RuntimeError: Backward through graph with Whisper-medium and gradient_checkpointing=True

I am trying to fine-tune Whisper-medium and am getting this specific error during trainer.train():

tmp/ipython-input-774985985.py:8: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Seq2SeqTrainer.__init__`. Use `processing_class` instead.
  trainer = Seq2SeqTrainer(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipython-input-774985985.py in <cell line: 0>()
     16     tokenizer=processor,
     17 )
---> 18 trainer.train()
     19 #trainer.push_to_hub()

10 frames
/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
    827         unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    828     try:
--> 829         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    830             t_outputs, *args, **kwargs
    831         )  # Calls into the C++ engine to run the backward pass

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

These are the steps I’ve tried:

  • Gradient checkpointing enabled (gradient_checkpointing=True).

  • FP16 disabled (fp16=False).

  • use_cache=False (which is the default for training with checkpointing, but you can mention you checked).

  • predict_with_generate=True.

  • Running on a minimal dataset subset.

  • Using the original openai/whisper-medium model.

  • Restarting the runtime.

Env:

PyTorch version: 2.8.0+cu126
Transformers version: 4.56.2
Accelerate version: 1.10.1
Datasets version: 4.1.1

Modified code (per Gemini):

from transformers import WhisperForConditionalGeneration
# Diag
from accelerate import Accelerator
accelerator = Accelerator()
device = accelerator.device

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")

#Diag
model.to(device)

from functools import partial

# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# set language and task for generation and re-enable cache
model.generate = partial(
    model.generate, language="en", use_cache=True
)

rom transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
#training_args = TrainingArguments(
    #Diag
    output_dir="./whisper-medium-tp-test",  # name on the HF Hub
    per_device_train_batch_size=16,
    gradient_accumulation_steps=8,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=50,
    #Diag
    max_steps=50,  # increase to 4000 if you have your own GPU or a Colab paid plan
    gradient_checkpointing=True,
    fp16=False,
    fp16_full_eval=False,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    #Diag
    save_steps=50,
    eval_steps=10,
    logging_steps=10,
    report_to=["tensorboard"],
    save_strategy="steps",
    #Diag
    load_best_model_at_end=False,
    metric_for_best_model="wer",
    greater_is_better=False,
    #Diag
    push_to_hub=False,
)

from transformers import Seq2SeqTrainer

#Diag
small_train_dataset = dataset["train"].select(range(10)) # Select first 10 samples
small_eval_dataset = dataset["test"].select(range(10)) # Select first 10 samples


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    #Diag
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)
trainer.train()
#trainer.push_to_hub()
1 Like

Seems KV cache conflicts with gradient checkpointing graphs…

Wow, appreciate you putting all together in one place. I see several things I need to modify, will report back with success or failure (hopefully the former).

1 Like

Success!

The significant changes I made based on your example were:

gradient_checkpointing_kwargs={"use_reentrant": False},   
fp16=False,   
fp16_full_eval=False,

and I removed the model_generate = partial(…) call. That resolved the issue. Thank you!

Should I go ahead and try your other suggestions as well? I’m so pumped that it’s running that I don’t want to break it again…

1 Like

I think it’s best to copy stable code somewhere first before making changes. That’s what I always do. It gets messy though…