Finetuning T5 problems

Hey thank you very much again. I tried to rebuild this with my current setup but I think my case is a bit different because I am using a classification scenario. So I am using the PT5_classification_model and the DataCollatorForTokenClassification. One problem might be that I have a mapping from the aa_tokens to the classes:

def create_token_to_aa_mapping(tokenizer):
    """
    Create mapping from tokenizer token IDs to amino acid class indices (0-19).
    """
    # Standard 20 amino acids
    AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
    AA_TO_CLASS = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
    
    # Create mapping from token ID to class ID
    token_to_class = {}
    
    # Map single amino acid tokens
    for aa in AMINO_ACIDS:
        token_ids = tokenizer.encode(aa, add_special_tokens=False)
        if len(token_ids) == 1:
            token_to_class[token_ids[0]] = AA_TO_CLASS[aa]
    
    # Handle special tokens - map them to -100 (ignore in loss)
    special_token_ids = [
        tokenizer.pad_token_id,
        tokenizer.eos_token_id,
        tokenizer.unk_token_id,
    ]
    
    for token_id in special_token_ids:
        if token_id is not None:
            token_to_class[token_id] = -100
    
    return token_to_class

So my main function currently looks like this:

def main():

    args = get_input()

    set_seed()

    df = pd.read_csv("data.csv")

    if 'data_split' in df.columns:
        train_df = df[df['data_split'] == 'train']  
        val_df = df[df['data_split'] == 'valid']      

        train_df = train_df.iloc[:args.num_samples]
        val_df = val_df.iloc[:int(args.num_samples*0.1)]

        print(f'training samples:{len(train_df)}')
        print(f'test samples:{len(val_df)}')
    
        train_pairs = create_pairs(train_df)
        val_pairs = create_pairs(val_df)
        
        # Create datasets
        train = Dataset.from_list(train_pairs)
        val = Dataset.from_list(val_pairs)
        
        print(f"Train samples: {len(train_pairs)}")
        print(f"Val samples: {len(val_pairs)}")

    # Examples
    print("\nFirst few training examples:")
    for i in range(min(3, len(train))):
        example = train[i]
        print(f"Example {i+1}:")
        print(f"  Source (3Di): {example['src'][:50]}...")
        print(f"  Target (AA):  {example['tgt'][:50]}...")
        if 'key' in example:
            print(f"  Key: {example['key']}")
        print()

    # Check sequence lengths
    src_lengths = [len(example['src'].split()) for example in train]
    tgt_lengths = [len(example['tgt'].split()) for example in train]

    print(f"Source (3Di) length stats: min={min(src_lengths)}, max={max(src_lengths)}, avg={sum(src_lengths)/len(src_lengths):.1f}")
    print(f"Target (AA) length stats: min={min(tgt_lengths)}, max={max(tgt_lengths)}, avg={sum(tgt_lengths)/len(tgt_lengths):.1f}")

    ####################### model + tokenizer #######################

    model, tokenizer = PT5_classification_model(
            num_labels=20, model_dir="model_snapshot"
        )
    
    # Create token to amino acid class mapping
    token_to_class = create_token_to_aa_mapping(tokenizer)

    print(f"\n=== Token Mapping ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    print(f"Token to class mappings created: {len(token_to_class)}")
    print(f"Sample mappings: {dict(list(token_to_class.items())[:10])}")

    # Set sequence lengths
    src_max = min(max(src_lengths) + 10, 512)
    tgt_max = min(max(tgt_lengths) + 10, 512)
    print(f"Using src_max={src_max}, tgt_max={tgt_max}")

    print(f"Using src_max={src_max}, tgt_max={tgt_max}")

    # Data collator
    data_collator = DataCollatorForTokenClassification(tokenizer)

    def preprocess(ex):
        # Tokenize source 
        enc = tokenizer(ex["src"], truncation=True, max_length=src_max)
        
        # Tokenize target (AA sequence) 
        tgt_tokens = tokenizer(ex["tgt"], truncation=True, max_length=tgt_max, 
                               add_special_tokens=False)
        
        # Convert token IDs to amino acid class labels (0-19)
        class_labels = convert_tokens_to_classes(tgt_tokens["input_ids"], token_to_class)
        
        # Pad/truncate labels to match input length
        # For seq2seq, labels should align with decoder inputs
        enc["labels"] = class_labels
        
        return enc
    
    # Process datasets
    train_processed = train.map(preprocess, remove_columns=train.column_names)
    val_processed = val.map(preprocess, remove_columns=val.column_names)

    # Verification
    print("\n=== DEBUGGING TOKENIZATION ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")

    sample = train_processed[0]
    print(f"Sample input_ids: {sample['input_ids'][:10]}...")
    print(f"Sample labels (should be 0-19 or -100): {sample['labels'][:10]}...")

    # Check label values
    all_labels = []
    for item in train_processed:
        all_labels.extend([x for x in item['labels'] if x != -100])

    if all_labels:
        unique_labels = set(all_labels)
        print(f"Unique label values: {sorted(unique_labels)}")
        print(f"Max label: {max(all_labels)}")
        print(f"Min label: {min(all_labels)}")
        
        if max(all_labels) >= 20:
            print(f"❌ ERROR: Labels exceed 20 classes! Max: {max(all_labels)}")
            print("Need to debug the token mapping...")
            
            # Debug: show what's being tokenized
            test_aa = "ACDEFG"
            tokens = tokenizer(test_aa, add_special_tokens=False)
            print(f"\nTest AA sequence: {test_aa}")
            print(f"Token IDs: {tokens['input_ids']}")
            print(f"Tokens: {tokenizer.convert_ids_to_tokens(tokens['input_ids'])}")
            
            for tid in tokens['input_ids']:
                class_id = token_to_class.get(tid, -100)
                print(f"  Token {tid} -> Class {class_id}")
            
            return
        else:
            print(f"✓ All labels within [0, 19] range (20 amino acid classes)")

    # Sanity check: verify that label pads become -100
    batch = data_collator([train_processed[i] for i in range(min(2, len(train_processed)))])
    assert (batch["labels"] == -100).any().item(), "Label pad masking failed"
    print("✓ Label padding correctly masked to -100")

    # DEBUG: Check tokenization results
    print("\n=== DEBUGGING TOKENIZATION ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    sample = train_processed[0]
    print(f"Sample input_ids: {sample['input_ids'][:10]}...")
    print(f"Sample labels: {sample['labels'][:10]}...")
    print(f"Max input_id: {max(sample['input_ids']) if sample['input_ids'] else 'None'}")
    print(f"Min input_id: {min(sample['input_ids']) if sample['input_ids'] else 'None'}")
    print(f"Max label: {max([x for x in sample['labels'] if x != -100]) if sample['labels'] else 'None'}")
    print(f"Min label: {min([x for x in sample['labels'] if x != -100]) if sample['labels'] else 'None'}")

    # Check if any labels are out of bounds
    all_labels = []
    for item in train_processed:
        all_labels.extend([x for x in item['labels'] if x != -100])

    vocab_size = len(tokenizer)
    out_of_bounds = [x for x in all_labels if x >= vocab_size or x < 0]
    if out_of_bounds:
        print(f"❌ FOUND {len(out_of_bounds)} OUT-OF-BOUNDS LABELS:")
        print(f"   Vocab size: {vocab_size}")
        print(f"   Out of bounds values: {sorted(set(out_of_bounds))[:10]}...")
        
        
        # See what the tokenizer produces for sequences
        print(f"\nDEBUG: Raw sequences vs tokenized:")
        raw_3di = train[0]['src'][:20]  # First 20 chars
        raw_aa = train[0]['tgt'][:20]   # First 20 chars
        print(f"Raw 3Di: {raw_3di}")
        print(f"Raw AA:  {raw_aa}")
        
        tok_3di = tokenizer(raw_3di)['input_ids']
        tok_aa = tokenizer(text_target=raw_aa)['input_ids']
        print(f"Tokenized 3Di: {tok_3di}")
        print(f"Tokenized AA:  {tok_aa}")
        
        return  # Stop execution to debug
    else:
        print("✓ All labels are within vocabulary bounds")

    # Debug: Check what's in the datasets before processing
    print(f"\n=== Dataset Debug ===")
    print(f"Raw train_pairs: {len(train_pairs)}")
    print(f"Raw val_pairs: {len(val_pairs)}")

    # Process datasets
    train_processed = train.map(preprocess, remove_columns=train.column_names)
    val_processed = val.map(preprocess, remove_columns=val.column_names)

    print(f"Processed train: {len(train_processed)}")
    print(f"Processed val: {len(val_processed)}")


    # Training arguments (following safe code pattern)
    training_args = Seq2SeqTrainingArguments(
        output_dir="finetuning",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=1,
        learning_rate=1e-4,
        lr_scheduler_type="linear",
        warmup_ratio=0.05,
        eval_strategy="steps",
        eval_steps=100,  # Adjust based on your dataset size
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        predict_with_generate=False,
        generation_max_length=tgt_max,
        group_by_length=True,
        fp16=False,
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        report_to="none",
        remove_unused_columns=False, # added
        save_safetensors=False
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_processed,
        eval_dataset=val_processed,
        data_collator=data_collator,
        processing_class=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    # Train
    print("\nStarting training...")
    metrics = trainer.train()
    print(metrics)
    
    print("\nEvaluation results:")
    eval_results = trainer.evaluate()
    print(eval_results)

    # Test predictions
    print("\n=== Testing predictions ===")
    pred = trainer.predict(val_processed.select(range(min(3, len(val_processed)))))
    
    # For classification, predictions are logits
    if hasattr(pred, 'predictions'):
        pred_logits = pred.predictions
        pred_classes = np.argmax(pred_logits, axis=-1)
        
        print(f"Prediction shape: {pred_classes.shape}")
        print(f"First prediction (class indices): {pred_classes[0][:20]}")
        
        # Convert class indices back to amino acids
        CLASS_TO_AA = "ACDEFGHIKLMNPQRSTVWY"
        for i in range(min(3, len(pred_classes))):
            pred_aa = ''.join([CLASS_TO_AA[c] if 0 <= c < 20 else '?' 
                              for c in pred_classes[i] if c != -100])
            print(f"Sample {i+1} predicted AA: {pred_aa[:50]}...")

    print("\n✓ Training completed successfully!")
1 Like