Using completion_only_loss=True seems like the smartest approach.
You didnât miss a magical âautoâmask the promptâ feature.
With your current setup, SFTTrainer is computing loss on the entire formatted string (prompt + SMILES + (A)/(B)), except for padding. When you pass only a formatting_func, SFTTrainer treats your data as plain languageâmodel text, not as âprompt vs. completion.â
Iâll walk through:
- What SFTTrainer actually does when you pass
formatting_func
- What the masks (
labels and attention_mask) look like in that case
- Why that makes Tx Gemma âlearn nothingâ for your (A)/(B) label
- What you need to change to get the mask you expected
1. What SFTTrainer does with formatting_func
Step 1: formatting_func transforms your rows to pure text
You wrote:
def formatting_func(example):
text = f"{tdc_prompts_json[task_name].replace(input_type, example['input_text'])} {example['output_text']}<eos>"
return text
So for each dataset row { "input_text": SMILES, "output_text": "(B)" } you produce a single string:
"Instructions: ... Question: ... Drug SMILES: ... Answer: (B)<eos>"
Inside SFTTrainer, when you pass formatting_func, it:
- Applies
formatting_func to each example in train_dataset and eval_dataset. (Okkiie-ft next Journal)
- Replaces the original columns by this formatted text (internally it becomes something like a
"text" field).
- Then tokenizes that text into
input_ids and attention_mask.
At this point, SFTTrainer has no explicit notion of âpromptâ vs. âanswerâ anymore. It just sees a long text sequence per example. This is treated as a language modeling (LM) dataset. (Hugging Face)
Step 2: default collator = normal LM objective
If you do not pass a custom data_collator and you do not use the new promptâcompletion dataset type, TRL chooses a standard LM-style collator (conceptually the same as DataCollatorForLanguageModeling):
- It builds
input_ids, attention_mask.
- It builds
labels equal to input_ids for all non-pad tokens.
- It sets
labels = -100 only for padding positions so they are ignored. (Hugging Face)
There is no special masking for the prompt when you only use formatting_func.
Thatâs the key missing piece.
2. What the masks actually look like in your setup
Letâs look at one example conceptually.
Your formatted text:
"Instructions: Answer the following question about drug properties.
Context: Human ether-Ă -go-go related gene (hERG) ...
Question: Given a drug SMILES string, predict whether it
(A) does not inhibit hERG (B) inhibits hERG
Drug SMILES: O=C(NC[C@H]1CCCC[NH2+]1)c1cc(OCC(F)(F)F)ccc1OCC(F)(F)F
Answer: (B)<eos>"
After tokenization, you get:
input_ids = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]
attention_mask = [1, 1, 1, 1, ..., 1]
The default SFTTrainer collator (LM style) will then set:
labels = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]
with labels = -100 only on any padding added when batching.
So:
- Every token of the instructions, context, question, SMILES, âAnswer:â and
(B) contributes to the loss.
- Only padding is ignored.
- There is no âloss only on
(B)<eos>â behaviour.
This matches the old âtrain on completions onlyâ warning in the TRL docs: you must use a special collator or a promptâcompletion dataset format to get completion-only loss. If you just pass text, the trainer uses plain LM loss on all non-padding tokens. (Hugging Face)
3. Why this makes Tx Gemma look like it âlearned nothingâ
Two effects combine:
3.1 The prompt dominates the loss
Your answer (A) or (B) is just a handful of tokens; your prompt is long (instructions + context + question + SMILES).
If we roughly say:
- Prompt â 150 tokens
- Answer â 3â5 tokens
then >95% of the supervised tokens are in the prompt.
The modelâs gradients are dominated by:
- Getting the instructions right
- Getting the context paragraph right
- Getting the SMILES and surrounding text right
The tiny suffix (A) / (B) contributes almost nothing to the total loss compared to the prompt.
3.2 The prompt is mostly constant; the label is tiny
You re-use the same TDC-style instructions and question each time, with only the SMILES and (A)/(B) changing. So the model can reduce loss substantially simply by:
- Memorizing the fixed instruction & context text.
- Modestly improving how it reproduces SMILES-like patterns.
It does not need to learn the mapping âSMILES â (A)/(B)â to lower the loss meaningfully, because the supervision on those few label tokens is dwarfed by supervision on the fixed prompt.
3.3 Tx Gemma has known multiple-choice positional bias
There is a GitHub issue from the Gemma cookbook showing Tx Gemma on a hERG task where, for multiple-choice prompts like:
(A) is a hERG blocker (B) isn't ...
(A) isn't a hERG blocker (B) is ...
1: is a hERG blocker 0: isn't ...
0: isn't a hERG blocker 1: is ...
the model tends to always pick the first choice, regardless of content. (GitHub)
If your fine-tuning signal barely touches the answer tokens (A)/(B) because masking is wrong, that pre-existing positional bias remains. At inference, you observe:
- âThe model always outputs the same answer.â
- âIt seems to have learned nothing.â
This is exactly what you described.
4. What you need to change to get the mask you expected
You have two clean options depending on which TRL style you want.
Option 1: use promptâcompletion dataset + completion_only_loss=True (no formatting_func)
This is the newer / recommended SFTTrainer pattern. (Hugging Face)
-
Preprocess your dataset into explicit prompt and completion fields:
def preprocess(example):
# prompt includes everything up to "Answer:"
prompt = tdc_prompts_json[task_name].replace(
input_type,
example["input_text"], # SMILES
)
prompt = prompt + "\nAnswer:" # important: consistent marker
completion = f" {example['output_text']}<eos>" # " (A)" or " (B)"
return {"prompt": prompt, "completion": completion}
-
Map over your dataset:
ds = refactored_dataset_dict.map(
preprocess,
remove_columns=["input_text", "output_text"],
)
-
Train with SFTTrainer configured for promptâcompletion data:
from trl import SFTTrainer, SFTConfig
training_args = SFTConfig(
output_dir="txgemma-herg",
completion_only_loss=True, # ensures loss on completion only
max_seq_length=512,
# other hyperparams...
)
trainer = SFTTrainer(
model=model,
train_dataset=ds["train"],
eval_dataset=ds["val"],
args=training_args,
peft_config=lora_config,
# no formatting_func needed now
)
In this mode:
That gives you the âloss only on generated tokens (the answer)â behaviour you expected.
Option 2: keep formatting_func, add DataCollatorForCompletionOnlyLM
If you really want to keep a single formatted text string (your current style), then you must explicitly instruct the collator where the answer starts, using DataCollatorForCompletionOnlyLM. (Hugging Face)
-
Change your formatting_func slightly so it includes a stable answer prefix:
def formatting_func(example):
prompt = tdc_prompts_json[task_name].replace(input_type, example["input_text"])
# ensure this exact pattern is used:
text = f"{prompt}\nAnswer: {example['output_text']}<eos>"
return text
-
Create the collator with a response_template that matches the prefix before the answer:
from trl import DataCollatorForCompletionOnlyLM
response_template = "\nAnswer:" # exact substring before the answer
collator = DataCollatorForCompletionOnlyLM(
response_template=response_template,
tokenizer=tokenizer,
)
TRL docs and multiple guides show this pattern for âtrain on completions onlyâ with formatting_func. (Hugging Face)
-
Instantiate SFTTrainer with packing=False (required for this collator) and pass the collator:
trainer = SFTTrainer(
model=model,
train_dataset=refactored_dataset_dict["train"],
eval_dataset=refactored_dataset_dict["val"],
args=SFTConfig(
output_dir="txgemma-herg",
packing=False, # important with DataCollatorForCompletionOnlyLM
max_seq_length=512,
# ...
),
formatting_func=formatting_func,
data_collator=collator,
peft_config=lora_config,
)
Now the masking looks like:
- The collator searches the tokenized sequence for the first occurrence of the
response_template string.
- All tokens before that point get
labels = -100 (ignored).
- All tokens after that point (including
(A)/(B) and <eos>) retain their token IDs as labels. (Gist)
That is exactly âcompute loss only on the generated output,â while still using formatting_func to build the text.
5. Short mental model answer to your question
âHow are masks set up when a formatting_func is used, in my case?â
With only formatting_func and no special collator or promptâcompletion dataset:
So:
- Prompt + SMILES + âAnswer: (B)â are all supervised.
- There is no built-in logic that âknowsâ where the answer starts.
- As a result, the loss is dominated by the long, mostly constant prompt, and Tx Gemmaâs pre-existing bias (e.g., âalways pick first optionâ) remains, giving you almost the same output every time.
To get the loss mask you wanted (only on (A) / (B)), you must either:
- Switch to a
prompt / completion dataset with completion_only_loss=True, or
- Keep
formatting_func but add DataCollatorForCompletionOnlyLM with a proper response_template and packing=False.
Once you do that, the masks will be:
- Prompt tokens:
labels = -100 (ignored)
- Answer tokens: real label IDs â loss is computed only there.