Using gradient checkpointing and KV caching when generation happens in no grad context

Hello,

During my training loop I generate from my LLM, but I do not need gradients from this so this generation happens within a no grad context. When I try to use gradient checkpointing together with use_cache=True I get an error, and use_cache is automatically set to false. But since my generation happens in a no grad context I don’t see why I shouldn’t be able to use KV caching here?

Thanks,

It is because in gradient_checkpointing the model has to do several forward passes to compute gradients again from the given checkpoint. So we disable caching in that case

I am currently trying to enable generation during eval stage of HFTrainer, but I am not 100% sure if gradient_checpointing is disabled for eval loop. Thanks for raising this question, worth investigating which I will do next week. For now, maybe you can try to disable checkponiting and then enable back. And lmk if that or any other solution works for you :slight_smile:

Yeah I’m just not using gradient_checkpointing for now. I’m using my own trainer but I’ll be interested in seeing how you get on with HFTrainer.

Thanks,