This document is a quick introduction to using datasets with JAX, with a particular focus on how to get jax.Array objects out of our datasets, and how to use them to train JAX models.
jax and jaxlib are required to reproduce to code above, so please make sure you install them as pip install datasets[jax].
Dataset format
By default, datasets return regular Python objects: integers, floats, strings, lists, etc., and string and binary objects are unchanged, since JAX only supports numbers.
To get JAX arrays (numpy-like) instead, you can set the format of the dataset to jax:
Note that the exact same procedure applies to DatasetDict objects, so that when setting the format of a DatasetDict to jax, all the Datasets there will be formatted as jax:
Another thing you’ll need to take into consideration is that the formatting is not applied until you actually access the data. So if you want to get a JAX array out of a dataset, you’ll need to access the data first, otherwise the format will remain the same.
Finally, to load the data in the device of your choice, you can specify the device argument, but note that jaxlib.xla_extension.Device is not supported as it’s not serializable with neither pickle not dill, so you’ll need to use its string identifier instead:
Copied
>>> import jax
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> device = str(jax.devices()[0]) # Not casting to `str` before passing it to `with_format` will raise a `ValueError`
>>> ds = ds.with_format("jax", device=device)
>>> ds[0]
{'data': DeviceArray([1, 2], dtype=int32)}
>>> ds[0]["data"].device()
TFRT_CPU_0
>>> assert ds[0]["data"].device() == jax.devices()[0]
True
Note that if the device argument is not provided to with_format then it will use the default device which is jax.devices()[0].
N-dimensional arrays
If your dataset consists of N-dimensional arrays, you will see that by default they are considered as nested lists. In particular, a JAX formatted dataset outputs a DeviceArray object, which is a numpy-like array, so it does not need the Array feature type to be specified as opposed to PyTorch or TensorFlow formatters.
So that’s the reason why JAX-formatting in datasets is so useful, because it lets you use any model from the HuggingFace Hub with JAX, without having to worry about the data loading part.
A object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to JAX arrays.
data is properly converted to arrays:
The and feature types are also supported.
To use the feature type, you’ll need to install the vision extra as pip install datasets[vision].
To use the feature type, you’ll need to install the audio extra as pip install datasets[audio].
JAX doesn’t have any built-in data loading capabilities, so you’ll need to use a library such as to load your data using a DataLoader or using a tf.data.Dataset. Citing the on this topic: “JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll grab PyTorch’s data loader, and make a tiny shim to make it work with NumPy arrays.”.
The easiest way to get JAX arrays out of a dataset is to use the with_format('jax') method. Lets assume that we want to train a neural network on the available at the HuggingFace Hub at .