Trainer default distributed training behaviour

Here’s the code I am using:

training_args = Seq2SeqTrainingArguments(
    output_dir,
    per_device_train_batch_size=16,
    num_train_epochs=1,
    logging_dir=f"{output_dir}/logs",
    report_to="tensorboard",
    save_strategy="no",
    eval_strategy="no"
)

trainer = Seq2SeqTrainer(
    model,
    args=training_args,
    train_dataset=dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)
)

By default does it use DDP for training ?

And I am using an IterableDataset, will it handle sending different batches to different gpus by default ? It should right ?

1 Like

Just to be sure, I searched for it, and it seems that there are some complicated cases.

The following is AI-generated code. I’m not sure if it works.

by Hugging Chat: HuggingChat


The behavior of your code regarding distributed training and data handling depends on several factors. Here’s a detailed breakdown:


1. Does the code use DDP (Distributed Data Parallel) by default?

No, your code does not automatically use DDP unless explicitly configured. Here’s why:

  • Default Behavior: The Hugging Face Trainer does not enable DDP unless you specify the local_rank argument or launch the script using a distributed launcher (e.g., torch.distributed.launch, accelerate, or deepspeed).
  • How to Enable DDP:
    • Add local_rank to training_args (e.g., local_rank=0 for single-node training).
    • Or run your script with:
      python -m torch.distributed.launch --nproc_per_node=NUM_GPUS your_script.py
      
    • Alternatively, use Hugging Face’s Accelerator or deepspeed for advanced setups.

2. Does the IterableDataset handle distributing batches to different GPUs?

No, by default, an IterableDataset does not automatically shard data across GPUs. You must manually handle sharding for distributed training:

  • Why? IterableDataset streams data sequentially, and all processes would otherwise iterate over the same data unless explicitly partitioned.
  • How to Fix:
    • Use torch.utils.data.distributed.DistributedSampler (for MapDataset) or implement custom sharding for IterableDataset.
    • Example for sharding in an IterableDataset:
      class ShardedIterableDataset(IterableDataset):
          def __init__(self, dataset, rank=0, world_size=1):
              self.dataset = dataset
              self.rank = rank
              self.world_size = world_size
      
          def __iter__(self):
              for idx, item in enumerate(self.dataset):
                  if idx % self.world_size == self.rank:
                      yield item
      
    • Pass the rank and world size to your dataset:
      from torch.distributed import get_rank, get_world_size
      
      dataset = ShardedIterableDataset(raw_dataset, rank=get_rank(), world_size=get_world_size())
      

Key Takeaways

  1. Enable DDP by setting local_rank or using a distributed launcher.
  2. Manually shard your IterableDataset to avoid redundant data loading across GPUs.
  3. Verify your setup by checking logs for distributed training flags (e.g., Using distributed training: True in the Trainer logs).

Example Fix for Distributed Training

# Add local_rank to training args (if using torch.distributed.launch)
training_args = Seq2SeqTrainingArguments(
    output_dir,
    per_device_train_batch_size=16,
    num_train_epochs=1,
    logging_dir=f"{output_dir}/logs",
    report_to="tensorboard",
    save_strategy="no",
    eval_strategy="no",
    local_rank=0  # Add this if launching via torch.distributed.launch
)

# Shard the dataset manually
from torch.distributed import get_rank, get_world_size

def get_sharded_dataset(rank, world_size):
    return ShardedIterableDataset(dataset, rank, world_size)

# Modify your DataLoader (if needed)
trainer = Seq2SeqTrainer(
    model,
    args=training_args,
    train_dataset=get_sharded_dataset(get_rank(), get_world_size()),
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)
)

By addressing these points, you’ll ensure efficient distributed training with proper data sharding. Let me know if you need further clarification!