According to the set_transform
documentation:
A formatting function is a callable that takes a batch (as a dict) as input and returns a batch.
so it will always be one dictionary that is passed you your transform, but the values in the dictionary are lists of size batch_size
. Can you check if this is the case ?
If this is the case then you are good and you can process the examples by batch
But if you have only lists of 1 element, then the issue might come from the data loader.
Indeed, by default the pytorch data loader load batches of data from a dataset one by one like this:
batch = [dataset[idx] for idx in range(start, end)]
Therefore the augmentation function passed to set_transform
is called batch_size
times with one element. For the function to get more than one item per execution, it should be used like this instead:
batch = dataset[start:end]
# or
batch = dataset[list_of_indices]
I think you can change the pytorch data loading behavior to work this way if you use the BatchSampler
Let me know if that helps !