How to handle IterableDataset with HuggingFace trainer and num_workers in DDP setup

Hey, I am trying to train a custom model (which inherits from PreTrainedModel) with IterableDataset using the HuggingFace Trainer in a DDP setup and I have a couple of questions on how to do it best as well as some of my observations.

  1. I know there is a datasets.distributed.split_dataset_by_node() function which can distribute the shards across the nodes, but should I call it myself inside the get_train_dataloader() function inside the Trainer or is it handled automatically?
  2. I noticed that the training fails with:
RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.either pass `dispatch_batches=False` and have each process fetch its own batch  or pass `split_batches=True`. By doing so, the main process will fetch a full batch and slice it into `num_processes` batches for each process.` if I do not add an argument  `accelerator_config=AcceleratorConfig(dispatch_batches=False, split_batches=False)` to the `TrainingArguments`. 

Should it not be handled automatically? Am I doing something wrong here in specifying it?

  1. How do I specify the num_workers to use in the DataLoader with DDP? When I specify num_workers > n_gpus (i.e. world_size), I get an error RuntimeError: DataLoader worker (pid(s) 572878) exited unexpectedly. This is an issue for me because the optimal num_workers in a single GPU setup is num_workers=16, hence, if I only use one worker per GPU/process. This results in slower training than on a single GPU. I also checked with nvidia-smi that because of this GPUs are idle most of the time.
  2. I noticed that if I do:
def get_train_dataloader():
   dataset = split_dataset_by_node(dataset, rank=self.args.local_rank, world_size=self.args.world_size)
   return DataLoader(dataset, **dataloader_params)

the training is faster than:

def get_train_dataloader():
   dataset = split_dataset_by_node(dataset, rank=self.args.local_rank, world_size=self.args.world_size)
   self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

Any idea why is this?

Here is how my get_train_dataloader() function inside trainer looks currently:

def get_train_dataloader(self) -> DataLoader:
        Returns the training [``].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,

        if not isinstance(train_dataset,
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
            # if using DDP, we need to split the dataset by node
            if self.args.world_size > 1:
                train_dataset = split_dataset_by_node(
                    train_dataset, rank=self.args.local_rank, world_size=self.args.world_size
        return DataLoader(train_dataset, **dataloader_params)

this is how my IterableDataset is created:

train_dataset = load_dataset("parquet", data_files=data_files, split="train", streaming=True).map(
        transform_fn, batched=False, with_indices=False

and I kick off the training with 4 GPUs on:

torchrun --nproc_per_node 4 <args>

I use:


Nr 3 is the biggest issue! It seems like as currently the DDP training is slower than single-GPU training because it uses only one worker, instead of 16, and there is quite significant on-the-fly processing, hence the importance of having multiple workers.

I think the self.accelerator from accelerate takes care of skipping examples from the IterableDataset depending on the node rank (@muellerzr can confirm ? I found IterableDatasetShard in the docs)

Otherwise using split_dataset_by_node does the job

Note that datasets.distributed.split_dataset_by_node distributes the shards across nodes (if possible) instead of simply skipping examples and wasting resources allocated to on-the-fly processing, which can be useful.

Thanks @lhoestq!

In this case I will use both the self.accelerator and datasets.distributed_split_dataset_by_node.

Some other notes which helped me in efficient DDP training with large iterable dataset (5TB+):

  1. Doing all of the preprocessing beforehand and minimising the columns read by the dataset to save up memory.
  2. Splitting large files (in my case each file is a shard) into smaller ones. So now instead of 100 large files from which I stream from, I have ~800 smaller files. Each worker is assigned to each shard (i.e. file).

Because of this I can now run the training with much more workers (60 instead of 12) without getting error (RuntimeError: DataLoader worker (pid(s) 572878) exited unexpectedly) which leads to much faster DDP training.

In this case I will use both the self.accelerator and datasets.distributed_split_dataset_by_node .

This might be redundant ? cc @muellerzr in case the self.accelerator would skip examples on a dataset that is already split by node

1 Like

The self.accelerator does skip examples on a dataset that is already split by node as you pointed out, so no need to have datasets.distributed_split_dataset_by_node