Not able to scale Trainer code to single node multi GPU

I am using the code provided in this blog. Here is the link to google colab notebook here

The notebook runs perfectly fine in a machine with single GPU. However, when I run it on machine with Mutiple GPUs (n=4, Nvidia Tesla T4), I am getting the following error (at the end)

To debug, I set use_cpu=True and the training loop runs ok as expected. So, I am pretty sure it is about multi-GPU.

Trainer code (full code in the colab notebook)

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./speecht5_tts_voxpopuli_nl",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    greater_is_better=False,
    label_names=["labels"],
    push_to_hub=True,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

trainer.train()

/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File <timed eval>:1

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1575, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1573         hf_hub_utils.enable_progress_bars()
   1574 else:
-> 1575     return inner_training_loop(
   1576         args=args,
   1577         resume_from_checkpoint=resume_from_checkpoint,
   1578         trial=trial,
   1579         ignore_keys_for_eval=ignore_keys_for_eval,
   1580     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1875, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1872     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1874 with self.accelerator.accumulate(model):
-> 1875     tr_loss_step = self.training_step(model, inputs)
   1877 if (
   1878     args.logging_nan_inf_filter
   1879     and not is_torch_tpu_available()
   1880     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1881 ):
   1882     # if loss is nan or inf simply add the average of previous logged losses
   1883     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2740, in Trainer.training_step(self, model, inputs)
   2737     return loss_mb.reduce_mean().detach().to(self.args.device)
   2739 with self.compute_loss_context_manager():
-> 2740     loss = self.compute_loss(model, inputs)
   2742 if self.args.n_gpu > 1:
   2743     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2765, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2763 else:
   2764     labels = None
-> 2765 outputs = model(**inputs)
   2766 # Save past state if it exists
   2767 # TODO: this needs to be fixed and made cleaner later.
   2768 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:172, in DataParallel.forward(self, *inputs, **kwargs)
    170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
    171 outputs = self.parallel_apply(replicas, inputs, kwargs)
--> 172 return self.gather(outputs, self.output_device)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:184, in DataParallel.gather(self, outputs, output_device)
    183 def gather(self, outputs, output_device):
--> 184     return gather(outputs, output_device, dim=self.dim)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:86, in gather(outputs, target_device, dim)
     83 # Recursive function calls like this create reference cycles.
     84 # Setting the function to None clears the refcycle.
     85 try:
---> 86     res = gather_map(outputs)
     87 finally:
     88     gather_map = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:77, in gather.<locals>.gather_map(outputs)
     75     if not all(len(out) == len(d) for d in outputs):
     76         raise ValueError('All dicts must have the same number of keys')
---> 77     return type(out)((k, gather_map([d[k] for d in outputs]))
     78                      for k in out)
     79 if _is_namedtuple(out):
     80     return type(out)._make(map(gather_map, zip(*outputs)))

File <string>:12, in __init__(self, loss, spectrogram, past_key_values, decoder_hidden_states, decoder_attentions, cross_attentions, encoder_last_hidden_state, encoder_hidden_states, encoder_attentions)

File /opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:343, in ModelOutput.__post_init__(self)
    340 # if we provided an iterator as first field and the iterator is a (key, value) iterator
    341 # set the associated fields
    342 if first_field_iterator:
--> 343     for idx, element in enumerate(iterator):
    344         if (
    345             not isinstance(element, (list, tuple))
    346             or not len(element) == 2
    347             or not isinstance(element[0], str)
    348         ):
    349             if idx == 0:
    350                 # If we do not have an iterator of key/values, set it as attribute

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:77, in <genexpr>(.0)
     75     if not all(len(out) == len(d) for d in outputs):
     76         raise ValueError('All dicts must have the same number of keys')
---> 77     return type(out)((k, gather_map([d[k] for d in outputs]))
     78                      for k in out)
     79 if _is_namedtuple(out):
     80     return type(out)._make(map(gather_map, zip(*outputs)))

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:81, in gather.<locals>.gather_map(outputs)
     79 if _is_namedtuple(out):
     80     return type(out)._make(map(gather_map, zip(*outputs)))
---> 81 return type(out)(map(gather_map, zip(*outputs)))

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:71, in gather.<locals>.gather_map(outputs)
     69 out = outputs[0]
     70 if isinstance(out, torch.Tensor):
---> 71     return Gather.apply(target_device, dim, *outputs)
     72 if out is None:
     73     return None

File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
    503 if not torch._C._are_functorch_transforms_active():
    504     # See NOTE: [functorch vjp and autograd interaction]
    505     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506     return super().apply(*args, **kwargs)  # type: ignore[misc]
    508 if cls.setup_context == _SingleLevelFunction.setup_context:
    509     raise RuntimeError(
    510         'In order to use an autograd.Function with functorch transforms '
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:56, in Gather.forward(ctx, target_device, dim, *inputs)
     54 @staticmethod
     55 def forward(ctx, target_device, dim, *inputs):
---> 56     assert all(i.device.type != 'cpu' for i in inputs), (
     57         'Gather function not implemented for CPU tensors'
     58     )
     59     if (target_device == 'cpu'):
     60         ctx.target_device = 'cpu'

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:56, in <genexpr>(.0)
     54 @staticmethod
     55 def forward(ctx, target_device, dim, *inputs):
---> 56     assert all(i.device.type != 'cpu' for i in inputs), (
     57         'Gather function not implemented for CPU tensors'
     58     )
     59     if (target_device == 'cpu'):
     60         ctx.target_device = 'cpu'

AttributeError: 'NoneType' object has no attribute 'device'
1 Like