Trying Grain, a data loader for Jax

Published: 2024-12-18

Even though I use jax for training my neural networks, I have been using pytorch for loading my data, as it was recommended on the jax documentation. Recently, I learned about grain, a data loader built for data loading for jax. Despite it is still at its early day, I decided to give it a try.

The documentation is still pretty primitive (As of version 0.2.2), especially on their getting started page, they only favour Google’s infrastructure for data loading such as ArrayRecord and TFDS. ArrayRecord doesn’t seem to have any documentation, and the last release is about a year ago at this point (Oct 30, 2024), so it doesn’t sound like a good option to me. And I hate anything related to tensorflow with a passion, so not that either. After some digging, they do have way to define a custom data loader, here is what I cooked up for loading CIFAR10 dataset:

import tarfile
import requests
import pickle
import grain.python as grain
from jaxtyping import Int, Array
import jax.numpy as jnp

class CIFAR10DataSource(grain.RandomAccessDataSource):

    def read_cifar(self, path: str) -> tuple[Int[Array, "50000 3 32 32"], Int[Array, "50000 10"]]:
        inputs = []
        labels = []
        for i in range(1, 6):
            with open(path+'data_batch_'+str(i), 'rb') as f:
                data = pickle.load(f, encoding='bytes')
                inputs.append(jnp.array(data[b'data']))
                labels.append(jnp.array(data[b'labels']))
        
        return jnp.concatenate(inputs).reshape(-1, 3, 32, 32), jnp.concatenate(labels)

    def __init__(self, data_dir: str = 'data/'):
        super().__init__()
        try:
            self.data = self.read_cifar(data_dir + 'cifar-10-batches-py/')
        except FileNotFoundError:
            print('Downloading CIFAR-10 dataset...')
            url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
            r = requests.get(url)
            with open(data_dir+'cifar-10-python.tar.gz', 'wb') as f:
                f.write(r.content)
            with tarfile.open(data_dir+'cifar-10-python.tar.gz') as tar:
                tar.extractall(data_dir)
            self.data = self.read_cifar(data_dir + 'cifar-10-batches-py/')

    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

This is pretty similar to PyTorch, but unlike PyTorch, you can/have to define your own shuffling and sampler. Taking an example from the documentation

index_sampler = grain.IndexSampler(
    num_records=5,
    num_epochs=2,
    shard_options=grain.ShardOptions(
        shard_index=0, shard_count=1, drop_remainder=True),
    shuffle=True,
    seed=0)

data_loader = grain.DataLoader(
    data_source=source,
    operations=transformations,
    sampler=index_sampler,
    worker_count=0)

I kind of like to have fine-grain control over my code, so I like more explicit approach like this. Haven’t benchmark this yet, so I will probably write that into the next blog