Access Hidden States in Custom Loss Function in Finetuning

When using SFTTrainer how can I create a custom loss function (as descirbed here) that can access hidden states of a model.

I know hidden states can be accessed if we pass output_hidden_states=True to the model’s forward call (as described here). So maybe another version of my question is how can we make SFTTrainer output hidden states?

1 Like