How to combine LoRA and gradient_checkpointing in Whisper?

I ran into this issue earlier.
The cause of the issue was due to the missing grad_fn in the loss value.
As stated in the documentation of gradient checkpointing:

If use_reentrant=True is specified, at least one of the inputs needs to have requires_grad=True if grads are needed for model inputs, otherwise the checkpointed part of the model won’t have gradients. At least one of the outputs needs to have requires_grad=True as well. Note that this does not apply if use_reentrant=False is specified.

Thus, I fixed it by adding the flag use_reentrant=False in torch.utils.checkpoint.checkpoint() in the transformers/src/transformers/models/whisper/modeling_whisper.py file.

3 Likes