How to combine LoRA and gradient_checkpointing in Whisper?

I’m trying to fine tune Whisper with LoRA. When I enable gradient_checkpointing, I’ll get the following error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Are LoRA and gradient_checkpointing inherently incompatible? Or is this a bug?

This is the code I have:

from transformers import WhisperForConditionalGeneration


model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
from transformers import Seq2SeqTrainingArguments


training_args = Seq2SeqTrainingArguments(
    output_dir=f"./temp",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=500,
    max_steps=40000,
    evaluation_strategy="steps",
    gradient_checkpointing=True,
    fp16=True,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=10000,
    eval_steps=1000,
    logging_steps=25,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    
    auto_find_batch_size=True,
)
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR


class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["validation"],
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[27], line 1
----> 1 trainer.train()

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1537         hf_hub_utils.enable_progress_bars()
   1538 else:
-> 1539     return inner_training_loop(
   1540         args=args,
   1541         resume_from_checkpoint=resume_from_checkpoint,
   1542         trial=trial,
   1543         ignore_keys_for_eval=ignore_keys_for_eval,
   1544     )

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/accelerate/utils/memory.py:136, in find_executable_batch_size.<locals>.decorator(*args, **kwargs)
    134     raise RuntimeError("No executable batch size found, reached zero.")
    135 try:
--> 136     return function(batch_size, *args, **kwargs)
    137 except Exception as e:
    138     if should_reduce_batch_size(e):

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:1821, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1818     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1820 with self.accelerator.accumulate(model):
-> 1821     tr_loss_step = self.training_step(model, inputs)
   1823 if (
   1824     args.logging_nan_inf_filter
   1825     and not is_torch_tpu_available()
   1826     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1827 ):
   1828     # if loss is nan or inf simply add the average of previous logged losses
   1829     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:2677, in Trainer.training_step(self, model, inputs)
   2675         scaled_loss.backward()
   2676 else:
-> 2677     self.accelerator.backward(loss)
   2679 return loss.detach() / self.args.gradient_accumulation_steps

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/accelerate/accelerator.py:1851, in Accelerator.backward(self, loss, **kwargs)
   1849     return
   1850 elif self.scaler is not None:
-> 1851     self.scaler.scale(loss).backward(**kwargs)
   1852 else:
   1853     loss.backward(**kwargs)

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I ran into this issue earlier.
The cause of the issue was due to the missing grad_fn in the loss value.
As stated in the documentation of gradient checkpointing:

If use_reentrant=True is specified, at least one of the inputs needs to have requires_grad=True if grads are needed for model inputs, otherwise the checkpointed part of the model won’t have gradients. At least one of the outputs needs to have requires_grad=True as well. Note that this does not apply if use_reentrant=False is specified.

Thus, I fixed it by adding the flag use_reentrant=False in torch.utils.checkpoint.checkpoint() in the transformers/src/transformers/models/whisper/modeling_whisper.py file.