Trainer RuntimeError: The size of tensor a (462) must match the size of tensor b (448) at non-singleton dimension 1

Hi, I am finetuning Whisper and run into a trainer issue and don’t know what to do:

RuntimeError: The size of tensor a (462) must match the size of tensor b (448) at non-singleton dimension 1

The trace goes:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed exec>:28

File ~/.local/lib/python3.9/site-packages/transformers/trainer.py:1515, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1510     self.model_wrapped = self.model
   1512 inner_training_loop = find_executable_batch_size(
   1513     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1514 )
-> 1515 return inner_training_loop(
   1516     args=args,
   1517     resume_from_checkpoint=resume_from_checkpoint,
   1518     trial=trial,
   1519     ignore_keys_for_eval=ignore_keys_for_eval,
   1520 )

File ~/.local/lib/python3.9/site-packages/transformers/trainer.py:1763, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1761         tr_loss_step = self.training_step(model, inputs)
   1762 else:
-> 1763     tr_loss_step = self.training_step(model, inputs)
   1765 if (
   1766     args.logging_nan_inf_filter
   1767     and not is_torch_tpu_available()
   1768     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1769 ):
   1770     # if loss is nan or inf simply add the average of previous logged losses
   1771     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/.local/lib/python3.9/site-packages/transformers/trainer.py:2522, in Trainer.training_step(self, model, inputs)
   2519     return loss_mb.reduce_mean().detach().to(self.args.device)
   2521 with self.compute_loss_context_manager():
-> 2522     loss = self.compute_loss(model, inputs)
   2524 if self.args.n_gpu > 1:
   2525     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File ~/.local/lib/python3.9/site-packages/transformers/trainer.py:2554, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2552 else:
   2553     labels = None
-> 2554 outputs = model(**inputs)
   2555 # Save past state if it exists
   2556 # TODO: this needs to be fixed and made cleaner later.
   2557 if self.args.past_index >= 0:

File ~/.local/lib/python3.9/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1192, in WhisperForConditionalGeneration.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1187     if decoder_input_ids is None and decoder_inputs_embeds is None:
   1188         decoder_input_ids = shift_tokens_right(
   1189             labels, self.config.pad_token_id, self.config.decoder_start_token_id
   1190         )
-> 1192 outputs = self.model(
   1193     input_features,
   1194     decoder_input_ids=decoder_input_ids,
   1195     encoder_outputs=encoder_outputs,
   1196     decoder_attention_mask=decoder_attention_mask,
   1197     head_mask=head_mask,
   1198     decoder_head_mask=decoder_head_mask,
   1199     cross_attn_head_mask=cross_attn_head_mask,
   1200     past_key_values=past_key_values,
   1201     decoder_inputs_embeds=decoder_inputs_embeds,
   1202     use_cache=use_cache,
   1203     output_attentions=output_attentions,
   1204     output_hidden_states=output_hidden_states,
   1205     return_dict=return_dict,
   1206 )
   1207 lm_logits = self.proj_out(outputs[0])
   1209 loss = None

File ~/.local/lib/python3.9/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1061, in WhisperModel.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1054     encoder_outputs = BaseModelOutput(
   1055         last_hidden_state=encoder_outputs[0],
   1056         hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
   1057         attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
   1058     )
   1060 # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
-> 1061 decoder_outputs = self.decoder(
   1062     input_ids=decoder_input_ids,
   1063     attention_mask=decoder_attention_mask,
   1064     encoder_hidden_states=encoder_outputs[0],
   1065     head_mask=decoder_head_mask,
   1066     cross_attn_head_mask=cross_attn_head_mask,
   1067     past_key_values=past_key_values,
   1068     inputs_embeds=decoder_inputs_embeds,
   1069     use_cache=use_cache,
   1070     output_attentions=output_attentions,
   1071     output_hidden_states=output_hidden_states,
   1072     return_dict=return_dict,
   1073 )
   1075 if not return_dict:
   1076     return decoder_outputs + encoder_outputs

File ~/.local/lib/python3.9/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:868, in WhisperDecoder.forward(self, input_ids, attention_mask, encoder_hidden_states, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    865 # embed positions
    866 positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
--> 868 hidden_states = inputs_embeds + positions
    869 hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
    871 # decoder layers

RuntimeError: The size of tensor a (462) must match the size of tensor b (448) at non-singleton dimension 1

And I found it strange because I am training on the Fleurs training set with different languages, this only happens to the Telugu dataset, and only happens right at the 24-th step like below:

Hey @navissivan!

Hmm, I’m wondering if there are some empty input/target sequences in the dataset. Could you try updating your prepare_dataset function and introducing one function to filter the inputs and one function to filter the labels:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute input length
    batch["input_length"] = len(batch["audio"])

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids

    # compute labels length
    batch["labels_length"] = len(tokenizer(batch["sentence"], add_special_tokens=False).input_ids)
    return batch

MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

def filter_labels(labels_length):
    """Filter empty label sequences"""
    return 0 < len(labels_length)

You can then apply the prepare_dataset function and the two filter functions to your dataset my_dataset as follows:

my_dataset = my_dataset.map(prepare_dataset, remove_columns= my_dataset.column_names["train"])

my_dataset = my_dataset.filter(filter_inputs, input_columns=["input_length"], remove_columns=["input_length"]

my_dataset = my_dataset.filter(filter_labels, input_columns=["labels_length"], remove_columns=["labels_length"])

(of course you will need to change the name of my_dataset to match your dataset name accordingly)

You can then run training as before with this filtered my_dataset.

If that doesn’t work, could you send a link to your colab so I can try and reproduce the error?

1 Like

Hi @sanchit-gandhi , thanks so much for the reply!!

First, I think there’s two modifications needed in your function:

# compute input length
batch["input_length"] = len(batch["audio"]['array'])

and

return 0 < labels_length

May I know why the labels_length is not directly len(batch["labels"]) but with add_special_tokens=False ? And why input_length is not len(batch['inpute_features']?

So I did look into the preprocessing steps, and turns out this is indeed caused by input length mismatch?
Token indices sequence length is longer than the specified maximum sequence length for this model (459 > 448). Running this sequence through the model will result in indexing errors

But I thought this should be taken care of by the feature extractor to truncate, will it?

And I did the filter as you said, but I don’t think the dataset has any empty input, rather filtered out 7 longer entries. However, according to the preprocess warnings, there should be 10 problematic sequences.

And there’s still the same error even after the filter.

So I checked Whisper feature extractor and Whisper tokenizer.

I assume the problem here is: there are 7 samples that have input lengths exceeding 30s, and 10 samples that have label lengths exceeding the max length for the model, which is 448.
So I tried:

MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], 
                                                sampling_rate=audio["sampling_rate"],
                                                max_length=max_input_length).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["raw_transcription"], 
                                truncation=True,
                                max_length=448).input_ids

While I guess the truncation in the feature extractor doesn’t matter (?) since the output feature size will be fixed to 80, this works for me to proceed with the training. Please correct me if my understanding is wrong!

Hey @sanchit-gandhi, may I also ask what part of Whisper is being fine-tuned and which part is fixed in this process of your post?

This process will train both encoder and decoder. You can check the number of parameter, it will be about 241M.

There is an option that you can finetune only decoder part by adding this code:

model.freeze_encoder()
2 Likes

Hey @navissivan!

The Whisper feature extractor returns log-Mel features of fixed dimension: it first pads/truncates the audio samples to 30s, and then computes the log-Mel filter bank coefficients. Therefore, they are always of fixed size (no matter what the input audio length), and so we have to compute the input audio length directly from the audio sample.

Regarding the labels, the Whisper tokenizer adds ‘special’ BOS and EOS tokens to the beginning and end of sequence. So even an empty string will have a non-zero labels length:

from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="English")

input_str = ""
encoded_ids = tokenizer(input_str).input_ids
labels_length = len(encoded_ids)
decoded_str = tokenizer.decode(encoded_ids)

print(labels_length)
print(decoded_str)

Print Output:

4
<|startoftranscript|><|en|><|notimestamps|><|endoftext|>

If we set add_special_tokens=False, these extra labels aren’t added.

Your last message has revealed exactly what the problem is! The input features are entirely fine - as you say they are being padded/truncated to 30s, so they are always fixed dimension.

The issue is with your target label sequences. Some of the label sequences have a length that exceeds the model’s maximum generation length. These must be very long sequences, as the maximum generation length is 448. This is the longest sequence the model is configured to handle (model.config.max_length).

What we can do is compute the labels length with the special tokens, in order to compute the total length of the label sequence (BOS and EOS tokens included):

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute input length
    batch["input_length"] = len(batch["audio"])

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids

    # compute labels length **with** special tokens! -> total label length
    batch["labels_length"] = len(batch["labels"])
    return batch

And then filter those that exceed the models maximum length:

MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

max_label_length = model.config.max_length

def filter_labels(labels_length):
    """Filter label sequences longer than max length (448)"""
    return labels_length < max_label_length

That should remove any label sequences that are too long for the model.

Note: we can also change the model’s max length to any value we want:

model.config.max_length = 500

This will update the max length to 500 tokens. Make sure to do this before you filter for it to take effect.

Thanks for the answers and clarification!!
So according to your description:

  1. my implementation here truncate the label at the max length instead of filter the sample out, does it still make sense?
  2. model.config.max_length = 500 does this mean I can choose which ever number to be this hyperparameter and don’t need to follow the original 448?

Hey @navissivan!

  1. Your implementation truncates the label sequences to the max length, whereas mine filters those above the max length. Both work in terms of getting the model to train! However, there might be instances in yours where the audio doesn’t match the labels, as you’ve truncated the label sequences to the max length. As an example, if we had a very small max length of 5, the audio might contain the utterance: “the cat sat on the mat”, but the labels get truncated to: “the cat sat on the”. This could derail predictions for these utterances, as the audio doesn’t match the target text. My implementation removes instances where we have to filter the labels (so the audio always matches the text)
  2. Yep! You can check out this excellent blog post for finding out more about generation: How to generate text: using different decoding methods for language generation with Transformers