Create batch from list of ids in the dataset is very slow

Hi all,

For a given use case, I need to forge batches where sampling depends on element within the batch, e.g to create batches with similar elements.
The solution I found online for such approaches is to first get elements from the dataloader (as usual) and then get additionnals elements depending on these first ones in the collate function.

So I built a custom collate function where I get a list of ids corresponding to the initial elements and the additionnal mined ones and I am trying to return the corresponding batch. However, constructing the dictionnary to return is really slow. It seems to come from the selection of items given the list of ids ( examples = train_dataset[mined_ids]).
I also tried to use examples = train_dataset.select(mined_ids) which is faster, but then, accessing column is very slow (input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long).
Finally, I also tried to use set_format to directly get numpy arrays/tensors, but it is still very slow.

Is there a better way to create batches given a list of ids or another way to build batches where elements depends on other elements of the batch ?

After digging more into this, it is obviously coming from the _getitem() function, but not from querying the table, rather formatting the resulting table.
More precisely, it comes from this function:

 def extract_batch(self, pa_table: pa.Table) -> dict:
        return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}

I suppose this is due to the stored format of the data, but it seems like a numpy dataset allows a faster indexing through a list of id.
Is there a special method I’m missing ? Is there any workaround ?

I’m running into this myself. PythonArrowExtractor.extract_batch is very slow.

Is there a better or faster way to do this? It’s causing my GPUs to be work starved while data is being extracted from the arrow table. I chose arrow because it was supposed to be memory mapped, zero copy, fast, etc… but this seems largely untrue now.

1 Like

I tried a suggestion from this thread Local dataset loading performance: HF's arrow vs torch.load - #3 by mztelus to call .with_format('torch'), but that did NOT help either. Now most of the time is spent in PyArrow’s ChunkedArray.to_numpy() method (pyarrow.ChunkedArray — Apache Arrow v18.0.0).

1 Like

Update: For me a suggestion from @nbroad helped - to increase the number of dataloader workers to 2 and I also increased the prefetch factor to 16.

2 Likes

Hi! I am experiencing the same issue. I have a dataset with images, not to many, about 2,000. I want to create embeddings from a ViT model using the .map() method.

The issue is see is that compared to iterating over the dataset in a simple for loop it is ~5 times slower!

I tried with many different batch size options like keep_in_memory=True, but it makes no difference. The only thing that makes a difference is to set batching=False. Then I can see it runs with the same speed as the for loop, but it gets stuck at certain point, waiting for something, before continuing, which is also not ideal.

It’s kindof ok for smaller datasets, but definitly an issue for bigger ones. Anyone has similar experiences? Or any idea what could be wrong? Or is this normal behaviour with the map() function?

1 Like