SFTTrainer loss function and formatting_func

I would like to know how the masks are setup when a formatting_func is used. In my case, I am trying to fine-tune TX Gemma. I have refactored my dataset into two columns: the SMILES string and the expected output (in my case (A) or (B)).

I use the tdc_prompt in the formatting_func and I concatenate the generated prompt and the expected output to have a single text at the end of the data refactoring pipeline.

def formatting_func(example):
    text = f"{tdc_prompts_json[task_name].replace(input_type, example['input_text'])} {example['output_text']}<eos>"
    return text

My text in as follows:
“Instructions: Answer the following question about drug properties.\nContext: Human ether-à-go-go related gene (hERG) is crucial for the coordination of the heart’s beating. Thus, if a drug blocks the hERG, it could lead to severe adverse effects. Therefore, reliable prediction of hERG liability in the early stages of drug design is quite important to reduce the risk of cardiotoxicity-related attritions in the later development stages.\nQuestion: Given a drug SMILES string, predict whether it\n(A) does not inhibit hERG (B) inhibits hERG\nDrug SMILES: O=C(NC[C@H]1CCCC[NH2+]1)c1cc(OCC(F)(F)F)ccc1OCC(F)(F)F\nAnswer: (B)”

When I create my trainer object, I just pass the dataset containing the [‘intput_text’] and the [‘output_text’] and the formatting_func and I assume that the trainer will take care of the mask for computing the loss only on the generated tokens and not the entire prompt.

trainer = SFTTrainer(
    model=model,
    train_dataset=refactored_dataset_dict['train'],
    eval_dataset=refactored_dataset_dict['val'],
    args=SFTConfig(
),
    formatting_func=formatting_func,
    peft_config=lora_config
)

The training goes “well” as the loss decreases, but when I test the fine-tuned model, I realize that it has learn nothing, and the generated output is always the same, hint that the loss has been computed on the entire prompt, not only the generated output.

What did I miss?

Thanks for the help

Jerome

1 Like

Using completion_only_loss=True seems like the smartest approach.


You didn’t miss a magical “auto–mask the prompt” feature.
With your current setup, SFTTrainer is computing loss on the entire formatted string (prompt + SMILES + (A)/(B)), except for padding. When you pass only a formatting_func, SFTTrainer treats your data as plain language–model text, not as “prompt vs. completion.”

I’ll walk through:

  1. What SFTTrainer actually does when you pass formatting_func
  2. What the masks (labels and attention_mask) look like in that case
  3. Why that makes Tx Gemma “learn nothing” for your (A)/(B) label
  4. What you need to change to get the mask you expected

1. What SFTTrainer does with formatting_func

Step 1: formatting_func transforms your rows to pure text

You wrote:

def formatting_func(example):
    text = f"{tdc_prompts_json[task_name].replace(input_type, example['input_text'])} {example['output_text']}<eos>"
    return text

So for each dataset row { "input_text": SMILES, "output_text": "(B)" } you produce a single string:

"Instructions: ... Question: ... Drug SMILES: ... Answer: (B)<eos>"

Inside SFTTrainer, when you pass formatting_func, it:

  1. Applies formatting_func to each example in train_dataset and eval_dataset. (Okkiie-ft next Journal)
  2. Replaces the original columns by this formatted text (internally it becomes something like a "text" field).
  3. Then tokenizes that text into input_ids and attention_mask.

At this point, SFTTrainer has no explicit notion of “prompt” vs. “answer” anymore. It just sees a long text sequence per example. This is treated as a language modeling (LM) dataset. (Hugging Face)

Step 2: default collator = normal LM objective

If you do not pass a custom data_collator and you do not use the new prompt–completion dataset type, TRL chooses a standard LM-style collator (conceptually the same as DataCollatorForLanguageModeling):

  • It builds input_ids, attention_mask.
  • It builds labels equal to input_ids for all non-pad tokens.
  • It sets labels = -100 only for padding positions so they are ignored. (Hugging Face)

There is no special masking for the prompt when you only use formatting_func.

That’s the key missing piece.


2. What the masks actually look like in your setup

Let’s look at one example conceptually.

Your formatted text:

"Instructions: Answer the following question about drug properties.
Context: Human ether-Ă -go-go related gene (hERG) ...
Question: Given a drug SMILES string, predict whether it
(A) does not inhibit hERG (B) inhibits hERG
Drug SMILES: O=C(NC[C@H]1CCCC[NH2+]1)c1cc(OCC(F)(F)F)ccc1OCC(F)(F)F
Answer: (B)<eos>"

After tokenization, you get:

input_ids = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]
attention_mask = [1, 1, 1, 1, ..., 1]

The default SFTTrainer collator (LM style) will then set:

labels      = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]

with labels = -100 only on any padding added when batching.

So:

  • Every token of the instructions, context, question, SMILES, “Answer:” and (B) contributes to the loss.
  • Only padding is ignored.
  • There is no “loss only on (B)<eos>” behaviour.

This matches the old “train on completions only” warning in the TRL docs: you must use a special collator or a prompt–completion dataset format to get completion-only loss. If you just pass text, the trainer uses plain LM loss on all non-padding tokens. (Hugging Face)


3. Why this makes Tx Gemma look like it “learned nothing”

Two effects combine:

3.1 The prompt dominates the loss

Your answer (A) or (B) is just a handful of tokens; your prompt is long (instructions + context + question + SMILES).

If we roughly say:

  • Prompt ≈ 150 tokens
  • Answer ≈ 3–5 tokens

then >95% of the supervised tokens are in the prompt.

The model’s gradients are dominated by:

  • Getting the instructions right
  • Getting the context paragraph right
  • Getting the SMILES and surrounding text right

The tiny suffix (A) / (B) contributes almost nothing to the total loss compared to the prompt.

3.2 The prompt is mostly constant; the label is tiny

You re-use the same TDC-style instructions and question each time, with only the SMILES and (A)/(B) changing. So the model can reduce loss substantially simply by:

  • Memorizing the fixed instruction & context text.
  • Modestly improving how it reproduces SMILES-like patterns.

It does not need to learn the mapping “SMILES → (A)/(B)” to lower the loss meaningfully, because the supervision on those few label tokens is dwarfed by supervision on the fixed prompt.

3.3 Tx Gemma has known multiple-choice positional bias

There is a GitHub issue from the Gemma cookbook showing Tx Gemma on a hERG task where, for multiple-choice prompts like:

(A) is a hERG blocker (B) isn't ...
(A) isn't a hERG blocker (B) is ...
1: is a hERG blocker  0: isn't ...
0: isn't a hERG blocker 1: is ...

the model tends to always pick the first choice, regardless of content. (GitHub)

If your fine-tuning signal barely touches the answer tokens (A)/(B) because masking is wrong, that pre-existing positional bias remains. At inference, you observe:

  • “The model always outputs the same answer.”
  • “It seems to have learned nothing.”

This is exactly what you described.


4. What you need to change to get the mask you expected

You have two clean options depending on which TRL style you want.

Option 1: use prompt–completion dataset + completion_only_loss=True (no formatting_func)

This is the newer / recommended SFTTrainer pattern. (Hugging Face)

  1. Preprocess your dataset into explicit prompt and completion fields:

    def preprocess(example):
        # prompt includes everything up to "Answer:"
        prompt = tdc_prompts_json[task_name].replace(
            input_type,
            example["input_text"],  # SMILES
        )
        prompt = prompt + "\nAnswer:"   # important: consistent marker
    
        completion = f" {example['output_text']}<eos>"  # " (A)" or " (B)"
    
        return {"prompt": prompt, "completion": completion}
    
  2. Map over your dataset:

    ds = refactored_dataset_dict.map(
        preprocess,
        remove_columns=["input_text", "output_text"],
    )
    
  3. Train with SFTTrainer configured for prompt–completion data:

    from trl import SFTTrainer, SFTConfig
    
    training_args = SFTConfig(
        output_dir="txgemma-herg",
        completion_only_loss=True,  # ensures loss on completion only
        max_seq_length=512,
        # other hyperparams...
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=ds["train"],
        eval_dataset=ds["val"],
        args=training_args,
        peft_config=lora_config,
        # no formatting_func needed now
    )
    

In this mode:

  • SFTTrainer knows which tokens belong to the prompt and which belong to the completion. (Hugging Face)

  • It automatically masks labels so that:

    • labels = -100 for all prompt tokens.
    • labels = token_id only for completion tokens.

That gives you the “loss only on generated tokens (the answer)” behaviour you expected.


Option 2: keep formatting_func, add DataCollatorForCompletionOnlyLM

If you really want to keep a single formatted text string (your current style), then you must explicitly instruct the collator where the answer starts, using DataCollatorForCompletionOnlyLM. (Hugging Face)

  1. Change your formatting_func slightly so it includes a stable answer prefix:

    def formatting_func(example):
        prompt = tdc_prompts_json[task_name].replace(input_type, example["input_text"])
        # ensure this exact pattern is used:
        text = f"{prompt}\nAnswer: {example['output_text']}<eos>"
        return text
    
  2. Create the collator with a response_template that matches the prefix before the answer:

    from trl import DataCollatorForCompletionOnlyLM
    
    response_template = "\nAnswer:"  # exact substring before the answer
    
    collator = DataCollatorForCompletionOnlyLM(
        response_template=response_template,
        tokenizer=tokenizer,
    )
    

    TRL docs and multiple guides show this pattern for “train on completions only” with formatting_func. (Hugging Face)

  3. Instantiate SFTTrainer with packing=False (required for this collator) and pass the collator:

    trainer = SFTTrainer(
        model=model,
        train_dataset=refactored_dataset_dict["train"],
        eval_dataset=refactored_dataset_dict["val"],
        args=SFTConfig(
            output_dir="txgemma-herg",
            packing=False,   # important with DataCollatorForCompletionOnlyLM
            max_seq_length=512,
            # ...
        ),
        formatting_func=formatting_func,
        data_collator=collator,
        peft_config=lora_config,
    )
    

Now the masking looks like:

  • The collator searches the tokenized sequence for the first occurrence of the response_template string.
  • All tokens before that point get labels = -100 (ignored).
  • All tokens after that point (including (A)/(B) and <eos>) retain their token IDs as labels. (Gist)

That is exactly “compute loss only on the generated output,” while still using formatting_func to build the text.


5. Short mental model answer to your question

“How are masks set up when a formatting_func is used, in my case?”

With only formatting_func and no special collator or prompt–completion dataset:

  • SFTTrainer:

    • Applies formatting_func → gets a single text per example.

    • Tokenizes to input_ids, attention_mask.

    • Sets:

      • attention_mask = 1 for all real tokens, 0 for padding.
      • labels = input_ids for all non-padding tokens.
      • labels = -100 only for padding tokens.

So:

  • Prompt + SMILES + “Answer: (B)” are all supervised.
  • There is no built-in logic that “knows” where the answer starts.
  • As a result, the loss is dominated by the long, mostly constant prompt, and Tx Gemma’s pre-existing bias (e.g., “always pick first option”) remains, giving you almost the same output every time.

To get the loss mask you wanted (only on (A) / (B)), you must either:

  • Switch to a prompt / completion dataset with completion_only_loss=True, or
  • Keep formatting_func but add DataCollatorForCompletionOnlyLM with a proper response_template and packing=False.

Once you do that, the masks will be:

  • Prompt tokens: labels = -100 (ignored)
  • Answer tokens: real label IDs → loss is computed only there.