Transformers
  • 🌍GET STARTED
    • Transformers
    • Quick tour
    • Installation
  • 🌍TUTORIALS
    • Run inference with pipelines
    • Write portable code with AutoClass
    • Preprocess data
    • Fine-tune a pretrained model
    • Train with a script
    • Set up distributed training with BOINC AI Accelerate
    • Load and train adapters with BOINC AI PEFT
    • Share your model
    • Agents
    • Generation with LLMs
  • 🌍TASK GUIDES
    • 🌍NATURAL LANGUAGE PROCESSING
      • Text classification
      • Token classification
      • Question answering
      • Causal language modeling
      • Masked language modeling
      • Translation
      • Summarization
      • Multiple choice
    • 🌍AUDIO
      • Audio classification
      • Automatic speech recognition
    • 🌍COMPUTER VISION
      • Image classification
      • Semantic segmentation
      • Video classification
      • Object detection
      • Zero-shot object detection
      • Zero-shot image classification
      • Depth estimation
    • 🌍MULTIMODAL
      • Image captioning
      • Document Question Answering
      • Visual Question Answering
      • Text to speech
    • 🌍GENERATION
      • Customize the generation strategy
    • 🌍PROMPTING
      • Image tasks with IDEFICS
  • 🌍DEVELOPER GUIDES
    • Use fast tokenizers from BOINC AI Tokenizers
    • Run inference with multilingual models
    • Use model-specific APIs
    • Share a custom model
    • Templates for chat models
    • Run training on Amazon SageMaker
    • Export to ONNX
    • Export to TFLite
    • Export to TorchScript
    • Benchmarks
    • Notebooks with examples
    • Community resources
    • Custom Tools and Prompts
    • Troubleshoot
  • 🌍PERFORMANCE AND SCALABILITY
    • Overview
    • 🌍EFFICIENT TRAINING TECHNIQUES
      • Methods and tools for efficient training on a single GPU
      • Multiple GPUs and parallelism
      • Efficient training on CPU
      • Distributed CPU training
      • Training on TPUs
      • Training on TPU with TensorFlow
      • Training on Specialized Hardware
      • Custom hardware for training
      • Hyperparameter Search using Trainer API
    • 🌍OPTIMIZING INFERENCE
      • Inference on CPU
      • Inference on one GPU
      • Inference on many GPUs
      • Inference on Specialized Hardware
    • Instantiating a big model
    • Troubleshooting
    • XLA Integration for TensorFlow Models
    • Optimize inference using `torch.compile()`
  • 🌍CONTRIBUTE
    • How to contribute to transformers?
    • How to add a model to BOINC AI Transformers?
    • How to convert a BOINC AI Transformers model to TensorFlow?
    • How to add a pipeline to BOINC AI Transformers?
    • Testing
    • Checks on a Pull Request
  • 🌍CONCEPTUAL GUIDES
    • Philosophy
    • Glossary
    • What BOINC AI Transformers can do
    • How BOINC AI Transformers solve tasks
    • The Transformer model family
    • Summary of the tokenizers
    • Attention mechanisms
    • Padding and truncation
    • BERTology
    • Perplexity of fixed-length models
    • Pipelines for webserver inference
    • Model training anatomy
  • 🌍API
    • 🌍MAIN CLASSES
      • Agents and Tools
      • 🌍Auto Classes
        • Extending the Auto Classes
        • AutoConfig
        • AutoTokenizer
        • AutoFeatureExtractor
        • AutoImageProcessor
        • AutoProcessor
        • Generic model classes
          • AutoModel
          • TFAutoModel
          • FlaxAutoModel
        • Generic pretraining classes
          • AutoModelForPreTraining
          • TFAutoModelForPreTraining
          • FlaxAutoModelForPreTraining
        • Natural Language Processing
          • AutoModelForCausalLM
          • TFAutoModelForCausalLM
          • FlaxAutoModelForCausalLM
          • AutoModelForMaskedLM
          • TFAutoModelForMaskedLM
          • FlaxAutoModelForMaskedLM
          • AutoModelForMaskGenerationge
          • TFAutoModelForMaskGeneration
          • AutoModelForSeq2SeqLM
          • TFAutoModelForSeq2SeqLM
          • FlaxAutoModelForSeq2SeqLM
          • AutoModelForSequenceClassification
          • TFAutoModelForSequenceClassification
          • FlaxAutoModelForSequenceClassification
          • AutoModelForMultipleChoice
          • TFAutoModelForMultipleChoice
          • FlaxAutoModelForMultipleChoice
          • AutoModelForNextSentencePrediction
          • TFAutoModelForNextSentencePrediction
          • FlaxAutoModelForNextSentencePrediction
          • AutoModelForTokenClassification
          • TFAutoModelForTokenClassification
          • FlaxAutoModelForTokenClassification
          • AutoModelForQuestionAnswering
          • TFAutoModelForQuestionAnswering
          • FlaxAutoModelForQuestionAnswering
          • AutoModelForTextEncoding
          • TFAutoModelForTextEncoding
        • Computer vision
          • AutoModelForDepthEstimation
          • AutoModelForImageClassification
          • TFAutoModelForImageClassification
          • FlaxAutoModelForImageClassification
          • AutoModelForVideoClassification
          • AutoModelForMaskedImageModeling
          • TFAutoModelForMaskedImageModeling
          • AutoModelForObjectDetection
          • AutoModelForImageSegmentation
          • AutoModelForImageToImage
          • AutoModelForSemanticSegmentation
          • TFAutoModelForSemanticSegmentation
          • AutoModelForInstanceSegmentation
          • AutoModelForUniversalSegmentation
          • AutoModelForZeroShotImageClassification
          • TFAutoModelForZeroShotImageClassification
          • AutoModelForZeroShotObjectDetection
        • Audio
          • AutoModelForAudioClassification
          • AutoModelForAudioFrameClassification
          • TFAutoModelForAudioFrameClassification
          • AutoModelForCTC
          • AutoModelForSpeechSeq2Seq
          • TFAutoModelForSpeechSeq2Seq
          • FlaxAutoModelForSpeechSeq2Seq
          • AutoModelForAudioXVector
          • AutoModelForTextToSpectrogram
          • AutoModelForTextToWaveform
        • Multimodal
          • AutoModelForTableQuestionAnswering
          • TFAutoModelForTableQuestionAnswering
          • AutoModelForDocumentQuestionAnswering
          • TFAutoModelForDocumentQuestionAnswering
          • AutoModelForVisualQuestionAnswering
          • AutoModelForVision2Seq
          • TFAutoModelForVision2Seq
          • FlaxAutoModelForVision2Seq
      • Callbacks
      • Configuration
      • Data Collator
      • Keras callbacks
      • Logging
      • Models
      • Text Generation
      • ONNX
      • Optimization
      • Model outputs
      • Pipelines
      • Processors
      • Quantization
      • Tokenizer
      • Trainer
      • DeepSpeed Integration
      • Feature Extractor
      • Image Processor
    • 🌍MODELS
      • 🌍TEXT MODELS
        • ALBERT
        • BART
        • BARThez
        • BARTpho
        • BERT
        • BertGeneration
        • BertJapanese
        • Bertweet
        • BigBird
        • BigBirdPegasus
        • BioGpt
        • Blenderbot
        • Blenderbot Small
        • BLOOM
        • BORT
        • ByT5
        • CamemBERT
        • CANINE
        • CodeGen
        • CodeLlama
        • ConvBERT
        • CPM
        • CPMANT
        • CTRL
        • DeBERTa
        • DeBERTa-v2
        • DialoGPT
        • DistilBERT
        • DPR
        • ELECTRA
        • Encoder Decoder Models
        • ERNIE
        • ErnieM
        • ESM
        • Falcon
        • FLAN-T5
        • FLAN-UL2
        • FlauBERT
        • FNet
        • FSMT
        • Funnel Transformer
        • GPT
        • GPT Neo
        • GPT NeoX
        • GPT NeoX Japanese
        • GPT-J
        • GPT2
        • GPTBigCode
        • GPTSAN Japanese
        • GPTSw3
        • HerBERT
        • I-BERT
        • Jukebox
        • LED
        • LLaMA
        • LLama2
        • Longformer
        • LongT5
        • LUKE
        • M2M100
        • MarianMT
        • MarkupLM
        • MBart and MBart-50
        • MEGA
        • MegatronBERT
        • MegatronGPT2
        • Mistral
        • mLUKE
        • MobileBERT
        • MPNet
        • MPT
        • MRA
        • MT5
        • MVP
        • NEZHA
        • NLLB
        • NLLB-MoE
        • Nyströmformer
        • Open-Llama
        • OPT
        • Pegasus
        • PEGASUS-X
        • Persimmon
        • PhoBERT
        • PLBart
        • ProphetNet
        • QDQBert
        • RAG
        • REALM
        • Reformer
        • RemBERT
        • RetriBERT
        • RoBERTa
        • RoBERTa-PreLayerNorm
        • RoCBert
        • RoFormer
        • RWKV
        • Splinter
        • SqueezeBERT
        • SwitchTransformers
        • T5
        • T5v1.1
        • TAPEX
        • Transformer XL
        • UL2
        • UMT5
        • X-MOD
        • XGLM
        • XLM
        • XLM-ProphetNet
        • XLM-RoBERTa
        • XLM-RoBERTa-XL
        • XLM-V
        • XLNet
        • YOSO
      • 🌍VISION MODELS
        • BEiT
        • BiT
        • Conditional DETR
        • ConvNeXT
        • ConvNeXTV2
        • CvT
        • Deformable DETR
        • DeiT
        • DETA
        • DETR
        • DiNAT
        • DINO V2
        • DiT
        • DPT
        • EfficientFormer
        • EfficientNet
        • FocalNet
        • GLPN
        • ImageGPT
        • LeViT
        • Mask2Former
        • MaskFormer
        • MobileNetV1
        • MobileNetV2
        • MobileViT
        • MobileViTV2
        • NAT
        • PoolFormer
        • Pyramid Vision Transformer (PVT)
        • RegNet
        • ResNet
        • SegFormer
        • SwiftFormer
        • Swin Transformer
        • Swin Transformer V2
        • Swin2SR
        • Table Transformer
        • TimeSformer
        • UperNet
        • VAN
        • VideoMAE
        • Vision Transformer (ViT)
        • ViT Hybrid
        • ViTDet
        • ViTMAE
        • ViTMatte
        • ViTMSN
        • ViViT
        • YOLOS
      • 🌍AUDIO MODELS
        • Audio Spectrogram Transformer
        • Bark
        • CLAP
        • EnCodec
        • Hubert
        • MCTCT
        • MMS
        • MusicGen
        • Pop2Piano
        • SEW
        • SEW-D
        • Speech2Text
        • Speech2Text2
        • SpeechT5
        • UniSpeech
        • UniSpeech-SAT
        • VITS
        • Wav2Vec2
        • Wav2Vec2-Conformer
        • Wav2Vec2Phoneme
        • WavLM
        • Whisper
        • XLS-R
        • XLSR-Wav2Vec2
      • 🌍MULTIMODAL MODELS
        • ALIGN
        • AltCLIP
        • BLIP
        • BLIP-2
        • BridgeTower
        • BROS
        • Chinese-CLIP
        • CLIP
        • CLIPSeg
        • Data2Vec
        • DePlot
        • Donut
        • FLAVA
        • GIT
        • GroupViT
        • IDEFICS
        • InstructBLIP
        • LayoutLM
        • LayoutLMV2
        • LayoutLMV3
        • LayoutXLM
        • LiLT
        • LXMERT
        • MatCha
        • MGP-STR
        • Nougat
        • OneFormer
        • OWL-ViT
        • Perceiver
        • Pix2Struct
        • Segment Anything
        • Speech Encoder Decoder Models
        • TAPAS
        • TrOCR
        • TVLT
        • ViLT
        • Vision Encoder Decoder Models
        • Vision Text Dual Encoder
        • VisualBERT
        • X-CLIP
      • 🌍REINFORCEMENT LEARNING MODELS
        • Decision Transformer
        • Trajectory Transformer
      • 🌍TIME SERIES MODELS
        • Autoformer
        • Informer
        • Time Series Transformer
      • 🌍GRAPH MODELS
        • Graphormer
  • 🌍INTERNAL HELPERS
    • Custom Layers and Utilities
    • Utilities for pipelines
    • Utilities for Tokenizers
    • Utilities for Trainer
    • Utilities for Generation
    • Utilities for Image Processors
    • Utilities for Audio processing
    • General Utilities
    • Utilities for Time Series
Powered by GitBook
On this page
  1. PERFORMANCE AND SCALABILITY
  2. EFFICIENT TRAINING TECHNIQUES

Training on TPU with TensorFlow

PreviousTraining on TPUsNextTraining on Specialized Hardware

Last updated 1 year ago

Training on TPU with TensorFlow

If you don’t need long explanations and just want TPU code samples to get started with, check out

What is a TPU?

A TPU is a Tensor Processing Unit. They are hardware designed by Google, which are used to greatly speed up the tensor computations within neural networks, much like GPUs. They can be used for both network training and inference. They are generally accessed through Google’s cloud services, but small TPUs can also be accessed directly for free through Google Colab and Kaggle Kernels.

Because 🌍, most of the methods in this document are generally applicable to TPU training for any Keras model! However, there are a few points that are specific to the BOINC AI ecosystem (hug-o-system?) of Transformers and Datasets, and we’ll make sure to flag them up when we get to them.

What kinds of TPU are available?

New users are often very confused by the range of TPUs, and the different ways to access them. The first key distinction to understand is the difference between TPU Nodes and TPU VMs.

When you use a TPU Node, you are effectively indirectly accessing a remote TPU. You will need a separate VM, which will initialize your network and data pipeline and then forward them to the remote node. When you use a TPU on Google Colab, you are accessing it in the TPU Node style.

Using TPU Nodes can have some quite unexpected behaviour for people who aren’t used to them! In particular, because the TPU is located on a physically different system to the machine you’re running your Python code on, your data cannot be local to your machine - any data pipeline that loads from your machine’s internal storage will totally fail! Instead, data must be stored in Google Cloud Storage where your data pipeline can still access it, even when the pipeline is running on the remote TPU node.

If you can fit all your data in memory as np.ndarray or tf.Tensor, then you can fit() on that data even when using Colab or a TPU Node, without needing to upload it to Google Cloud Storage.

🌍Specific BOINC AI Tip🌍: The methods Dataset.to_tf_dataset() and its higher-level wrapper model.prepare_tf_dataset() , which you will see throughout our TF code examples, will both fail on a TPU Node. The reason for this is that even though they create a tf.data.Dataset it is not a “pure” tf.data pipeline and uses tf.numpy_function or Dataset.from_generator() to stream data from the underlying BOINC AI Dataset. This BOINC AI Dataset is backed by data that is on a local disc and which the remote TPU Node will not be able to read.

The second way to access a TPU is via a TPU VM. When using a TPU VM, you connect directly to the machine that the TPU is attached to, much like training on a GPU VM. TPU VMs are generally easier to work with, particularly when it comes to your data pipeline. All of the above warnings do not apply to TPU VMs!

This is an opinionated document, so here’s our opinion: Avoid using TPU Node if possible. It is more confusing and more difficult to debug than TPU VMs. It is also likely to be unsupported in future - Google’s latest TPU, TPUv4, can only be accessed as a TPU VM, which suggests that TPU Nodes are increasingly going to become a “legacy” access method. However, we understand that the only free TPU access is on Colab and Kaggle Kernels, which uses TPU Node - so we’ll try to explain how to handle it if you have to! Check the for code samples that explain this in more detail.

What sizes of TPU are available?

A single TPU (a v2-8/v3-8/v4-8) runs 8 replicas. TPUs exist in pods that can run hundreds or thousands of replicas simultaneously. When you use more than a single TPU but less than a whole pod (for example, a v3-32), your TPU fleet is referred to as a pod slice.

When you access a free TPU via Colab, you generally get a single v2-8 TPU.

I keep hearing about this XLA thing. What’s XLA, and how does it relate to TPUs?

XLA is an optimizing compiler, used by both TensorFlow and JAX. In JAX it is the only compiler, whereas in TensorFlow it is optional (but mandatory on TPU!). The easiest way to enable it when training a Keras model is to pass the argument jit_compile=True to model.compile(). If you don’t get any errors and performance is good, that’s a great sign that you’re ready to move to TPU!

Debugging on TPU is generally a bit harder than on CPU/GPU, so we recommend getting your code running on CPU/GPU with XLA first before trying it on TPU. You don’t have to train for long, of course - just for a few steps to make sure that your model and data pipeline are working like you expect them to.

XLA compiled code is usually faster - so even if you’re not planning to run on TPU, adding jit_compile=True can improve your performance. Be sure to note the caveats below about XLA compatibility, though!

Tip born of painful experience: Although using jit_compile=True is a good way to get a speed boost and test if your CPU/GPU code is XLA-compatible, it can actually cause a lot of problems if you leave it in when actually training on TPU. XLA compilation will happen implicitly on TPU, so remember to remove that line before actually running your code on a TPU!

How do I make my model XLA compatible?

In many cases, your code is probably XLA-compatible already! However, there are a few things that work in normal TensorFlow that don’t work in XLA. We’ve distilled them into three core rules below:

🌍Specific BOINC AI Tip🌍: We’ve put a lot of effort into rewriting our TensorFlow models and loss functions to be XLA-compatible. Our models and loss functions generally obey rule #1 and #2 by default, so you can skip over them if you’re using transformers models. Don’t forget about these rules when writing your own models and loss functions, though!

XLA Rule #1: Your code cannot have “data-dependent conditionals”

What that means is that any if statement cannot depend on values inside a tf.Tensor. For example, this code block cannot be compiled with XLA!

Copied

if tf.reduce_sum(tensor) > 10:
    tensor = tensor / 2.0

Copied

sum_over_10 = tf.cast(tf.reduce_sum(tensor) > 10, tf.float32)
tensor = tensor / (1.0 + sum_over_10)

This code has exactly the same effect as the code above, but by avoiding a conditional, we ensure it will compile with XLA without problems!

XLA Rule #2: Your code cannot have “data-dependent shapes”

What this means is that the shape of all of the tf.Tensor objects in your code cannot depend on their values. For example, the function tf.unique cannot be compiled with XLA, because it returns a tensor containing one instance of each unique value in the input. The shape of this output will obviously be different depending on how repetitive the input Tensor was, and so XLA refuses to handle it!

Copied

label_mask = labels >= 0
masked_outputs = outputs[label_mask]
masked_labels = labels[label_mask]
loss = compute_loss(masked_outputs, masked_labels)
mean_loss = torch.mean(loss)

This code is totally fine in NumPy or PyTorch, but it breaks in XLA! Why? Because the shape of masked_outputs and masked_labels depends on how many positions are masked - that makes it a data-dependent shape. However, just like for rule #1, we can often rewrite this code to yield exactly the same output without any data-dependent shapes.

Copied

label_mask = tf.cast(labels >= 0, tf.float32)
loss = compute_loss(outputs, labels)
loss = loss * label_mask  # Set negative label positions to 0
mean_loss = tf.reduce_sum(loss) / tf.reduce_sum(label_mask)

Here, we avoid data-dependent shapes by computing the loss for every position, but zeroing out the masked positions in both the numerator and denominator when we calculate the mean, which yields exactly the same result as the first block while maintaining XLA compatibility. Note that we use the same trick as in rule #1 - converting a tf.bool to tf.float32 and using it as an indicator variable. This is a really useful trick, so remember it if you need to convert your own code to XLA!

XLA Rule #3: XLA will need to recompile your model for every different input shape it sees

This is the big one. What this means is that if your input shapes are very variable, XLA will have to recompile your model over and over, which will create huge performance problems. This commonly arises in NLP models, where input texts have variable lengths after tokenization. In other modalities, static shapes are more common and this rule is much less of a problem.

How can you get around rule #3? The key is padding - if you pad all your inputs to the same length, and then use an attention_mask, you can get the same results as you’d get from variable shapes, but without any XLA issues. However, excessive padding can cause severe slowdown too - if you pad all your samples to the maximum length in the whole dataset, you might end up with batches consisting endless padding tokens, which will waste a lot of compute and memory!

There isn’t a perfect solution to this problem. However, you can try some tricks. One very useful trick is to pad batches of samples up to a multiple of a number like 32 or 64 tokens. This often only increases the number of tokens by a small amount, but it hugely reduces the number of unique input shapes, because every input shape now has to be a multiple of 32 or 64. Fewer unique input shapes means fewer XLA compilations!

🌍Specific BOINC AI Tip🌍: Our tokenizers and data collators have methods that can help you here. You can use padding="max_length" or padding="longest" when calling tokenizers to get them to output padded data. Our tokenizers and data collators also have a pad_to_multiple_of argument that you can use to reduce the number of unique input shapes you see!

How do I actually train my model on TPU?

Summary

There was a lot in here, so let’s summarize with a quick checklist you can follow when you want to get your model ready for TPU training:

  • Make sure your code follows the three rules of XLA

  • Compile your model with jit_compile=True on CPU/GPU and confirm that you can train it with XLA

  • Migrate your code either to Colab (with accelerator set to “TPU”) or a TPU VM on Google Cloud

  • Don’t forget to take jit_compile=True out again when you move to TPU!

  • 🙏🙏🙏🥺🥺🥺

  • Call model.fit()

  • You did it!

This might seem very restrictive at first, but most neural net code doesn’t need to do this. You can often get around this restriction by using tf.cond (see the documentation ) or by removing the conditional and finding a clever math trick with indicator variables instead, like so:

In general, most neural network code obeys rule #2 by default. However, there are a few common cases where it becomes a problem. One very common one is when you use label masking, setting your labels to a negative value to indicate that those positions should be ignored when computing the loss. If you look at NumPy or PyTorch loss functions that support label masking, you will often see code like this that uses :

Once your training is XLA-compatible and (if you’re using TPU Node / Colab) your dataset has been prepared appropriately, running on TPU is surprisingly easy! All you really need to change in your code is to add a few lines to initialize your TPU, and to ensure that your model and dataset are created inside a TPUStrategy scope. Take a look at to see this in action!

Either load your dataset into memory or use a TPU-compatible dataset loading approach (see )

Add TPU initializer code (see )

Create your TPUStrategy and make sure dataset loading and model creation are inside the strategy.scope() (see )

🌍
🌍
our TPU example notebook!
all TensorFlow models in
Transformers are Keras models
TPU example notebook
here
boolean indexing
our TPU example notebook
notebook
notebook
notebook