I am using the code provided in this blog. Here is the link to google colab notebook here
The notebook runs perfectly fine in a machine with single GPU. However, when I run it on machine with Mutiple GPUs (n=4, Nvidia Tesla T4), I am getting the following error (at the end)
To debug, I set use_cpu=True
and the training loop runs ok as expected. So, I am pretty sure it is about multi-GPU.
Trainer code (full code in the colab notebook)
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./speecht5_tts_voxpopuli_nl", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
greater_is_better=False,
label_names=["labels"],
push_to_hub=True,
)
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
data_collator=data_collator,
tokenizer=processor.tokenizer,
)
trainer.train()
/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
warnings.warn('Was asked to gather along dimension 0, but all '
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File <timed eval>:1
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1575, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1573 hf_hub_utils.enable_progress_bars()
1574 else:
-> 1575 return inner_training_loop(
1576 args=args,
1577 resume_from_checkpoint=resume_from_checkpoint,
1578 trial=trial,
1579 ignore_keys_for_eval=ignore_keys_for_eval,
1580 )
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1875, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1872 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
1874 with self.accelerator.accumulate(model):
-> 1875 tr_loss_step = self.training_step(model, inputs)
1877 if (
1878 args.logging_nan_inf_filter
1879 and not is_torch_tpu_available()
1880 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
1881 ):
1882 # if loss is nan or inf simply add the average of previous logged losses
1883 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2740, in Trainer.training_step(self, model, inputs)
2737 return loss_mb.reduce_mean().detach().to(self.args.device)
2739 with self.compute_loss_context_manager():
-> 2740 loss = self.compute_loss(model, inputs)
2742 if self.args.n_gpu > 1:
2743 loss = loss.mean() # mean() to average on multi-gpu parallel training
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2765, in Trainer.compute_loss(self, model, inputs, return_outputs)
2763 else:
2764 labels = None
-> 2765 outputs = model(**inputs)
2766 # Save past state if it exists
2767 # TODO: this needs to be fixed and made cleaner later.
2768 if self.args.past_index >= 0:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:172, in DataParallel.forward(self, *inputs, **kwargs)
170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
171 outputs = self.parallel_apply(replicas, inputs, kwargs)
--> 172 return self.gather(outputs, self.output_device)
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:184, in DataParallel.gather(self, outputs, output_device)
183 def gather(self, outputs, output_device):
--> 184 return gather(outputs, output_device, dim=self.dim)
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:86, in gather(outputs, target_device, dim)
83 # Recursive function calls like this create reference cycles.
84 # Setting the function to None clears the refcycle.
85 try:
---> 86 res = gather_map(outputs)
87 finally:
88 gather_map = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:77, in gather.<locals>.gather_map(outputs)
75 if not all(len(out) == len(d) for d in outputs):
76 raise ValueError('All dicts must have the same number of keys')
---> 77 return type(out)((k, gather_map([d[k] for d in outputs]))
78 for k in out)
79 if _is_namedtuple(out):
80 return type(out)._make(map(gather_map, zip(*outputs)))
File <string>:12, in __init__(self, loss, spectrogram, past_key_values, decoder_hidden_states, decoder_attentions, cross_attentions, encoder_last_hidden_state, encoder_hidden_states, encoder_attentions)
File /opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:343, in ModelOutput.__post_init__(self)
340 # if we provided an iterator as first field and the iterator is a (key, value) iterator
341 # set the associated fields
342 if first_field_iterator:
--> 343 for idx, element in enumerate(iterator):
344 if (
345 not isinstance(element, (list, tuple))
346 or not len(element) == 2
347 or not isinstance(element[0], str)
348 ):
349 if idx == 0:
350 # If we do not have an iterator of key/values, set it as attribute
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:77, in <genexpr>(.0)
75 if not all(len(out) == len(d) for d in outputs):
76 raise ValueError('All dicts must have the same number of keys')
---> 77 return type(out)((k, gather_map([d[k] for d in outputs]))
78 for k in out)
79 if _is_namedtuple(out):
80 return type(out)._make(map(gather_map, zip(*outputs)))
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:81, in gather.<locals>.gather_map(outputs)
79 if _is_namedtuple(out):
80 return type(out)._make(map(gather_map, zip(*outputs)))
---> 81 return type(out)(map(gather_map, zip(*outputs)))
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:71, in gather.<locals>.gather_map(outputs)
69 out = outputs[0]
70 if isinstance(out, torch.Tensor):
---> 71 return Gather.apply(target_device, dim, *outputs)
72 if out is None:
73 return None
File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:56, in Gather.forward(ctx, target_device, dim, *inputs)
54 @staticmethod
55 def forward(ctx, target_device, dim, *inputs):
---> 56 assert all(i.device.type != 'cpu' for i in inputs), (
57 'Gather function not implemented for CPU tensors'
58 )
59 if (target_device == 'cpu'):
60 ctx.target_device = 'cpu'
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:56, in <genexpr>(.0)
54 @staticmethod
55 def forward(ctx, target_device, dim, *inputs):
---> 56 assert all(i.device.type != 'cpu' for i in inputs), (
57 'Gather function not implemented for CPU tensors'
58 )
59 if (target_device == 'cpu'):
60 ctx.target_device = 'cpu'
AttributeError: 'NoneType' object has no attribute 'device'