I have a list of IterableDatasets and would like to return homogeneous batches meaning that each batch should only contain examples from one dataset at the time. I am running training with the HF trainer + accelerate for multi-gpu training.
Hi ! Have you tried this ?
ds = interleave_datasets(datasets)
dataloader = DataLoader(ds, batch_size=len(datasets))
This will just randomly from both datasets, thus not create batches that consist only of examples of a single dataset…
ah sorry I misread, I though you asked for batches with examples for every dataset.
in your case you want the examples to be from the same dataset:
datasets = [dataset.batch(batch_size, drop_last_batch=True) for dataset in datasets]
ds = interleave_datasets(datasets)
dataloader = DataLoader(ds, batch_size=None)
1 Like
I found a solution:
from datasets import IterableDataset, Features, Value
import datasets
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import DataLoader
features = Features({
'count': Value('int64'),
'doc': Value('string'), # Allows string or None
})
def count_generator():
for i in range(201):
yield {"count": i, "doc": "text"}
def count_generator2():
for i in range(1000, 1201):
yield {"count": i, "doc": ""}
batch_size = 4
dataset = IterableDataset.from_generator(generator=count_generator, features=features).batch(batch_size=batch_size)
dataset2 = IterableDataset.from_generator(generator=count_generator2, features=features).batch(batch_size=batch_size)
interleaved_dataset = datasets.interleave_datasets([dataset, dataset2])
sharded_dataset = IterableDatasetShard(
interleaved_dataset,
batch_size=1,
num_processes=1,
process_index=1,
)
data_loader = DataLoader(
interleaved_dataset,
batch_size=1,
)
for batch in data_loader:
print(batch)
But it is a lot slower than just interleaving the datasets and passing it to the HF Trainer (heterogeneous batches).
To be more precise i’m overwriting get_train_dataloader
in my Trainer
:
def get_train_dataloader(self):
train_dataset = self._remove_unused_columns(self.train_dataset)
if isinstance(self.train_dataset, IterableDataset):
train_dataset = IterableDatasetShard(
self.train_dataset,
batch_size=1,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
print(self.args.dataloader_num_workers)
return DataLoader(
train_dataset,
batch_size=None,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
else:
return super().get_train_dataloader()
I’m using accelerate, passing batch_size=None is not possible in this case