Transformers
  • 🌍GET STARTED
    • Transformers
    • Quick tour
    • Installation
  • 🌍TUTORIALS
    • Run inference with pipelines
    • Write portable code with AutoClass
    • Preprocess data
    • Fine-tune a pretrained model
    • Train with a script
    • Set up distributed training with BOINC AI Accelerate
    • Load and train adapters with BOINC AI PEFT
    • Share your model
    • Agents
    • Generation with LLMs
  • 🌍TASK GUIDES
    • 🌍NATURAL LANGUAGE PROCESSING
      • Text classification
      • Token classification
      • Question answering
      • Causal language modeling
      • Masked language modeling
      • Translation
      • Summarization
      • Multiple choice
    • 🌍AUDIO
      • Audio classification
      • Automatic speech recognition
    • 🌍COMPUTER VISION
      • Image classification
      • Semantic segmentation
      • Video classification
      • Object detection
      • Zero-shot object detection
      • Zero-shot image classification
      • Depth estimation
    • 🌍MULTIMODAL
      • Image captioning
      • Document Question Answering
      • Visual Question Answering
      • Text to speech
    • 🌍GENERATION
      • Customize the generation strategy
    • 🌍PROMPTING
      • Image tasks with IDEFICS
  • 🌍DEVELOPER GUIDES
    • Use fast tokenizers from BOINC AI Tokenizers
    • Run inference with multilingual models
    • Use model-specific APIs
    • Share a custom model
    • Templates for chat models
    • Run training on Amazon SageMaker
    • Export to ONNX
    • Export to TFLite
    • Export to TorchScript
    • Benchmarks
    • Notebooks with examples
    • Community resources
    • Custom Tools and Prompts
    • Troubleshoot
  • 🌍PERFORMANCE AND SCALABILITY
    • Overview
    • 🌍EFFICIENT TRAINING TECHNIQUES
      • Methods and tools for efficient training on a single GPU
      • Multiple GPUs and parallelism
      • Efficient training on CPU
      • Distributed CPU training
      • Training on TPUs
      • Training on TPU with TensorFlow
      • Training on Specialized Hardware
      • Custom hardware for training
      • Hyperparameter Search using Trainer API
    • 🌍OPTIMIZING INFERENCE
      • Inference on CPU
      • Inference on one GPU
      • Inference on many GPUs
      • Inference on Specialized Hardware
    • Instantiating a big model
    • Troubleshooting
    • XLA Integration for TensorFlow Models
    • Optimize inference using `torch.compile()`
  • 🌍CONTRIBUTE
    • How to contribute to transformers?
    • How to add a model to BOINC AI Transformers?
    • How to convert a BOINC AI Transformers model to TensorFlow?
    • How to add a pipeline to BOINC AI Transformers?
    • Testing
    • Checks on a Pull Request
  • 🌍CONCEPTUAL GUIDES
    • Philosophy
    • Glossary
    • What BOINC AI Transformers can do
    • How BOINC AI Transformers solve tasks
    • The Transformer model family
    • Summary of the tokenizers
    • Attention mechanisms
    • Padding and truncation
    • BERTology
    • Perplexity of fixed-length models
    • Pipelines for webserver inference
    • Model training anatomy
  • 🌍API
    • 🌍MAIN CLASSES
      • Agents and Tools
      • 🌍Auto Classes
        • Extending the Auto Classes
        • AutoConfig
        • AutoTokenizer
        • AutoFeatureExtractor
        • AutoImageProcessor
        • AutoProcessor
        • Generic model classes
          • AutoModel
          • TFAutoModel
          • FlaxAutoModel
        • Generic pretraining classes
          • AutoModelForPreTraining
          • TFAutoModelForPreTraining
          • FlaxAutoModelForPreTraining
        • Natural Language Processing
          • AutoModelForCausalLM
          • TFAutoModelForCausalLM
          • FlaxAutoModelForCausalLM
          • AutoModelForMaskedLM
          • TFAutoModelForMaskedLM
          • FlaxAutoModelForMaskedLM
          • AutoModelForMaskGenerationge
          • TFAutoModelForMaskGeneration
          • AutoModelForSeq2SeqLM
          • TFAutoModelForSeq2SeqLM
          • FlaxAutoModelForSeq2SeqLM
          • AutoModelForSequenceClassification
          • TFAutoModelForSequenceClassification
          • FlaxAutoModelForSequenceClassification
          • AutoModelForMultipleChoice
          • TFAutoModelForMultipleChoice
          • FlaxAutoModelForMultipleChoice
          • AutoModelForNextSentencePrediction
          • TFAutoModelForNextSentencePrediction
          • FlaxAutoModelForNextSentencePrediction
          • AutoModelForTokenClassification
          • TFAutoModelForTokenClassification
          • FlaxAutoModelForTokenClassification
          • AutoModelForQuestionAnswering
          • TFAutoModelForQuestionAnswering
          • FlaxAutoModelForQuestionAnswering
          • AutoModelForTextEncoding
          • TFAutoModelForTextEncoding
        • Computer vision
          • AutoModelForDepthEstimation
          • AutoModelForImageClassification
          • TFAutoModelForImageClassification
          • FlaxAutoModelForImageClassification
          • AutoModelForVideoClassification
          • AutoModelForMaskedImageModeling
          • TFAutoModelForMaskedImageModeling
          • AutoModelForObjectDetection
          • AutoModelForImageSegmentation
          • AutoModelForImageToImage
          • AutoModelForSemanticSegmentation
          • TFAutoModelForSemanticSegmentation
          • AutoModelForInstanceSegmentation
          • AutoModelForUniversalSegmentation
          • AutoModelForZeroShotImageClassification
          • TFAutoModelForZeroShotImageClassification
          • AutoModelForZeroShotObjectDetection
        • Audio
          • AutoModelForAudioClassification
          • AutoModelForAudioFrameClassification
          • TFAutoModelForAudioFrameClassification
          • AutoModelForCTC
          • AutoModelForSpeechSeq2Seq
          • TFAutoModelForSpeechSeq2Seq
          • FlaxAutoModelForSpeechSeq2Seq
          • AutoModelForAudioXVector
          • AutoModelForTextToSpectrogram
          • AutoModelForTextToWaveform
        • Multimodal
          • AutoModelForTableQuestionAnswering
          • TFAutoModelForTableQuestionAnswering
          • AutoModelForDocumentQuestionAnswering
          • TFAutoModelForDocumentQuestionAnswering
          • AutoModelForVisualQuestionAnswering
          • AutoModelForVision2Seq
          • TFAutoModelForVision2Seq
          • FlaxAutoModelForVision2Seq
      • Callbacks
      • Configuration
      • Data Collator
      • Keras callbacks
      • Logging
      • Models
      • Text Generation
      • ONNX
      • Optimization
      • Model outputs
      • Pipelines
      • Processors
      • Quantization
      • Tokenizer
      • Trainer
      • DeepSpeed Integration
      • Feature Extractor
      • Image Processor
    • 🌍MODELS
      • 🌍TEXT MODELS
        • ALBERT
        • BART
        • BARThez
        • BARTpho
        • BERT
        • BertGeneration
        • BertJapanese
        • Bertweet
        • BigBird
        • BigBirdPegasus
        • BioGpt
        • Blenderbot
        • Blenderbot Small
        • BLOOM
        • BORT
        • ByT5
        • CamemBERT
        • CANINE
        • CodeGen
        • CodeLlama
        • ConvBERT
        • CPM
        • CPMANT
        • CTRL
        • DeBERTa
        • DeBERTa-v2
        • DialoGPT
        • DistilBERT
        • DPR
        • ELECTRA
        • Encoder Decoder Models
        • ERNIE
        • ErnieM
        • ESM
        • Falcon
        • FLAN-T5
        • FLAN-UL2
        • FlauBERT
        • FNet
        • FSMT
        • Funnel Transformer
        • GPT
        • GPT Neo
        • GPT NeoX
        • GPT NeoX Japanese
        • GPT-J
        • GPT2
        • GPTBigCode
        • GPTSAN Japanese
        • GPTSw3
        • HerBERT
        • I-BERT
        • Jukebox
        • LED
        • LLaMA
        • LLama2
        • Longformer
        • LongT5
        • LUKE
        • M2M100
        • MarianMT
        • MarkupLM
        • MBart and MBart-50
        • MEGA
        • MegatronBERT
        • MegatronGPT2
        • Mistral
        • mLUKE
        • MobileBERT
        • MPNet
        • MPT
        • MRA
        • MT5
        • MVP
        • NEZHA
        • NLLB
        • NLLB-MoE
        • Nyströmformer
        • Open-Llama
        • OPT
        • Pegasus
        • PEGASUS-X
        • Persimmon
        • PhoBERT
        • PLBart
        • ProphetNet
        • QDQBert
        • RAG
        • REALM
        • Reformer
        • RemBERT
        • RetriBERT
        • RoBERTa
        • RoBERTa-PreLayerNorm
        • RoCBert
        • RoFormer
        • RWKV
        • Splinter
        • SqueezeBERT
        • SwitchTransformers
        • T5
        • T5v1.1
        • TAPEX
        • Transformer XL
        • UL2
        • UMT5
        • X-MOD
        • XGLM
        • XLM
        • XLM-ProphetNet
        • XLM-RoBERTa
        • XLM-RoBERTa-XL
        • XLM-V
        • XLNet
        • YOSO
      • 🌍VISION MODELS
        • BEiT
        • BiT
        • Conditional DETR
        • ConvNeXT
        • ConvNeXTV2
        • CvT
        • Deformable DETR
        • DeiT
        • DETA
        • DETR
        • DiNAT
        • DINO V2
        • DiT
        • DPT
        • EfficientFormer
        • EfficientNet
        • FocalNet
        • GLPN
        • ImageGPT
        • LeViT
        • Mask2Former
        • MaskFormer
        • MobileNetV1
        • MobileNetV2
        • MobileViT
        • MobileViTV2
        • NAT
        • PoolFormer
        • Pyramid Vision Transformer (PVT)
        • RegNet
        • ResNet
        • SegFormer
        • SwiftFormer
        • Swin Transformer
        • Swin Transformer V2
        • Swin2SR
        • Table Transformer
        • TimeSformer
        • UperNet
        • VAN
        • VideoMAE
        • Vision Transformer (ViT)
        • ViT Hybrid
        • ViTDet
        • ViTMAE
        • ViTMatte
        • ViTMSN
        • ViViT
        • YOLOS
      • 🌍AUDIO MODELS
        • Audio Spectrogram Transformer
        • Bark
        • CLAP
        • EnCodec
        • Hubert
        • MCTCT
        • MMS
        • MusicGen
        • Pop2Piano
        • SEW
        • SEW-D
        • Speech2Text
        • Speech2Text2
        • SpeechT5
        • UniSpeech
        • UniSpeech-SAT
        • VITS
        • Wav2Vec2
        • Wav2Vec2-Conformer
        • Wav2Vec2Phoneme
        • WavLM
        • Whisper
        • XLS-R
        • XLSR-Wav2Vec2
      • 🌍MULTIMODAL MODELS
        • ALIGN
        • AltCLIP
        • BLIP
        • BLIP-2
        • BridgeTower
        • BROS
        • Chinese-CLIP
        • CLIP
        • CLIPSeg
        • Data2Vec
        • DePlot
        • Donut
        • FLAVA
        • GIT
        • GroupViT
        • IDEFICS
        • InstructBLIP
        • LayoutLM
        • LayoutLMV2
        • LayoutLMV3
        • LayoutXLM
        • LiLT
        • LXMERT
        • MatCha
        • MGP-STR
        • Nougat
        • OneFormer
        • OWL-ViT
        • Perceiver
        • Pix2Struct
        • Segment Anything
        • Speech Encoder Decoder Models
        • TAPAS
        • TrOCR
        • TVLT
        • ViLT
        • Vision Encoder Decoder Models
        • Vision Text Dual Encoder
        • VisualBERT
        • X-CLIP
      • 🌍REINFORCEMENT LEARNING MODELS
        • Decision Transformer
        • Trajectory Transformer
      • 🌍TIME SERIES MODELS
        • Autoformer
        • Informer
        • Time Series Transformer
      • 🌍GRAPH MODELS
        • Graphormer
  • 🌍INTERNAL HELPERS
    • Custom Layers and Utilities
    • Utilities for pipelines
    • Utilities for Tokenizers
    • Utilities for Trainer
    • Utilities for Generation
    • Utilities for Image Processors
    • Utilities for Audio processing
    • General Utilities
    • Utilities for Time Series
Powered by GitBook
On this page
  • Default text generation configuration
  • Customize text generation
  • Save a custom decoding strategy with your model
  • Streaming
  • Decoding strategies
  1. TASK GUIDES
  2. GENERATION

Customize the generation strategy

PreviousGENERATIONNextPROMPTING

Last updated 1 year ago

Text generation is essential to many NLP tasks, such as open-ended text generation, summarization, translation, and more. It also plays a role in a variety of mixed-modality applications that have text as an output like speech-to-text and vision-to-text. Some of the models that can generate text include GPT2, XLNet, OpenAI GPT, CTRL, TransformerXL, XLM, Bart, T5, GIT, Whisper.

Check out a few examples that use method to produce text outputs for different tasks:

Note that the inputs to the generate method depend on the model’s modality. They are returned by the model’s preprocessor class, such as AutoTokenizer or AutoProcessor. If a model’s preprocessor creates more than one kind of input, pass all the inputs to generate(). You can learn more about the individual model’s preprocessor in the corresponding model’s documentation.

The process of selecting output tokens to generate text is known as decoding, and you can customize the decoding strategy that the generate() method will use. Modifying a decoding strategy does not change the values of any trainable parameters. However, it can have a noticeable impact on the quality of the generated output. It can help reduce repetition in the text and make it more coherent.

This guide describes:

  • default generation configuration

  • common decoding strategies and their main parameters

  • saving and sharing custom generation configurations with your fine-tuned model on 🌍 Hub

Default text generation configuration

A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference within a , the models call the PreTrainedModel.generate() method that applies a default generation configuration under the hood. The default configuration is also used when no custom configuration has been saved with the model.

When you load a model explicitly, you can inspect the generation configuration that comes with it through model.generation_config:

Copied

>>> from transformers import AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.generation_config
GenerationConfig {
    "bos_token_id": 50256,
    "eos_token_id": 50256,
}

Printing out the model.generation_config reveals only the values that are different from the default generation configuration, and does not list any of the default values.

The default generation configuration limits the size of the output combined with the input prompt to a maximum of 20 tokens to avoid running into resource limitations. The default decoding strategy is greedy search, which is the simplest decoding strategy that picks a token with the highest probability as the next token. For many tasks and small output sizes this works well. However, when used to generate longer outputs, greedy search can start producing highly repetitive results.

Customize text generation

You can override any generation_config by passing the parameters and their values directly to the generate method:

Copied

>>> my_model.generate(**inputs, num_beams=4, do_sample=True)

Even if the default decoding strategy mostly works for your task, you can still tweak a few things. Some of the commonly adjusted parameters include:

  • num_beams: by specifying a number of beams higher than 1, you are effectively switching from greedy search to beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability sequences that start with a lower probability initial tokens and would’ve been ignored by the greedy search.

  • do_sample: if set to True, this parameter enables decoding strategies such as multinomial sampling, beam-search multinomial sampling, Top-K sampling and Top-p sampling. All these strategies select the next token from the probability distribution over the entire vocabulary with various strategy-specific adjustments.

  • num_return_sequences: the number of sequence candidates to return for each input. This option is only available for the decoding strategies that support multiple sequence candidates, e.g. variations of beam search and sampling. Decoding strategies like greedy search and contrastive search return a single output sequence.

Save a custom decoding strategy with your model

If you would like to share your fine-tuned model with a specific generation configuration, you can:

  • Specify the decoding strategy parameters

  • Set push_to_hub to True to upload your config to the model’s repo

Copied

>>> from transformers import AutoModelForCausalLM, GenerationConfig

>>> model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
>>> generation_config = GenerationConfig(
...     max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
... )
>>> generation_config.save_pretrained("my_account/my_model", push_to_hub=True)

Copied

>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

>>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

>>> translation_generation_config = GenerationConfig(
...     num_beams=4,
...     early_stopping=True,
...     decoder_start_token_id=0,
...     eos_token_id=model.config.eos_token_id,
...     pad_token=model.config.pad_token_id,
... )

>>> # Tip: add `push_to_hub=True` to push to the Hub
>>> translation_generation_config.save_pretrained("/tmp", "translation_generation_config.json")

>>> # You could then use the named generation config file to parameterize generation
>>> generation_config = GenerationConfig.from_pretrained("/tmp", "translation_generation_config.json")
>>> inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
>>> outputs = model.generate(**inputs, generation_config=generation_config)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Les fichiers de configuration sont faciles à utiliser!']

Streaming

The generate() supports streaming, through its streamer input. The streamer input is compatible with any instance from a class that has the following methods: put() and end(). Internally, put() is used to push new tokens and end() is used to flag the end of text generation.

The API for the streamer classes is still under development and may change in the future.

Copied

>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)

>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,

Decoding strategies

Here, we’ll show some of the parameters that control the decoding strategies and illustrate how you can use them.

Greedy Search

generate uses greedy search decoding by default so you don’t have to pass any parameters to enable it. This means the parameters num_beams is set to 1 and do_sample=False.

Copied

>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "I look forward to"
>>> checkpoint = "distilgpt2"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> outputs = model.generate(**inputs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['I look forward to seeing you all again!\n\n\n\n\n\n\n\n\n\n\n']

Contrastive search

Copied

>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> checkpoint = "gpt2-large"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)

>>> prompt = "BOINC AI Company is"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['BOINC AI Company is a family owned and operated business. We pride ourselves on being the best
in the business and our customer service is second to none.\n\nIf you have any questions about our
products or services, feel free to contact us at any time. We look forward to hearing from you!']

Multinomial sampling

As opposed to greedy search that always chooses a token with the highest probability as the next token, multinomial sampling (also called ancestral sampling) randomly selects the next token based on the probability distribution over the entire vocabulary given by the model. Every token with a non-zero probability has a chance of being selected, thus reducing the risk of repetition.

To enable multinomial sampling set do_sample=True and num_beams=1.

Copied

>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)  # For reproducibility

>>> checkpoint = "gpt2-large"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)

>>> prompt = "Today was an amazing day because"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today was an amazing day because when you go to the World Cup and you don\'t, or when you don\'t get invited,
that\'s a terrible feeling."']

Beam-search decoding

Unlike greedy search, beam-search decoding keeps several hypotheses at each time step and eventually chooses the hypothesis that has the overall highest probability for the entire sequence. This has the advantage of identifying high-probability sequences that start with lower probability initial tokens and would’ve been ignored by the greedy search.

To enable this decoding strategy, specify the num_beams (aka number of hypotheses to keep track of) that is greater than 1.

Copied

>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "It is astonishing how one can"
>>> checkpoint = "gpt2-medium"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)

>>> outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['It is astonishing how one can have such a profound impact on the lives of so many people in such a short period of
time."\n\nHe added: "I am very proud of the work I have been able to do in the last few years.\n\n"I have']

Beam-search multinomial sampling

As the name implies, this decoding strategy combines beam search with multinomial sampling. You need to specify the num_beams greater than 1, and set do_sample=True to use this decoding strategy.

Copied

>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, set_seed
>>> set_seed(0)  # For reproducibility

>>> prompt = "translate English to German: The house is wonderful."
>>> checkpoint = "t5-small"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

>>> outputs = model.generate(**inputs, num_beams=5, do_sample=True)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'Das Haus ist wunderbar.'

Diverse beam search decoding

Copied

>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

>>> checkpoint = "google/pegasus-xsum"
>>> prompt = (
...     "The Permaculture Design Principles are a set of universal design principles "
...     "that can be applied to any location, climate and culture, and they allow us to design "
...     "the most efficient and sustainable human habitation and food production systems. "
...     "Permaculture is a design system that encompasses a wide variety of disciplines, such "
...     "as ecology, landscape design, environmental science and energy conservation, and the "
...     "Permaculture design principles are drawn from these various disciplines. Each individual "
...     "design principle itself embodies a complete conceptual framework based on sound "
...     "scientific principles. When we bring all these separate  principles together, we can "
...     "create a design system that both looks at whole systems, the parts that these systems "
...     "consist of, and how those parts interact with each other to create a complex, dynamic, "
...     "living system. Each design principle serves as a tool that allows us to integrate all "
...     "the separate parts of a design, referred to as elements, into a functional, synergistic, "
...     "whole system, where the elements harmoniously interact and work together in the most "
...     "efficient way possible."
... )

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'The Design Principles are a set of universal design principles that can be applied to any location, climate and
culture, and they allow us to design the'

Assisted Decoding

To enable assisted decoding, set the assistant_model argument with a model.

Copied

>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']

When using assisted decoding with sampling methods, you can use the temperature argument to control the randomness just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.

Copied

>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42)  # For reproducibility

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are going to the same party. It is a small party, in a small']

max_new_tokens: the maximum number of tokens to generate. In other words, the size of the output sequence, not including the tokens in the prompt. As an alternative to using the output’s length as a stopping criteria, you can choose to stop generation whenever the full generation exceeds some amount of time. To learn more, check .

Create a class instance

Save your generation configuration with , making sure to leave its config_file_name argument empty

You can also store several generation configurations in a single directory, making use of the config_file_name argument in . You can later instantiate them with . This is useful if you want to store several generation configurations for a single model (e.g. one for creative text generation with sampling, and one for summarization with beam search). You must have the right Hub permissions to add configuration files to a model.

In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes ready for you to use. For example, you can use the class to stream the output of generate() into your screen, one word at a time:

Certain combinations of the generate() parameters, and ultimately generation_config, can be used to enable specific decoding strategies. If you are new to this concept, we recommend reading .

The contrastive search decoding strategy was proposed in the 2022 paper . It demonstrates superior results for generating non-repetitive yet coherent long outputs. To learn how contrastive search works, check out . The two main parameters that enable and control the behavior of contrastive search are penalty_alpha and top_k:

The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse set of beam sequences to choose from. To learn how it works, refer to . This approach has three main parameters: num_beams, num_beam_groups, and diversity_penalty. The diversity penalty ensures the outputs are distinct across groups, and beam search is used within each group.

This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the generate method, which gives you even further control over the generate method’s behavior. For the complete list of the available parameters, refer to the .

Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search and sampling are supported with assisted decoding, and doesn’t support batched inputs. To learn more about assisted decoding, check .

🌍
🌍
generate()
Text summarization
Image captioning
Audio transcription
pipeline()
StoppingCriteria
GenerationConfig
GenerationConfig.save_pretrained()
GenerationConfig.save_pretrained()
GenerationConfig.from_pretrained()
TextStreamer
this blog post that illustrates how common decoding strategies work
A Contrastive Framework for Neural Text Generation
this blog post
Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models
API documentation
this blog post