Image captioning is the task of predicting a caption for a given image. Common real world applications of it include aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning helps to improve content accessibility for people by describing images to them.
This guide will show you how to:
Fine-tune an image captioning model.
Use the fine-tuned model for inference.
Before you begin, make sure you have all the necessary libraries installed:
We encourage you to log in to your BOINC AI account so you can upload and share your model with the community. When prompted, enter your token to log in:
Copied
from boincai_hub import notebook_login
notebook_login()
Load the Pokรฉmon BLIP captions dataset
Use the ๐ Dataset library to load a dataset that consists of {image-caption} pairs. To create your own image captioning dataset in PyTorch, you can follow .
Copied
from datasets import load_dataset
ds = load_dataset("lambdalabs/pokemon-blip-captions")
ds
Many image captioning datasets contain multiple captions per image. In those cases, a common strategy is to randomly sample a caption amongst the available ones during training.
Split the datasetโs train split into a train and test set with the [~datasets.Dataset.train_test_split] method:
Letโs visualize a couple of samples from the training set.
Copied
from textwrap import wrap
import matplotlib.pyplot as plt
import numpy as np
def plot_images(images, captions):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
caption = captions[i]
caption = "\n".join(wrap(caption, 12))
plt.title(caption)
plt.imshow(images[i])
plt.axis("off")
sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)
Preprocess the dataset
Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions.
To do so, load the processor class associated with the model you are about to fine-tune.
Copied
from transformers import AutoProcessor
checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)
The processor will internally pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption.
Copied
def transforms(example_batch):
images = [x for x in example_batch["image"]]
captions = [x for x in example_batch["text"]]
inputs = processor(images=images, text=captions, padding="max_length")
inputs.update({"labels": inputs["input_ids"]})
return inputs
train_ds.set_transform(transforms)
test_ds.set_transform(transforms)
With the dataset ready, you can now set up the model for fine-tuning.
Load a base model
Copied
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(checkpoint)