Whisper fine-tuning without Seq2SeqTrainer

Hi,

I’m currently working on Whisper models and I’m exploring modifications in the Whisper Processing. Initially, I’m focusing on decomposing the encoder and the decoder of the model.

Decomposing the encoder and the decoder is straightforward. However, I’m encountering a challenge in predicting the output sequence from the decoder. Specifically, I have a decoder-only Whisper, such as WhisperForCausalLM or a .decoder from WhisperModel. My goal is to predict the output sequence in a single step for subsequent use in computing the loss function with labels.

I’ve set up the decoder_input_ids with the decoder_start_token_id, the language token, and the task token. Let’s assume decoder_input_ids.shape = (batch_size, 3). When using the forward function, the output has the same shape. Should I consider “padding” the input_ids in this case?
For sure I’m missing something trivial.

I believe this question may apply to other Seq2Seq models as well.

I appreciate any guidance you can provide.