How do LLMs identify generation start point during fine-tuning?

Hi,

During supervised fine-tuning (SFT), LLMs are trained to predict the next token, just like during the pre-training stage. The data consists of (instruction, completion) pairs, like this dataset for instance. By default, one trains on all tokens (and just makes the LLM predict the next token on each of the tokens).

Data preparation

Let’s take an (input, output) or (instruction, completion) pair as example. Let’s say we have as instruction: “give me a recipe to bake pancakes” with the completion “to bake pancakes, one takes a pan”, then you would prepare this for the model like so:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

instruction = "give me a recipe to bake pancakes"
completion =  "to bake pancakes, one takes a pan"

messages = [
    {"role": "user", "content": instruction},
    {"role": "assistant", "content": completion},
]

inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
print(inputs)

As you will see, the instruction and completion are concatenated and some special tokens are added which indicate where the completion starts:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ngive me a recipe to bake pancakes<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nto bake pancakes, one takes a pan<|eot_id|>'

As can be seen, the tokenizer has added a default system prompt, then the instruction (marked with a special user token, then the completion (marked with a special assistant token), all concatenated into a single string.

Training the model

Next, we will train the LLM to predict the next token for each of the tokens. By default, one simply trains on all tokens (yes, that includes the system and user prompts). This means that the model is teached to predict <|start_header_id|> after <|begin_of_text|>, then system after <|start_header_id|>, and so on (the model is trained to predict the next token for each tokens all in parallel in one go). By default, the labels are in other words just a copy of the inputs (the model will shift the labels internally one time step to predict the next token). Visually:

One could avoid this by using the DataCollatorForCompletionOnlyLM in the TRL library, which ensures that the labels are a copy of the inputs, with the system prompt and user prompt tokens replaced by -100 (the ignore index of PyTorch’s CrossEntropyLoss). This ensures that the model is only trained the predict the completion (“assistant”) tokens. Visually:

1 Like