Fine tune with SFTTrainer

I noticed that, according to the trainer’s documentation, when fine-tuning the model, I am required to provide a text field (trl/trl/trainer/sft_trainer.py at 18a33ffcd3a576f809b6543a710e989333428bd3 · huggingface/trl · GitHub). However, this does not seem to be a supervised task!

Upon further examination, I observed that the model’s labels are the same as the input_ids, except they are shifted. This leads me to ask how this can be considered supervised learning. In my understanding, the prompt should serve as the input, and the completion should be the label. However, in this case, there are no distinct prompts and completions, only raw text.

Could you clarify what I am missing here?

1 Like

Hi,

So SFT (supervised fine-tuning) is called supervised since we’re collecting the data from humans. However we’re still training the model using the same cross-entropy loss as during pre-training (i.e. predicting the next token).

We now just make it more likely that the model will generate a useful completion given an instruction like “what are 10 things to do in London”, then the model should learn to generate “in London, you can visit (…)” for instance.

Since the model is still trained to predict the next token, we just concatenate the instruction and completion in a single “text” column, hence we can create the labels by shifting the inputs one position to the right (as is done during pre-training). One can then decide to only train the model on the completions, rather than the instructions, but the default SFTTrainer of TRL trains the model to predict both instructions and completions.

4 Likes

I have the same question as you, can you show me how to check how the dataset is created after putting the “text” field into traniner()?

Hi, not sure if you have tried or seen this. When I try to do sft on only completions using DataCollatorForCompletionOnlyLM, I get nan in the gradients very quickly. However, when I use the default sft which is on the entire input, everything works well. Do you happen to have any ideas why?

My issue is linked here: TRL SFT super prone to nan when using data collator

You could check this by doing trainer.get_train_dataloader, and then check the first batch.

I had the same surprise as ron5569 when I looked at the SFTTrainer code.

One can then decide to only train the model on the completions, rather than the instructions, but the default SFTTrainer of TRL trains the model to predict both instructions and completions.

Questions:

  • is it fair to say that most people fine tune with both the instruction and the completion?
  • if so, does that mean fine tuning on both leads to about as good performance as fine tuning on just the completions?
  • fine tuning with generation of the instruction as well as the completion would seem like a waste of unnecessary computation no? Shouldn’t the default be to only fine tune on the completions?
  • The DataCollatorForCompletionOnlyLM works by setting the indexes on the instruction part to -100. But this still means generation of those tokens happens, just they are not included in the loss function calculation. Again, would that not be wasteful of compute?

Btw, thank you @nielsr for Tutorials/Mistral/Supervised_fine_tuning_(SFT)_of_an_LLM_using_Hugging_Face_tooling! Very nice to see the complete example but raises the above questions about why not train only on the completions by default.

4 Likes