Fine-Tune Llama on main and auxiliary task

Hello everyone,

I am trying to fine-tune Llama model on two task at the same time:

  1. Main task: Causal language model like the model was initially trained for
  2. A classification task based on the whole input sequence (recommend an article). For this task I am getting as a reference the LlamaForCausalLM class, overwriting init and forward functions .

However, I want to combine the two tasks above into one process. The main problem is that language modelling is an iterative process were the loss is calculated for every new context token in the input sequence, while for the classification task the loss should only be calculated once.

How can I freeze the loss update on the classification task up and only calculated once the language modelling part has been completed. Is there any example you can recommend in order to combine a main LM task with an auxiliary classification task?

First question for me here, thanks everyone for your understanding.