For a given use case, I need to forge batches where sampling depends on element within the batch, e.g to create batches with similar elements.
The solution I found online for such approaches is to first get elements from the dataloader (as usual) and then get additionnals elements depending on these first ones in the collate function.
So I built a custom collate function where I get a list of ids corresponding to the initial elements and the additionnal mined ones and I am trying to return the corresponding batch. However, constructing the dictionnary to return is really slow. It seems to come from the selection of items given the list of ids ( examples = train_dataset[mined_ids]).
I also tried to use examples = train_dataset.select(mined_ids) which is faster, but then, accessing column is very slow (input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long).
Finally, I also tried to use set_format to directly get numpy arrays/tensors, but it is still very slow.
Is there a better way to create batches given a list of ids or another way to build batches where elements depends on other elements of the batch ?
After digging more into this, it is obviously coming from the _getitem() function, but not from querying the table, rather formatting the resulting table.
More precisely, it comes from this function:
def extract_batch(self, pa_table: pa.Table) -> dict:
return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
I suppose this is due to the stored format of the data, but it seems like a numpy dataset allows a faster indexing through a list of id.
Is there a special method I’m missing ? Is there any workaround ?
I’m running into this myself. PythonArrowExtractor.extract_batch is very slow.
Is there a better or faster way to do this? It’s causing my GPUs to be work starved while data is being extracted from the arrow table. I chose arrow because it was supposed to be memory mapped, zero copy, fast, etc… but this seems largely untrue now.