I am training a custom speech encoder-decoder model using a wav2vec2 encoder and a decoder from a language model (e.g., BART). This architecture can be trained with AutoModelForSpeechSeq2Seq
and works well so far. I noticed that my setup resembles Whisper in some ways. Whisper describes itself as a multilingual and multitask model, which got me curious: How can I implement multitasking in my model?
For multilinguality, I prepend a language identifier token to the input text during training (e.g., [en] it's a really nice day outside
). During generation, I can either force the decoder input IDs to include a specific language token (e.g., [en]
) or let the model infer the language from the input audio. This approach works well for language-specific tasks.
However, for multitasking, the challenge arises because the model cannot infer the task type (e.g., translation vs. summarization) directly from the audio. I haven’t fully explored this concept yet, but I suspect it might be feasible to handle task-specific identifiers (like [summarization]
) using a custom PyTorch training script. Unfortunately, implementing this would require significantly more time than I currently have available.
My question is:
- Is it possible to incorporate task-specific identifiers (similar to language identifiers) in the training process using
transformers.Seq2SeqTrainer
? - If yes, how can I modify or extend
Seq2SeqTrainer
to support multitasking functionality?
Any guidance or pointers would be greatly appreciated!