Using gradient_checkpointing=True in Trainer causes error with LLaMA

Hey,

I am trying to fine tune llama using the transformers library. I noticed when I set gradient_checkpointing=True in the trainer args i get the following error:

Expects BACKWARD_PRE or BACKWARD_POST state but got HandleTrainingState.FORWARD

Has anyone come across this before?

@sgugger I have seen you respond to gradient_checkpointing questions before, so thought I would tag you

This happens when i have set use_reentrant=False in the call to torch.checkpoint, when i set use_reentrant=True, i get an error when torch.compile() runs on the model