Use with PyTorch
Use with PyTorch
This document is a quick introduction to using datasets with PyTorch, with a particular focus on how to get torch.Tensor objects out of our datasets, and how to use a PyTorch DataLoader and a Hugging Face Dataset with the best performance.
Dataset format
By default, datasets return regular python objects: integers, floats, strings, lists, etc.
To get PyTorch tensors instead, you can set the format of the dataset to pytorch using Dataset.with_format():
Copied
>>> from datasets import Dataset
>>> data = [[1, 2],[3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("torch")
>>> ds[0]
{'data': tensor([1, 2])}
>>> ds[:2]
{'data': tensor([[1, 2],
[3, 4]])}A Dataset object is a wrapper of an Arrow table, which allows fast zero-copy reads from arrays in the dataset to PyTorch tensors.
To load the data as tensors on a GPU, specify the device argument:
Copied
N-dimensional arrays
If your dataset consists of N-dimensional arrays, you will see that by default they are considered as nested lists. In particular, a PyTorch formatted dataset outputs nested lists instead of a single tensor:
Copied
To get a single tensor, you must explicitly use the Array feature type and specify the shape of your tensors:
Copied
Other feature types
ClassLabel data are properly converted to tensors:
Copied
String and binary objects are unchanged, since PyTorch only supports numbers.
The Image and Audio feature types are also supported.
To use the Image feature type, youβll need to install the vision extra as pip install datasets[vision].
Copied
To use the Audio feature type, youβll need to install the audio extra as pip install datasets[audio].
Copied
Data loading
Like torch.utils.data.Dataset objects, a Dataset can be passed directly to a PyTorch DataLoader:
Copied
Optimize data loading
There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets. PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to iterate over the dataset without downloading it on disk.
Use multiple Workers
You can parallelize data loading with the num_workers argument of a PyTorch DataLoader and get a higher throughput.
Under the hood, the DataLoader starts num_workers processes. Each process reloads the dataset passed to the DataLoader and is used to query examples. Reloading the dataset inside a worker doesnβt fill up your RAM, since it simply memory-maps the dataset again from your disk.
Copied
Stream data
Stream a dataset by loading it as an IterableDataset. This allows you to progressively iterate over a remote dataset without downloading it on disk and or over local data files. Learn more about which type of dataset is best for your use case in the choosing between a regular dataset or an iterable dataset guide.
An iterable dataset from datasets inherits from torch.utils.data.IterableDataset so you can pass it to a torch.utils.data.DataLoader:
Copied
If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using num_workers:
Copied
In this case each worker is given a subset of the list of shards to stream from.
Distributed
To split your dataset across your training nodes, you can use datasets.distributed.split_dataset_by_node():
Copied
This works for both map-style datasets and iterable datasets. The dataset is split for the node at rank rank in a pool of nodes of size world_size.
For map-style datasets:
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
For iterable datasets:
If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.n_shards % world_size == 0), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of world_size, skipping the other examples.
This can also be combined with a torch.utils.data.DataLoader if you want each node to use multiple workers to load the data.
Last updated