Datasets
  • 🌍GET STARTED
    • Datasets
    • Quickstart
    • Installation
  • 🌍TUTORIALS
    • Overview
    • Load a dataset from the Hub
    • Know your dataset
    • Preprocess
    • Evaluate predictions
    • Create a data
    • Share a dataset to the Hub
  • 🌍HOW-TO GUIDES
    • Overview
    • 🌍GENERAL USAGE
      • Load
      • Process
      • Stream
      • Use with TensorFlow
      • Use with PyTorch
      • Use with JAX
      • Use with Spark
      • Cache management
      • Cloud storage
      • Search index
      • Metrics
      • Beam Datasets
    • 🌍AUDIO
      • Load audio data
      • Process audio data
      • Create an audio dataset
    • 🌍VISION
      • Load image data
      • Process image data
      • Create an image dataset
      • Depth estimation
      • Image classification
      • Semantic segmentation
      • Object detection
    • 🌍TEXT
      • Load text data
      • Process text data
    • 🌍TABULAR
      • Load tabular data
    • 🌍DATASET REPOSITORY
      • Share
      • Create a dataset card
      • Structure your repository
      • Create a dataset loading script
  • 🌍CONCEPTUAL GUIDES
    • Datasets with Arrow
    • The cache
    • Dataset or IterableDataset
    • Dataset features
    • Build and load
    • Batch mapping
    • All about metrics
  • 🌍REFERENCE
    • Main classes
    • Builder classes
    • Loading methods
    • Table Classes
    • Logging methods
    • Task templates
Powered by GitBook
On this page
  • Use with Spark
  • Load from Spark
  1. HOW-TO GUIDES
  2. GENERAL USAGE

Use with Spark

PreviousUse with JAXNextCache management

Last updated 1 year ago

Use with Spark

This document is a quick introduction to using 🌍 Datasets with Spark, with a particular focus on how to load a Spark DataFrame into a object.

From there, you have fast access to any element and you can use it as a data loader to train models.

Load from Spark

A object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to PyTorch, TensorFlow and JAX tensors. The Arrow table is memory mapped from disk, which can load datasets bigger than your available RAM.

You can get a from a Spark DataFrame using Dataset.from_spark():

Copied

>>> from datasets import Dataset
>>> df = spark.createDataFrame(
...     data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
...     columns=["id", "name"],
... )
>>> ds = Dataset.from_spark(df)

The Spark workers write the dataset on disk in a cache directory as Arrow files, and the is loaded from there.

Alternatively, you can skip materialization by using IterableDataset.from_spark(), which returns an :

Copied

>>> from datasets import IterableDataset
>>> df = spark.createDataFrame(
...     data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
...     columns=["id", "name"],
... )
>>> ds = IterableDataset.from_spark(df)
>>> print(next(iter(ds)))
{"id": 1, "name": "Elia"}

Caching

You can set the cache location by passing cache_dir= to Dataset.from_spark(). Make sure to use a disk that is available to both your workers and your current machine (the driver).

Feature types

If your dataset is made of images, audio data or N-dimensional arrays, you can specify the features= argument in Dataset.from_spark() (or IterableDataset.from_spark()):

Copied

>>> from datasets import Dataset, Features, Image, Value
>>> data = [(0, open("image.png", "rb").read())]
>>> df = spark.createDataFrame(data, "idx: int, image: binary")
>>> # Also works if you have arrays
>>> # data = [(0, np.zeros(shape=(32, 32, 3), dtype=np.int32).tolist())]
>>> # df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
>>> features = Features({"idx": Value("int64"), "image": Image()})
>>> dataset = Dataset.from_spark(df, features=features)
>>> dataset[0]
{'idx': 0, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>}

When using Dataset.from_spark(), the resulting is cached; if you call Dataset.from_spark() multiple times on the same DataFrame it won’t re-run the Spark job that writes the dataset as Arrow files on disk.

In a different session, a Spark DataFrame doesn’t have the same , and it will rerun a Spark job and store it in a new cache.

You can check the documentation to know about all the feature types available.

🌍
🌍
Dataset
Dataset
Dataset
Dataset
IterableDataset
Dataset
semantic hash
Features