Is there a way to change batching behaviour of `map`?

Hello,

I tried to use one of my data collators inside a function passed to the datasets.map. The goal was to measure something on model outputs. To sketch it I wanted to do something similar to

def measure_sth(examples, model):
        batch = COLLATE_FUNCTION(examples)
        out = model.forward(batch)
        return out
dataset = dataset.map(measure_sth, fn_kwargs={"model": model}, batched=True)

The problem was that the torch Samplers provide List[dict] to the collate function and map function dict[string, list]. So I ended up doing some ugly dict[string, list] -> List[dict]

def measure_sth(examples, model):
        batch_of_examples = [{} for _ in range(1000)]
        for key, value in examples.items():
            for i, v in enumerate(value):
                batch_of_examples[i][key] = v

        batch = COLLATE_FUNCTION(examples)
        out = model.forward(batch)
        return out

So there is my question if there is a ways to change procedure of batching for either torch Samplers or datasets.map?

1 Like

Hi!

You can do the following to avoid having to convert dict[string, list] -> List[dict] manually:

def measure_sth(examples, model):
        examples = examples.to_dict(orient="records")
        batch = COLLATE_FUNCTION(examples)
        out = model.forward(batch)
        return out

dataset = dataset.with_format("pandas")
dataset = dataset.map(measure_sth, fn_kwargs={"model": model}, batched=True)
dataset = dataset.with_format(None)

Perhaps we could make orient a parameter of the PythonFormatter (the default one) to allow doing:

dataset = dataset.with_format(None, orient="records")
dataset = dataset.map(...)
1 Like

@mariosasko
Unfortunately this doesn’t work. I’am obtaining following exception:

'dict' object has no attribute 'to_dict'

I updated my datasets to 2.11 version and it still doesn’t work.
What I store in the dataset are audio files. Moreover I used .with_format("torch") instead of pandas.

Well, you need to implement the conversion yourself in the map transform then as datasets does not support stacking formatting transforms (calling with_format("pandas") after with_format("torch") overwrites the torch formatting)