Audio Course: Unit 6 Unable to train Speech T5

Hello,

Thank you for the audio course.

@MariaK / @sanchit-gandhi Am assuming that you are one of the course instructors.

I am trying to follow the fine-tuning example Fine-tuning SpeechT5 - Hugging Face Audio Course but am unable to train the model.
I haven’t changed any code.

Following is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 1
----> 1 trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1528, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1525 try:
   1526     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1527     hf_hub_utils.disable_progress_bars()
-> 1528     return inner_training_loop(
   1529         args=args,
   1530         resume_from_checkpoint=resume_from_checkpoint,
   1531         trial=trial,
   1532         ignore_keys_for_eval=ignore_keys_for_eval,
   1533     )
   1534 finally:
   1535     hf_hub_utils.enable_progress_bars()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1854, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1851     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1853 with self.accelerator.accumulate(model):
-> 1854     tr_loss_step = self.training_step(model, inputs)
   1856 if (
   1857     args.logging_nan_inf_filter
   1858     and not is_torch_tpu_available()
   1859     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1860 ):
   1861     # if loss is nan or inf simply add the average of previous logged losses
   1862     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2735, in Trainer.training_step(self, model, inputs)
   2732     return loss_mb.reduce_mean().detach().to(self.args.device)
   2734 with self.compute_loss_context_manager():
-> 2735     loss = self.compute_loss(model, inputs)
   2737 if self.args.n_gpu > 1:
   2738     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2758, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2756 else:
   2757     labels = None
-> 2758 outputs = model(**inputs)
   2759 # Save past state if it exists
   2760 # TODO: this needs to be fixed and made cleaner later.
   2761 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:687, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    686 def forward(*args, **kwargs):
--> 687     return model_forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:675, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    674 def __call__(self, *args, **kwargs):
--> 675     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py:14, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     11 @functools.wraps(func)
     12 def decorate_autocast(*args, **kwargs):
     13     with autocast_instance:
---> 14         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/transformers/models/speecht5/modeling_speecht5.py:2719, in SpeechT5ForTextToSpeech.forward(self, input_ids, attention_mask, decoder_input_values, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, speaker_embeddings, labels, stop_labels)
   2716     if self.config.use_guided_attention_loss:
   2717         output_attentions = True
-> 2719 outputs = self.speecht5(
   2720     input_values=input_ids,
   2721     attention_mask=attention_mask,
   2722     decoder_input_values=decoder_input_values,
   2723     decoder_attention_mask=decoder_attention_mask,
   2724     head_mask=head_mask,
   2725     decoder_head_mask=decoder_head_mask,
   2726     cross_attn_head_mask=cross_attn_head_mask,
   2727     encoder_outputs=encoder_outputs,
   2728     past_key_values=past_key_values,
   2729     use_cache=use_cache,
   2730     speaker_embeddings=speaker_embeddings,
   2731     output_attentions=output_attentions,
   2732     output_hidden_states=output_hidden_states,
   2733     return_dict=True,
   2734 )
   2736 outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0])
   2738 loss = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/speecht5/modeling_speecht5.py:2213, in SpeechT5Model.forward(self, input_values, attention_mask, decoder_input_values, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, use_cache, speaker_embeddings, output_attentions, output_hidden_states, return_dict)
   2210 else:
   2211     decoder_args = {}
-> 2213 decoder_outputs = self.decoder(
   2214     input_values=decoder_input_values,
   2215     attention_mask=decoder_attention_mask,
   2216     encoder_hidden_states=encoder_outputs[0],
   2217     encoder_attention_mask=encoder_attention_mask,
   2218     head_mask=decoder_head_mask,
   2219     cross_attn_head_mask=cross_attn_head_mask,
   2220     past_key_values=past_key_values,
   2221     use_cache=use_cache,
   2222     output_attentions=output_attentions,
   2223     output_hidden_states=output_hidden_states,
   2224     return_dict=return_dict,
   2225     **decoder_args,
   2226 )
   2228 if not return_dict:
   2229     return decoder_outputs + encoder_outputs

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/speecht5/modeling_speecht5.py:1734, in SpeechT5DecoderWithSpeechPrenet.forward(self, input_values, attention_mask, encoder_hidden_states, encoder_attention_mask, speaker_embeddings, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1719 def forward(
   1720     self,
   1721     input_values: Optional[torch.FloatTensor] = None,
   (...)
   1732     return_dict: Optional[bool] = None,
   1733 ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
-> 1734     decoder_hidden_states = self.prenet(input_values, speaker_embeddings)
   1736     outputs = self.wrapped_decoder(
   1737         hidden_states=decoder_hidden_states,
   1738         attention_mask=attention_mask,
   (...)
   1747         return_dict=return_dict,
   1748     )
   1750     return outputs

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/speecht5/modeling_speecht5.py:702, in SpeechT5SpeechDecoderPrenet.forward(self, input_values, speaker_embeddings)
    700     speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
    701     speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1)
--> 702     inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
    703     inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
    705 return inputs_embeds

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 4 but got size 16 for tensor number 1 in the list.

Any suggestions would greatly help.

Thank you
Shamik

Hello

Can someone respond to this ? Whom should i tag ?

I certainly couldn’t find any problem with the code so i suspect that there’s problem with the transformers version or some other libraries.

Executed this notebook code Google Colab and it worked perfectly.

The versions of the libraries that i tested with are:

soundfile: 0.12.1 
 speechbrain: 0.5.16 
 accelerate: 0.26.1 
 datasets: 2.16.1 
 transformers: 4.35.2

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.