Use with Spark
Last updated
Last updated
This document is a quick introduction to using π Datasets with Spark, with a particular focus on how to load a Spark DataFrame into a object.
From there, you have fast access to any element and you can use it as a data loader to train models.
A object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to PyTorch, TensorFlow and JAX tensors. The Arrow table is memory mapped from disk, which can load datasets bigger than your available RAM.
You can get a from a Spark DataFrame using Dataset.from_spark()
:
Copied
The Spark workers write the dataset on disk in a cache directory as Arrow files, and the is loaded from there.
Alternatively, you can skip materialization by using IterableDataset.from_spark()
, which returns an :
Copied
You can set the cache location by passing cache_dir=
to Dataset.from_spark()
. Make sure to use a disk that is available to both your workers and your current machine (the driver).
If your dataset is made of images, audio data or N-dimensional arrays, you can specify the features=
argument in Dataset.from_spark()
(or IterableDataset.from_spark()
):
Copied
When using Dataset.from_spark()
, the resulting is cached; if you call Dataset.from_spark()
multiple times on the same DataFrame it wonβt re-run the Spark job that writes the dataset as Arrow files on disk.
In a different session, a Spark DataFrame doesnβt have the same , and it will rerun a Spark job and store it in a new cache.
You can check the documentation to know about all the feature types available.