When fine-tuning an LLM, we pass a long string containing context
, user input
, and the expected answer
. How does the model know which word to start generating from.
Is the loss during fine-tuning not only from generating the answer
?
Hi,
During supervised fine-tuning (SFT), LLMs are trained to predict the next token, just like during the pre-training stage. The data consists of (instruction, completion) pairs, like this dataset for instance. By default, one trains on all tokens (and just makes the LLM predict the next token on each of the tokens).
Data preparation
Let’s take an (input, output) or (instruction, completion) pair as example. Let’s say we have as instruction: “give me a recipe to bake pancakes” with the completion “to bake pancakes, one takes a pan”, then you would prepare this for the model like so:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
instruction = "give me a recipe to bake pancakes"
completion = "to bake pancakes, one takes a pan"
messages = [
{"role": "user", "content": instruction},
{"role": "assistant", "content": completion},
]
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
print(inputs)
As you will see, the instruction and completion are concatenated and some special tokens are added which indicate where the completion starts:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ngive me a recipe to bake pancakes<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nto bake pancakes, one takes a pan<|eot_id|>'
As can be seen, the tokenizer has added a default system prompt, then the instruction (marked with a special user
token, then the completion (marked with a special assistant
token), all concatenated into a single string.
Training the model
Next, we will train the LLM to predict the next token for each of the tokens. By default, one simply trains on all tokens (yes, that includes the system and user prompts). This means that the model is teached to predict <|start_header_id|>
after <|begin_of_text|>
, then system
after <|start_header_id|>
, and so on (the model is trained to predict the next token for each tokens all in parallel in one go). By default, the labels are in other words just a copy of the inputs (the model will shift the labels internally one time step to predict the next token). Visually:
One could avoid this by using the DataCollatorForCompletionOnlyLM
in the TRL library, which ensures that the labels are a copy of the inputs, with the system prompt and user prompt tokens replaced by -100 (the ignore index of PyTorch’s CrossEntropyLoss). This ensures that the model is only trained the predict the completion (“assistant”) tokens. Visually:
Does this not only ignore the calculation of PyTorch’s CrossEntropyLoss for the prompt tokens
, but also ignore
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ngive me a recipe to bake pancakes<|eot_id|><|start_header_id|>assistant<|end_header_id|>
as well ?
That is a good question. I think for that we would need to look into the actual code of DataCollatorForCompletionOnlyLM, which is defined here. As can be seen, the class takes a response_template as input, where you can indicate where the response (completion) starts.
Of course, you could also create the labels yourself from the input_ids
, and making sure they are set to -100 for any input id you don’t want to get predicted.
Thank you, I now understand how the DataCollatorForCompletionOnlyLM
works. It marks the start of the response and simply ignores tokens before that point.
This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.