Single Node Multi GPU FlanT5 fine-tuning using HF Dataset and HF Trainer

I have been stuck on this for quite a while and would really appreciate it if someone could help.

So, I am trying to fine-tune a FlanT5-Large on a dataset using the HF dataset and HF trainer. I want to do it in a distributed fashion as my machine has 2 GPUs.

Attempt 1
This was the script I used - single_node_multi_gpu_v0.py · GitHub

To test if this works for a single GPU, I did
CUDA_VISIBLE_DEVICES=0 python run_lora_train.py --train_file_path datasets/train_sample.csv --valid_file_path datasets/valid_sample.csv --output_dir test_flant5_large
and it worked fine.

It has been mentioned on a few Pytorch/HF Forums that extending this to single node multip GPU setting only requires changing how the script is launched. So I did this -
torchrun --nproc_per_node 2 run_lora_train.py --train_file_path datasets/train_sample.csv --valid_file_path datasets/valid_sample.csv --output_dir test_flant5_large

I get this error -

  File "run_lora_train.py", line 124, in <module>
    trainer.train()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1637, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1720, in _inner_training_loop
    model = self._wrap_model(self.model_wrapped)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1549, in _wrap_model
    **kwargs,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 589, in __init__
    {p.device for p in module.parameters()},
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 686, in _log_and_throw
    raise err_type(err_msg)
ValueError: DistributedDataParallel device_ids and output_device arguments only work with single-device/multiple-device GPU modules or CPU modules, but got device_ids [0], output_device 0, and module parameters {device(type='cuda', index=0), device(type='cuda', index=1)}.
/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
Traceback (most recent call last):
  File "run_lora_train.py", line 124, in <module>
    trainer.train()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1637, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1720, in _inner_training_loop
    model = self._wrap_model(self.model_wrapped)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1549, in _wrap_model
    **kwargs,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 589, in __init__
    {p.device for p in module.parameters()},
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 686, in _log_and_throw
    raise err_type(err_msg)
ValueError: DistributedDataParallel device_ids and output_device arguments only work with single-device/multiple-device GPU modules or CPU modules, but got device_ids [1], output_device 1, and module parameters {device(type='cuda', index=0), device(type='cuda', index=1)}.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 16406) of binary: /home/shreyansh/miniconda/envs/shreyansh-env/bin/python
Traceback (most recent call last):
  File "/home/shreyansh/miniconda/envs/shreyansh-env/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/run.py", line 756, in run
    )(*cmd_args)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 248, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

Attempt 2

I decided to use DDP explicitly (although HF trainer handles that as the above forum post also says) with the following script - single_node_multi_gpu_v1.py · GitHub
where I only add line 107, the rest of the script is the same.

This leads to a different error -

  File "run_lora_train.py", line 124, in <module>
    trainer.train()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1637, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1872, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 2357, in __getitem__
    key,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 2340, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/formatting/formatting.py", line 463, in query_table
    _check_valid_index_key(key, size)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/formatting/formatting.py", line 406, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 13 is out of bounds for size 0
  0%|                                                                                                                                                                                    | 0/40 [00:00<?, ?it/s]Traceback (most recent call last):
  File "run_lora_train.py", line 124, in <module>
    trainer.train()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1637, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/transformers/trainer.py", line 1872, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 2357, in __getitem__
    key,
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/arrow_dataset.py", line 2340, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/formatting/formatting.py", line 463, in query_table
    _check_valid_index_key(key, size)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/datasets/formatting/formatting.py", line 406, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 6 is out of bounds for size 0
  0%|                                                                                                                                                                                    | 0/40 [00:00<?, ?it/s]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 16640) of binary: /home/shreyansh/miniconda/envs/shreyansh-env/bin/python
Traceback (most recent call last):
  File "/home/shreyansh/miniconda/envs/shreyansh-env/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/run.py", line 756, in run
    )(*cmd_args)
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/shreyansh/miniconda/envs/shreyansh-env/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 248, in launch_agent
    failures=result.failures,

Any suggestions on how I can fix these issues?

1 Like

Were you able to fix this issue? I’ve been stuck on the same thing for quite some time now. Not sure what the issue is.

1 Like

encountered the same issue.

@mariosasko @muellerzr
Sorry for directly pinging to this discussion, but I’ve encountered the same issue and have been stuck on it for a while now. I’ve tested multiple scripts and it seems that HuggingFace’s Trainer class simply doesn’t work for single-node multi-gpu setups. When I test with single gpu, the training runs without a problem. However, when I run with multi-gpu, the training deadlocks and makes zero progress (even if given enough time).

As @shreyansh26 said in his post, I also attempted to use Pytorch’s torchrun, but encountered that vague error. Could it be that something is wrong with HuggingFace’s Trainer class for multi-gpu setups, or is it specifically a script issue? For reference here is the Falcon-7b finetuning script that works with a single gpu but deadlocks with multiple:

from datasets import load_dataset

dataset_name = "timdettmers/openassistant-guanaco"
dataset = load_dataset(dataset_name, split="train")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer

model_name = "ybelkada/falcon-7b-sharded-bf16"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True
)
model.config.use_cache = False

print(model.hf_device_map)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

from peft import LoraConfig

lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
)

from transformers import TrainingArguments

output_dir = "./results"
per_device_train_batch_size = 4
gradient_accumulation_steps = 4
optim = "paged_adamw_32bit"
save_steps = 10
logging_steps = 10
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 500
warmup_ratio = 0.03
lr_scheduler_type = "constant"

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=False,
    lr_scheduler_type=lr_scheduler_type,
)

from trl import SFTTrainer

max_seq_length = 512


trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)


for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)


trainer.train()

For everyone else in this discussion, I managed to find this section of transformers documents that explains how to use the accelerate api in conjunction with huggingface’s trainer (its at the botttom of the page). Here’s the link: