Quantization reduces the precision of floating point data types, decreasing the memory required to store model weights. However, quantization degrades inference performance because you lose information when you reduce the precision. 8-bit or int8 quantization uses only a quarter precision, but it does not degrade performance because it doesn’t just drop the bits or data. Instead, int8 quantization rounds from one data type to another.
This guide will show you how to train a openai/whisper-large-v2 model for multilingual automatic speech recognition (ASR) using a combination of int8 quantization and LoRA. You’ll train Whisper for multilingual ASR on Marathi from the Common Voice 11.0 dataset.
Before you start, make sure you have all the necessary libraries installed:
Let’s take care of some of the setup first so you can start training faster later. Set the CUDA_VISIBLE_DEVICES to 0 to use the first GPU on your machine. Then you can specify the model name (either a Hub model repository id or a path to a directory containing the model), language and language abbreviation to train on, the task type, and the dataset name:
Copied
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name_or_path = "openai/whisper-large-v2"
language = "Marathi"
language_abbr = "mr"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"
You can also log in to your BOINC AI account to save and share your trained model on the Hub if you’d like:
Copied
from boincai_hub import notebook_login
notebook_login()
Load dataset and metric
The Common Voice 11.0 dataset contains many hours of recorded speech in many different languages. This guide uses the Marathi language as an example, but feel free to use any other language you’re interested in.
Initialize a DatasetDict structure, and load the train (load both the train+validation split into train) and test splits from the dataset into it:
Let’s prepare the dataset for training. Load a feature extractor, tokenizer, and processor. You should also pass the language and task to the tokenizer and processor so they know how to process the inputs:
If you look at the sampling_rate, you’ll see the audio was sampled at 48kHz. The Whisper model was pretrained on audio inputs at 16kHZ which means you’ll need to downsample the audio inputs to match what the model was pretrained on. Downsample the audio by using the cast_column method on the audio column, and set the sampling_rate to 16kHz. The audio input is resampled on the fly the next time you call it:
Apply the prepare_dataset function to the dataset with the map function, and set the num_proc argument to 2 to enable multiprocessing (if map hangs, then set num_proc=1):
Finally, create a DataCollator class to pad the labels in each batch to the maximum length, and replace padding with -100 so they’re ignored by the loss function. Then initialize an instance of the data collator:
Copied
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
Train
Now that the dataset is ready, you can turn your attention to the model. Start by loading the pretrained openai/whisper-large-v2 model from AutoModelForSpeechSeq2Seq, and make sure to set the load_in_8bit argument to True to enable int8 quantization. The device_map=auto argument automatically determines how to load and store the model weights:
Copied
from transformers import AutoModelForSpeechSeq2Seq
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
You should configure forced_decoder_ids=None because no tokens are used before sampling, and you won’t need to suppress any tokens during generation either:
To get the model ready for int8 quantization, use the utility function prepare_model_for_int8_training to handle the following:
casts all the non int8 modules to full precision (fp32) for stability
adds a forward hook to the input embedding layer to calculate the gradients of the input hidden states
enables gradient checkpointing for more memory-efficient training
Copied
from peft import prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)
Let’s also apply LoRA to the training to make it even more efficient. Load a LoraConfig and configure the following parameters:
r, the dimension of the low-rank matrices
lora_alpha, scaling factor for the weight matrices
target_modules, the name of the attention matrices to apply LoRA to (q_proj and v_proj, or query and value in this case)
lora_dropout, dropout probability of the LoRA layers
bias, set to none
💡 The weight matrix is scaled by lora_alpha/r, and a higher lora_alpha value assigns more weight to the LoRA activations. For performance, we recommend setting bias to None first, and then lora_only, before trying all.
After you set up the LoraConfig, wrap it and the base model with the get_peft_model() function to create a PeftModel. Print out the number of trainable parameters to see how much more efficient LoRA is compared to fully training the model!
Copied
model = get_peft_model(model, config)
model.print_trainable_parameters()
"trainable params: 15728640 || all params: 1559033600 || trainable%: 1.0088711365810203"
Now you’re ready to define some training hyperparameters in the Seq2SeqTrainingArguments class, such as where to save the model to, batch size, learning rate, and number of epochs to train for. The PeftModel doesn’t have the same signature as the base model, so you’ll need to explicitly set remove_unused_columns=False and label_names=["labels"].
It is also a good idea to write a custom TrainerCallback to save model checkpoints during training:
Copied
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
Pass the Seq2SeqTrainingArguments, model, datasets, data collator, tokenizer, and callback to the Seq2SeqTrainer. You can optionally set model.config.use_cache = False to silence any warnings. Once everything is ready, call train to start training!
Word error rate (WER) is a common metric for evaluating ASR models. Load the WER metric from 🤗 Evaluate:
Copied
import evaluate
metric = evaluate.load("wer")
Write a loop to evaluate the model performance. Set the model to evaluation mode first, and write the loop with torch.cuda.amp.autocast() because int8 training requires autocasting. Then, pass a batch of examples to the model to evaluate. Get the decoded predictions and labels, and add them as a batch to the WER metric before calling compute to get the final WER score:
Copied
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gc
eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (
model.generate(
input_features=batch["input_features"].to("cuda"),
decoder_input_ids=batch["labels"][:, :4].to("cuda"),
max_new_tokens=255,
)
.cpu()
.numpy()
)
labels = batch["labels"].cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")
Share model
Once you’re happy with your results, you can upload your model to the Hub with the push_to_hub method:
Instantiate the model configuration from PeftConfig, and from here, you can use the configuration to load the base and PeftModel, tokenizer, processor, and feature extractor. Remember to define the language and task in the tokenizer, processor, and forced_decoder_ids:
Then use the pipeline with autocast as a context manager on the audio sample:
Copied
with torch.cuda.amp.autocast():
text = pipe(audio, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]
text
"मी तुमच्यासाठी काही करू शकतो का?"