I tried to use one of my data collators inside a function passed to the datasets.map. The goal was to measure something on model outputs. To sketch it I wanted to do something similar to
def measure_sth(examples, model):
batch = COLLATE_FUNCTION(examples)
out = model.forward(batch)
return out
The problem was that the torch Samplers provide List[dict] to the collate function and map function dict[string, list]. So I ended up doing some ugly dict[string, list] -> List[dict]
def measure_sth(examples, model):
batch_of_examples = [{} for _ in range(1000)]
for key, value in examples.items():
for i, v in enumerate(value):
batch_of_examples[i][key] = v
batch = COLLATE_FUNCTION(examples)
out = model.forward(batch)
return out
So there is my question if there is a ways to change procedure of batching for either torch Samplers or datasets.map?
@mariosasko
Unfortunately this doesn’t work. I’am obtaining following exception:
'dict' object has no attribute 'to_dict'
I updated my datasets to 2.11 version and it still doesn’t work.
What I store in the dataset are audio files. Moreover I used .with_format("torch") instead of pandas.
Well, you need to implement the conversion yourself in the map transform then as datasets does not support stacking formatting transforms (calling with_format("pandas") after with_format("torch") overwrites the torch formatting)