How do LLMs identify generation start point during fine-tuning?

When fine-tuning an LLM, we pass a long string containing context , user input , and the expected answer . How does the model know which word to start generating from.
Is the loss during fine-tuning not only from generating the answer?

Hi,

During supervised fine-tuning (SFT), LLMs are trained to predict the next token, just like during the pre-training stage. The data consists of (instruction, completion) pairs, like this dataset for instance. By default, one trains on all tokens (and just makes the LLM predict the next token on each of the tokens).

Data preparation

Let’s take an (input, output) or (instruction, completion) pair as example. Let’s say we have as instruction: “give me a recipe to bake pancakes” with the completion “to bake pancakes, one takes a pan”, then you would prepare this for the model like so:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

instruction = "give me a recipe to bake pancakes"
completion =  "to bake pancakes, one takes a pan"

messages = [
    {"role": "user", "content": instruction},
    {"role": "assistant", "content": completion},
]

inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
print(inputs)

As you will see, the instruction and completion are concatenated and some special tokens are added which indicate where the completion starts:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ngive me a recipe to bake pancakes<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nto bake pancakes, one takes a pan<|eot_id|>'

As can be seen, the tokenizer has added a default system prompt, then the instruction (marked with a special user token, then the completion (marked with a special assistant token), all concatenated into a single string.

Training the model

Next, we will train the LLM to predict the next token for each of the tokens. By default, one simply trains on all tokens (yes, that includes the system and user prompts). This means that the model is teached to predict <|start_header_id|> after <|begin_of_text|>, then system after <|start_header_id|>, and so on (the model is trained to predict the next token for each tokens all in parallel in one go). By default, the labels are in other words just a copy of the inputs (the model will shift the labels internally one time step to predict the next token). Visually:

One could avoid this by using the DataCollatorForCompletionOnlyLM in the TRL library, which ensures that the labels are a copy of the inputs, with the system prompt and user prompt tokens replaced by -100 (the ignore index of PyTorch’s CrossEntropyLoss). This ensures that the model is only trained the predict the completion (“assistant”) tokens. Visually:

1 Like

Does this not only ignore the calculation of PyTorch’s CrossEntropyLoss for the prompt tokens, but also ignore
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ngive me a recipe to bake pancakes<|eot_id|><|start_header_id|>assistant<|end_header_id|>
as well ?

That is a good question. I think for that we would need to look into the actual code of DataCollatorForCompletionOnlyLM, which is defined here. As can be seen, the class takes a response_template as input, where you can indicate where the response (completion) starts.

Of course, you could also create the labels yourself from the input_ids, and making sure they are set to -100 for any input id you don’t want to get predicted.

1 Like

Thank you, I now understand how the DataCollatorForCompletionOnlyLM works. It marks the start of the response and simply ignores tokens before that point.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.