Create batch from list of ids in the dataset is very slow

Hi all,

For a given use case, I need to forge batches where sampling depends on element within the batch, e.g to create batches with similar elements.
The solution I found online for such approaches is to first get elements from the dataloader (as usual) and then get additionnals elements depending on these first ones in the collate function.

So I built a custom collate function where I get a list of ids corresponding to the initial elements and the additionnal mined ones and I am trying to return the corresponding batch. However, constructing the dictionnary to return is really slow. It seems to come from the selection of items given the list of ids ( examples = train_dataset[mined_ids]).
I also tried to use examples = train_dataset.select(mined_ids) which is faster, but then, accessing column is very slow (input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long).
Finally, I also tried to use set_format to directly get numpy arrays/tensors, but it is still very slow.

Is there a better way to create batches given a list of ids or another way to build batches where elements depends on other elements of the batch ?

After digging more into this, it is obviously coming from the _getitem() function, but not from querying the table, rather formatting the resulting table.
More precisely, it comes from this function:

 def extract_batch(self, pa_table: pa.Table) -> dict:
        return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}

I suppose this is due to the stored format of the data, but it seems like a numpy dataset allows a faster indexing through a list of id.
Is there a special method I’m missing ? Is there any workaround ?