To read tokenized text from Arrow, the bottleneck is often the conversion of the tokenized Arrow data to pythons lists.
It’s much faster to load them as torch tensors directly - since the data is loaded using zero-copy from your disk:
dataset = load_from_disk("mydata",keep_in_memory=False)
dataset = dataset.with_format("torch")
loader = DataLoader(dataset, batch_size=1000, collate_fn = collate_fn)
Also make sure to use the latest versions of datasets
and torch
(minimum 2.10 and 1.13 to get the best speed)