Question about gradient checkpointing and use_cache

I want to use the past_key_values to compute some losses during training, but gradient checkpointing forbids use_cache in CLMs. I wonder why there is such a conflict.

I think it’s because gradient checkpointing forces us to do several forward passes to compute the gradients again. If we have use past_key_values in such a scenario, the cache will be filled with repeting keys/values from extra forward calls.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.