Homogeneous batches from list of IterableDatasets

I have a list of IterableDatasets and would like to return homogeneous batches meaning that each batch should only contain examples from one dataset at the time. I am running training with the HF trainer + accelerate for multi-gpu training.

Hi ! Have you tried this ?

ds = interleave_datasets(datasets)
dataloader = DataLoader(ds, batch_size=len(datasets))

This will just randomly from both datasets, thus not create batches that consist only of examples of a single dataset…

ah sorry I misread, I though you asked for batches with examples for every dataset.

in your case you want the examples to be from the same dataset:

datasets = [dataset.batch(batch_size, drop_last_batch=True) for dataset in datasets]
ds = interleave_datasets(datasets)
dataloader = DataLoader(ds, batch_size=None)
1 Like

I found a solution:

from datasets import IterableDataset, Features, Value
import datasets
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import DataLoader

features = Features({
    'count': Value('int64'),
    'doc': Value('string'),  # Allows string or None
})

def count_generator():
    for i in range(201):
        yield {"count": i, "doc": "text"}

def count_generator2():
    for i in range(1000, 1201):
        yield {"count": i, "doc": ""}

batch_size = 4
dataset = IterableDataset.from_generator(generator=count_generator, features=features).batch(batch_size=batch_size)
dataset2 = IterableDataset.from_generator(generator=count_generator2, features=features).batch(batch_size=batch_size)
interleaved_dataset = datasets.interleave_datasets([dataset, dataset2])

sharded_dataset = IterableDatasetShard(
    interleaved_dataset,
    batch_size=1,
    num_processes=1,
    process_index=1,
)

data_loader = DataLoader(
    interleaved_dataset,
    batch_size=1, 
)

for batch in data_loader:
    print(batch)

But it is a lot slower than just interleaving the datasets and passing it to the HF Trainer (heterogeneous batches).

To be more precise i’m overwriting get_train_dataloader in my Trainer:

def get_train_dataloader(self):
        train_dataset = self._remove_unused_columns(self.train_dataset)
        if isinstance(self.train_dataset, IterableDataset):
            train_dataset = IterableDatasetShard(
                self.train_dataset,
                batch_size=1,
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
            )
            print(self.args.dataloader_num_workers)
            return DataLoader(
                train_dataset,
                batch_size=None,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )
        else:
            return super().get_train_dataloader()

I’m using accelerate, passing batch_size=None is not possible in this case