Recently, I started playing around with PyTorch XLA, mainly for two reasons:

1. The project and API seems really mature now compared to where it was about a year ago. Some widely-used PyTorch-based libraries (like HuggingFace Transformers) have even started adding TPU support to their codebase.
2. I'm tired of having to switch to TensorFlow every time I need to use a TPU for heavy lifting (yes, all my friends know I'm a huge PyTorch fanatic).

So I got around and learned how to get TPUs to play nicely with PyTorch, and wrote two notebooks in the process:

• One where I tested language modeling on 8 TPU cores using a small dataset to test the waters.
• One where I tested finetuning on 8 TPU cores with a lot of data, specifically the Toxic Comments Competition dataset from Jigsaw AI's 2018 Kaggle competition.

## Language Modeling

Finetuning a DistilGPT2 model for language modeling was a good way to test the waters. On the WikiText-2 dataset, I used a batch size of 8 (has to be small given that DistilGPT2 takes in a maximum of 1024 tokens at once), but since I'm doing distributed training on 8 cores, we're technically finetuning with batch size 64. That's pretty neat.

Our finetuning code ran for 3 epochs and finished in about 2 minutes, getting a validation perplexity of ~32. That's fast!

Number of training batches: 35
Number of evaluation batches: 3

===========================Epoch 1 of 3===========================
[TRAIN] Iteration    0 | Loss 4.7338 | Time Elapsed 2.97 seconds
[TRAIN] Iteration   20 | Loss 4.1438 | Time Elapsed 79.53 seconds
===========================Epoch 2 of 3===========================
[TRAIN] Iteration    0 | Loss 3.6926 | Time Elapsed 89.75 seconds
[TRAIN] Iteration   20 | Loss 3.7793 | Time Elapsed 103.52 seconds
===========================Epoch 3 of 3===========================
[TRAIN] Iteration    0 | Loss 3.4750 | Time Elapsed 113.67 seconds
[TRAIN] Iteration   20 | Loss 3.6648 | Time Elapsed 127.69 seconds

Finished training 3 epochs in 136.46 seconds.

============================Validation============================
[VALID] Iteration    0 | Loss 3.3704 | Time Elapsed 141.23 seconds

Finished evaluation in 8.59 seconds. Validation Loss: 3.4767 | Validation PPL 32.3538

I also evaluated on the test set, getting a test perplexity of ~31.

Testing batches: 32
[TEST] Iteration    0 | Loss 3.4673 | Time Elapsed 3.39 seconds
[TEST] Iteration   10 | Loss 3.4809 | Time Elapsed 4.87 seconds
[TEST] Iteration   20 | Loss 3.9778 | Time Elapsed 6.31 seconds
[TEST] Iteration   30 | Loss 3.2503 | Time Elapsed 7.74 seconds

Finished evaluation in 7.98 seconds. Test Loss: 3.4646 | Test PPL 31.9635

And then tested the model by generating some text.

Input: this is a brand new

Generation: this is a brand new product that will be available in the US and Europe in the coming months.

## Text Classification

I then decided to check out how much of a speedup multicore TPU training does on PyTorch on a sizable dataset. For this purpose, I decided to finetune a multiclass text classification model for the Toxic Comments Classification Challenge on Kaggle.

We finetuned a DistilBERT model for this purpose, using a batch size of 16 (but since we're using 8 cores, that's technically batch size 128).

We finished training in about 18 minutes for a total of 3 epochs. In contrast, running the same finetuning setup on an NVIDIA Tesla P100 GPU (batch size of 32, all other settings kept the same) finished in about 2 hours and 21 minutes. That's a 7.83x speedup!

Number of training batches: 872
Number of evaluation batches: 374

===========================Epoch 1 of 3===========================
[TRAIN] Iteration    0 | Loss 0.7006 | Time Elapsed 4.41 seconds
[TRAIN] Iteration  150 | Loss 0.0825 | Time Elapsed 98.67 seconds
[TRAIN] Iteration  300 | Loss 0.1211 | Time Elapsed 153.47 seconds
[TRAIN] Iteration  450 | Loss 0.1496 | Time Elapsed 208.38 seconds
[TRAIN] Iteration  600 | Loss 0.0523 | Time Elapsed 263.86 seconds
[TRAIN] Iteration  750 | Loss 0.0276 | Time Elapsed 319.04 seconds
===========================Epoch 2 of 3===========================
[TRAIN] Iteration    0 | Loss 0.0946 | Time Elapsed 366.67 seconds
[TRAIN] Iteration  150 | Loss 0.0289 | Time Elapsed 423.45 seconds
[TRAIN] Iteration  300 | Loss 0.0364 | Time Elapsed 479.03 seconds
[TRAIN] Iteration  450 | Loss 0.0785 | Time Elapsed 534.68 seconds
[TRAIN] Iteration  600 | Loss 0.0576 | Time Elapsed 593.00 seconds
[TRAIN] Iteration  750 | Loss 0.0168 | Time Elapsed 648.79 seconds
===========================Epoch 3 of 3===========================
[TRAIN] Iteration    0 | Loss 0.0861 | Time Elapsed 697.14 seconds
[TRAIN] Iteration  150 | Loss 0.0270 | Time Elapsed 753.59 seconds
[TRAIN] Iteration  300 | Loss 0.0323 | Time Elapsed 809.44 seconds
[TRAIN] Iteration  450 | Loss 0.0666 | Time Elapsed 865.38 seconds
[TRAIN] Iteration  600 | Loss 0.0401 | Time Elapsed 921.27 seconds
[TRAIN] Iteration  750 | Loss 0.0122 | Time Elapsed 977.26 seconds

Finished training 3 epochs in 1023.42 seconds.

============================Validation============================
[VALID] Iteration    0 | Loss 0.0582 | Time Elapsed 1033.13 seconds
[VALID] Iteration  150 | Loss 0.0009 | Time Elapsed 1056.29 seconds
[VALID] Iteration  300 | Loss 0.0129 | Time Elapsed 1075.24 seconds

Finished evaluation in 61.76 seconds. Validation AUROC: 0.9874

We then ran inference, again using all 8 TPU cores, to predict on the test set. Submitting the predictions to Kaggle got us a score of 0.97847 (mean column-wise ROC AUC), which was 0.01054 points away from the top scorer on the public leaderboard!

## Some Final Words

Ultimately, learning PyTorch XLA was a good idea,  since I won't have to mix PyTorch and TensorFlow code for my experiments anymore. Given that the models and datasets that I normally work with are pretty big, being able to use TPUs with code that's interoperable with the rest of my codebase is a pretty huge boost. It also lessens the cognitive load on my part since I don't have to fight with semantic differences.

With TPUs looking to be a mainstay in AI compute solutions for years to come, it's never been a better time to learn how to use them to speed up your code.