Weird example of batching in Dataset.map document

In the document of Dataset.map (here), the example given in “Batch processing” → “Split long examples” says “Batch processing enables interesting applications such as splitting long sentences into shorter chunks and data augmentation” with the following code:

def chunk_examples(examples):
    chunks = []
    for sentence in examples["sentence1"]:
        chunks += [sentence[i:i + 50] for i in range(0, len(sentence), 50)]
    return {"chunks": chunks}
chunked_dataset = dataset.map(chunk_examples, batched=True, remove_columns=dataset.column_names)

This example looks weird to me, because it seems it’s just doing the chunking row by row, so it does not matter whether we are doing it row by row or in a batched way. Did I misunderstand anything here?

By default, map requires an input one 1 example and to output 1 example.

But a batched map can take a input batch of size N and output a batch of size M.
The code you provided indeed returns a batch with more examples than the input.

1 Like

Hi @lhoestq , with that example above, it is just splitting sentence1, but not saving the labels, ie the label and idx columns in the chunked_dataset. How does one split a long sentence and save the labels in a format for the Trainer to understand?

You can duplicate the label for each chunk:

def chunk_examples(examples):
    chunks = []
    labels = []
    for sentence, label in zip(examples["sentence1"], examples["label"]):
        chunks += [sentence[i:i + 50] for i in range(0, len(sentence), 50)]
        labels += [label] * len(range(0, len(sentence), 50))
    return {"chunk": chunks, "label": labels}
chunked_dataset = dataset.map(chunk_examples, batched=True, remove_columns=dataset.column_names)

And if you’re doing NER and have multiple labels per sentence:

def chunk_examples(examples):
    chunks = []
    labels = []
    for sentence, labels in zip(examples["sentence1"], examples["labels"]):
        chunks += [sentence[i:i + 50] for i in range(0, len(sentence), 50)]
        labels += [labels[i:i + 50] for i in range(0, len(labels), 50)]
    return {"chunk": chunks, "labels": labels}
chunked_dataset = dataset.map(chunk_examples, batched=True, remove_columns=dataset.column_names)