I ran into this issue earlier.
The cause of the issue was due to the missing grad_fn in the loss value.
As stated in the documentation of gradient checkpointing:
If
use_reentrant=True
is specified, at least one of the inputs needs to haverequires_grad=True
if grads are needed for model inputs, otherwise the checkpointed part of the model won’t have gradients. At least one of the outputs needs to haverequires_grad=True
as well. Note that this does not apply ifuse_reentrant=False
is specified.
Thus, I fixed it by adding the flag use_reentrant=False
in torch.utils.checkpoint.checkpoint()
in the transformers/src/transformers/models/whisper/modeling_whisper.py
file.