Datasets
  • 🌍GET STARTED
    • Datasets
    • Quickstart
    • Installation
  • 🌍TUTORIALS
    • Overview
    • Load a dataset from the Hub
    • Know your dataset
    • Preprocess
    • Evaluate predictions
    • Create a data
    • Share a dataset to the Hub
  • 🌍HOW-TO GUIDES
    • Overview
    • 🌍GENERAL USAGE
      • Load
      • Process
      • Stream
      • Use with TensorFlow
      • Use with PyTorch
      • Use with JAX
      • Use with Spark
      • Cache management
      • Cloud storage
      • Search index
      • Metrics
      • Beam Datasets
    • 🌍AUDIO
      • Load audio data
      • Process audio data
      • Create an audio dataset
    • 🌍VISION
      • Load image data
      • Process image data
      • Create an image dataset
      • Depth estimation
      • Image classification
      • Semantic segmentation
      • Object detection
    • 🌍TEXT
      • Load text data
      • Process text data
    • 🌍TABULAR
      • Load tabular data
    • 🌍DATASET REPOSITORY
      • Share
      • Create a dataset card
      • Structure your repository
      • Create a dataset loading script
  • 🌍CONCEPTUAL GUIDES
    • Datasets with Arrow
    • The cache
    • Dataset or IterableDataset
    • Dataset features
    • Build and load
    • Batch mapping
    • All about metrics
  • 🌍REFERENCE
    • Main classes
    • Builder classes
    • Loading methods
    • Table Classes
    • Logging methods
    • Task templates
Powered by GitBook
On this page
  1. HOW-TO GUIDES
  2. VISION

Semantic segmentation

PreviousImage classificationNextObject detection

Last updated 1 year ago

Semantic segmentation

Semantic segmentation datasets are used to train a model to classify every pixel in an image. There are a wide variety of applications enabled by these datasets such as background removal from images, stylizing images, or scene understanding for autonomous driving. This guide will show you how to apply transformations to an image segmentation dataset.

Before you start, make sure you have up-to-date versions of albumentations and cv2 installed:

Copied

pip install -U albumentations opencv-python

is a Python library for performing data augmentation for computer vision. It supports various computer vision tasks such as image classification, object detection, segmentation, and keypoint estimation.

This guide uses the dataset for segmenting and parsing an image into different image regions associated with semantic categories, such as sky, road, person, and bed.

Load the train split of the dataset and take a look at an example:

Copied

>>> from datasets import load_dataset

>>> dataset = load_dataset("scene_parse_150", split="train")
>>> index = 10
>>> dataset[index]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=683x512 at 0x7FB37B0EC810>,
 'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=683x512 at 0x7FB37B0EC9D0>,
 'scene_category': 927}

The dataset has three fields:

  • image: a PIL image object.

  • annotation: segmentation mask of the image.

  • scene_category: the label or scene category of the image (like β€œkitchen” or β€œoffice”).

Next, check out an image with:

Copied

>>> dataset[index]["image"]

Similarly, you can check out the respective segmentation mask:

Copied

>>> dataset[index]["annotation"]

After defining the color palette, you should be ready to visualize some overlays.

Copied

>>> import matplotlib.pyplot as plt

>>> def visualize_seg_mask(image: np.ndarray, mask: np.ndarray):
...    color_seg = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
...    palette = np.array(create_ade20k_label_colormap())
...    for label, color in enumerate(palette):
...        color_seg[mask == label, :] = color
...    color_seg = color_seg[..., ::-1]  # convert to BGR

...    img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
...    img = img.astype(np.uint8)

...    plt.figure(figsize=(15, 10))
...    plt.imshow(img)
...    plt.axis("off")
...    plt.show()


>>> visualize_seg_mask(
...     np.array(dataset[index]["image"]),
...     np.array(dataset[index]["annotation"])
... )

Now apply some augmentations with albumentations. You’ll first resize the image and adjust its brightness.

Copied

>>> import albumentations

>>> transform = albumentations.Compose(
...     [
...         albumentations.Resize(256, 256),
...         albumentations.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
...     ]
... )

Create a function to apply the transformation to the images:

Copied

>>> def transforms(examples):
...     transformed_images, transformed_masks = [], []
...
...     for image, seg_mask in zip(examples["image"], examples["annotation"]):
...         image, seg_mask = np.array(image), np.array(seg_mask)
...         transformed = transform(image=image, mask=seg_mask)
...         transformed_images.append(transformed["image"])
...         transformed_masks.append(transformed["mask"])
...
...     examples["pixel_values"] = transformed_images
...     examples["label"] = transformed_masks
...     return examples

Copied

>>> dataset.set_transform(transforms)

You can verify the transformation worked by indexing into the pixel_values and label of an example:

Copied

>>> image = np.array(dataset[index]["pixel_values"])
>>> mask = np.array(dataset[index]["label"])

>>> visualize_seg_mask(image, mask)

In this guide, you have used albumentations for augmenting the dataset. It’s also possible to use torchvision to apply some similar transforms.

Copied

>>> from torchvision.transforms import Resize, ColorJitter, Compose

>>> transformation_chain = Compose([
...     Resize((256, 256)),
...     ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
... ])
>>> resize = Resize((256, 256))

>>> def train_transforms(example_batch):
...     example_batch["pixel_values"] = [transformation_chain(x) for x in example_batch["image"]]
...     example_batch["label"] = [resize(x) for x in example_batch["annotation"]]
...     return example_batch

>>> dataset.set_transform(train_transforms)

>>> image = np.array(dataset[index]["pixel_values"])
>>> mask = np.array(dataset[index]["label"])

>>> visualize_seg_mask(image, mask)

We can also add a on the segmentation mask and overlay it on top of the original image to visualize the dataset:

Use the function to apply the transformation on-the-fly to batches of the dataset to consume less disk space:

Now that you know how to process a dataset for semantic segmentation, learn and use it for inference.

🌍
🌍
Albumentations
Scene Parsing
color palette
set_transform()
how to train a semantic segmentation model