For a given use case, I need to forge batches where sampling depends on element within the batch, e.g to create batches with similar elements.
The solution I found online for such approaches is to first get elements from the dataloader (as usual) and then get additionnals elements depending on these first ones in the collate function.
So I built a custom collate function where I get a list of ids corresponding to the initial elements and the additionnal mined ones and I am trying to return the corresponding batch. However, constructing the dictionnary to return is really slow. It seems to come from the selection of items given the list of ids (
examples = train_dataset[mined_ids]).
I also tried to use
examples = train_dataset.select(mined_ids) which is faster, but then, accessing column is very slow (
input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long).
Finally, I also tried to use set_format to directly get numpy arrays/tensors, but it is still very slow.
Is there a better way to create batches given a list of ids or another way to build batches where elements depends on other elements of the batch ?