I tried using @sanchit-gandhi’s tutorial (Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers) on fine-tuning Whisper models in my own scenario (fine-tuning on LibriSpeech) but seem to be running into an issue I cannot make much sense of. On running the following code I get forward() got an unexpected keyword argument 'attention_mask'
, even though the model definition (Whisper) very clearly has an input argument with the same name.
training_args = Seq2SeqTrainingArguments(
output_dir="./testing_training",
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
learning_rate=1e-5,
weight_decay=0,
warmup_steps=500,
num_train_epochs=5,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_strategy="no",
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=False,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
)
# Init data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train()