Finetuning T5 problems

In that case, either DataCollator or input IDs might be incorrect. Here’s some safe code.

# pip install -U transformers accelerate datasets huggingface_hub[hf_xet] trackio
# Minimal, safe baseline for token→token seq2seq (e.g., protein tokens).
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    EarlyStoppingCallback,
)
from transformers.integrations import TrackioCallback
import torch
import numpy as np
import random

def set_seed(s=13):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
set_seed()

# --- toy tokenized protein-like pairs (replace with your data) ---
def join(chars): return " ".join(list(chars))
pairs = [
    dict(src=join("dvavqavvvvyvyyvvvvqvvqcvllllvvvvvvvvvcy"),
         tgt="10 8 2 18 13 9 12 15 15 5 12 19 3 9 14 7 3 17 13 12"),
]*40
raw = Dataset.from_list(pairs).train_test_split(test_size=0.1, seed=0)

# --- model + tokenizer ---
model_name = "t5-small"  # swap with your pretrained protein T5
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# --- preprocess: build labels via text_target (no -100 here) ---
src_max, tgt_max = 128, 64
def preprocess(ex):
    enc = tok(ex["src"], truncation=True, max_length=src_max)
    lab = tok(text_target=ex["tgt"], truncation=True, max_length=tgt_max)
    enc["labels"] = lab["input_ids"]  # collator will mask pad to -100
    return enc

train = raw["train"].map(preprocess, remove_columns=raw["train"].column_names)
val   = raw["test"].map(preprocess,  remove_columns=raw["test"].column_names)

# --- collator: masks label padding to -100 automatically ---
collator = DataCollatorForSeq2Seq(tokenizer=tok, model=model, pad_to_multiple_of=8)
# sanity: check that label pads become -100
batch = collator([train[i] for i in range(min(2, len(train)))])
assert (batch["labels"] == -100).any().item(), "Label pad masking failed"

# --- training args: early stopping tracks eval_loss ---
args = Seq2SeqTrainingArguments(
    output_dir="out",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    learning_rate=1e-4,
    lr_scheduler_type="linear",
    warmup_ratio=0.05,
    eval_strategy="steps",
    eval_steps=5,
    save_strategy="steps",
    save_steps=5,                      # keep equal to eval_steps
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,         # ensure .predict() returns token ids
    generation_max_length=tgt_max,
    group_by_length=True,
    fp16=False,                         # safer for T5; use bf16 if available
    logging_strategy="steps",
    logging_steps=1,
    logging_first_step=True,
    report_to="none",
    #report_to="trackio", # https://huggingface.co/docs/trackio/index
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    data_collator=collator,
    processing_class=tok,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TrackioCallback(),],
)

# --- train + quick decode demo ---
metrics = trainer.train()
print(metrics)
print(trainer.evaluate())

pred = trainer.predict(val)
pred_ids = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
decoded = tok.batch_decode(pred_ids[:3], skip_special_tokens=True)
print("DECODED SAMPLES:")
for s in decoded: print(s)
"""
Step	Training Loss	Validation Loss
5	4.073900	4.016737
10	3.765400	3.520519
15	3.668000	3.347217
20	3.344500	3.184257
25	3.069900	3.072643
30	3.276600	3.006171
35	3.094000	2.975766

{'eval_loss': 2.97576642036438, 'eval_runtime': 0.5958, 'eval_samples_per_second': 6.714, 'eval_steps_per_second': 6.714, 'epoch': 1.0}
DECODED SAMPLES:
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
"""