Verification of script to train a LLM on supervised data

Hey everyone,
I’m facing a dilemma whether the below code I’m using for supervised finetuning on a dataset is correct or not. The code runs but the model starts to overfit (training logs attached below code). Can anyone please help me if this is an issue and if yes, what part of the script?
script.py:

import os
import torch
import peft as p
import pandas as pd
import argparse as g
import datasets as d
import accelerate as a
import transformers as t
from tqdm import tqdm
from torch.utils.data import DataLoader
os.environ["TOKENIZERS_PARALLELISM"] = "false"

if __name__ == '__main__':
    parser = g.ArgumentParser(description="Train a causal LM with HuggingFace Transformers.")
    parser.add_argument("--model", type=str, required=True, help="Model name or local path.")
    parser.add_argument("--precision", type=str, choices=['none', 'fp16', 'bf16'], default='none', help="Precision.")
    parser.add_argument("--quantization", type=str, choices=['none', '4bit', '8bit'], default='none', help="Quantization.")
    parser.add_argument("--lora", action='store_true', help="Should LoRA be used.")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
    parser.add_argument("--max_length", type=int, default=512, help="Max sequence length.")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training.")
    parser.add_argument("--patience", type=int, default=10, help="Patience for early stopping.")
    args = parser.parse_args()

    DATASET = 'zou-lab/MedCaseReasoning'
    MODEL = args.model
    MAX_LENGTH = args.max_length
    FP16 = True if args.precision == 'fp16' else False
    BF16 = True if args.precision == 'bf16' else False
    LOAD_IN_4BIT = True if args.quantization == '4bit' else False
    LOAD_IN_8BIT = True if args.quantization == '8bit' else False
    LORA = args.lora
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    EPOCHS = args.epochs
    PATIENCE = args.patience
    OUTPUT_DIR = f'{MODEL.replace("/", "-")}_{DATASET.replace("/", "-")}'

    df = d.load_dataset(DATASET)
    df = df.remove_columns(column_names=['title', 'text', 'Unnamed: 0', 'pmcid', 'journal', 'article_link', 'publication_date'])
    df['train'] = df['test']

    tokenizer = t.AutoTokenizer.from_pretrained(MODEL, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def format_dataset(x):
        input_tokens = tokenizer([f"CASE:\n{case}\n\nDIAGNOSTIC REASONING:\n" for case in x['case_prompt']])
        return input_tokens
    df = df.map(format_dataset, batched=True)
    df = df.filter(lambda x: len(x['input_ids']) < MAX_LENGTH)
    df = df.remove_columns(column_names=['input_ids', 'attention_mask'])

    def tokenize_data(x):
        inp = f"CASE:\n{x['case_prompt']}\n\nDIAGNOSTIC REASONING:\n"
        out = f"{x['diagnostic_reasoning']}.\n\nFINAL DIAGNOSIS:\n{x['final_diagnosis']}"
        full = f"{inp}{out}"
        full_tokens = tokenizer(full, truncation=True, max_length=MAX_LENGTH, padding='max_length') # [<BOS>, $A, $B, <EOS>, .... , <EOS>]
        input_tokens = tokenizer(inp, truncation=True, max_length=MAX_LENGTH)['input_ids'] # [<BOS>, $A]
        output_tokens = tokenizer(text=out, truncation=True, max_length=MAX_LENGTH, padding='max_length')['input_ids'] # [BOS, $B, <EOS>, .... , <EOS>]
        labels = [-100] * len(input_tokens) + output_tokens
        labels = labels[:MAX_LENGTH]
        labels = labels + [tokenizer.pad_token_id] * (MAX_LENGTH - len(labels))
        full_tokens['labels'] = labels
        return full_tokens
    df = df.map(tokenize_data, batched=False, remove_columns=['case_prompt', 'diagnostic_reasoning', 'final_diagnosis'])
    df.set_format(type='torch')
    print(df)
    print(df['train'][0])

    train_dataloader = DataLoader(
        df['train'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
    )
    val_dataloader = DataLoader(
        df['val'],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
    )

    accelerator = a.Accelerator(mixed_precision=args.precision if args.precision in ['fp16', 'bf16'] else None, gradient_accumulation_steps=4)
    device = accelerator.device

    quantization_config = None
    if LOAD_IN_4BIT or LOAD_IN_8BIT:
        quantization_config = t.BitsAndBytesConfig(
            load_in_8bit=LOAD_IN_8BIT,
            load_in_4bit=LOAD_IN_4BIT,
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.bfloat16 if BF16 else torch.float16,
            bnb_4bit_quant_type='nf4',
            bnb_4bit_use_double_quant=True,
        )

    model = t.AutoModelForCausalLM.from_pretrained(
        MODEL,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16 if LOAD_IN_4BIT or LOAD_IN_8BIT else None,
        device_map={"": device.index if device.type == 'cuda' else 0} if quantization_config is not None else None,
    )
    if quantization_config is not None:
        model = p.prepare_model_for_kbit_training(model)
    print(model)

    if LORA:
        lora_config = p.LoraConfig(
            task_type='CAUSAL_LM',
            inference_mode=False,
            r=16,
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
            lora_alpha=8,
            lora_dropout=0.1,
            fan_in_fan_out=False,
            bias='none',
        )
        model = p.get_peft_model(model, lora_config)
    print(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0)

    lr_scheduler = t.get_scheduler(
        name="cosine",
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=EPOCHS * len(train_dataloader),
    )

    model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader, lr_scheduler)

    log_df = {'Epoch': [], 'Train_Loss': [], 'Val_Loss': [], 'Learning_Rate': []}

    print("============ TRAINING STARTED ============")
    best_val_loss = float('inf')
    epochs_no_improve = 0
    early_stop = False
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                train_loss = outputs.loss
                if torch.isnan(train_loss):
                    print("NaN detected in training loss")
                    print("Input IDs:", batch['input_ids'][0])
                    print("Labels:", batch['labels'][0])
                    assert False
                accelerator.backward(train_loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                total_train_loss = total_train_loss + accelerator.gather(train_loss.detach()).mean().item()
            global_step = epoch * len(train_dataloader) + step
        avg_train_loss = total_train_loss / len(train_dataloader)
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_dataloader:
                outputs = model(**batch)
                val_loss = outputs.loss
                total_val_loss = total_val_loss + accelerator.gather(val_loss.detach()).mean().item()
            avg_val_loss = total_val_loss / len(val_dataloader)
        print(f'EPOCH {epoch + 1}: TRAIN LOSS ----> {avg_train_loss}, VALIDATION LOSS ----> {avg_val_loss}, LEARNING RATE ----> {lr_scheduler.get_last_lr()[0]}')
        log_df['Epoch'].append(epoch)
        log_df['Train_Loss'].append(avg_train_loss)
        log_df['Val_Loss'].append(avg_val_loss)
        log_df['Learning_Rate'].append(lr_scheduler.get_last_lr()[0])
        if PATIENCE > 0:
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                print("Validation loss improved. Saving best model...")
                unwrapped_model = accelerator.unwrap_model(model)
                if accelerator.is_main_process:
                    if LORA:
                        unwrapped_model.save_pretrained(OUTPUT_DIR, save_adapter=True)
                    else:
                        unwrapped_model.save_pretrained(OUTPUT_DIR)
                    tokenizer.save_pretrained(OUTPUT_DIR)
            else:
                epochs_no_improve = epochs_no_improve + 1
                print(f"No improvement in validation loss for {epochs_no_improve} epoch(s).")
            if epochs_no_improve >= PATIENCE:
                print("Early stopping triggered!")
                early_stop = True
                break
        print(log_df)
    log_df = pd.DataFrame(log_df)
    log_df.to_csv(f'{OUTPUT_DIR}/logs.csv', index=False)
    print("============ TRAINING ENDED ============")

logs.csv:

Epoch,Train_Loss,Val_Loss,Learning_Rate
0,1.6755454306395292,1.2330640574633065,9.996651457418269e-05
1,1.204512412677943,1.197177179789139,9.985584850050682e-05
2,1.1752366725108967,1.1813179224224415,9.966798371124085e-05
3,1.1584693791028282,1.1724048307386494,9.940321110345845e-05
4,1.1442841144155436,1.1653084876173634,9.9061940661261e-05
5,1.132354564355969,1.1615709036083546,9.864470082094231e-05
6,1.1195868387451122,1.1601799867920957,9.815213765273755e-05
7,1.108680416096664,1.1578090019145255,9.758501386042321e-05
8,1.0922332310954885,1.1589073164988373,9.694420760031725e-05
9,1.079660495199582,1.1611757339057276,9.623071112150801e-05
10,1.0658968001720042,1.1645658208151994,9.544562922941746e-05
11,1.0500226521383773,1.166986705893177,9.459017757507808e-05
12,1.0339320170585593,1.1740479408684423,9.366568077277185e-05
13,1.0168083200736113,1.1799179931818429,9.267357034894651e-05
14,0.9996888250228496,1.1892044332067846,9.161538252558502e-05
15,0.9836784733022079,1.1998685891345395,9.049275584146009e-05
16,0.9637572932722658,1.2207969629158408,8.93074286149578e-05
17,0.9444253675050154,1.2331248673342041,8.806123625239845e-05

I run the code on multi-GPU setup using accelerate launch script.py --model <decoder-only LLM> --batch_size <batch size that fits> --lora

Please feel free to comment and let me know my mistakes. Thanks in advance.

1 Like
df['train'] = df['test']

This makes the dataset smaller, so it seems prone to overfitting, but it’s probably just set up this way for testing purposes.

I don’t see any obvious issues with the data preprocessing code. I’m not sure if padding tokens should be included in the labels, and if that differs between Seq2Seq and CausalLM… I’m not sure about that part.

The LR is slightly high (though I don’t think it’s a problem), and the lack of LoRA Dropout or Weight Decay settings might be related to overfitting.

Thanks for the insight. I appreciate you taking a look. As you have mentioned, I did the train = test for testing purposes. Apart from this, any issues with quantization, LoRA, accelerator or training loop?

1 Like

I don’t think there are any major issues with quantization. The LoRA settings also appear to be correct. If there are any problems, it usually results in an error…:sweat_smile:

Regarding tokenization, I later noticed that the handling of special tokens and masks may not be complete.

And since the training loop section is manually created, making debugging difficult… For example, I can assist with fixes for cases where it doesn’t work due to syntax errors, but it’s challenging to identify logical oversights.

Given that manually performing all of this correctly is quite tedious (there are numerous model-specific conventions and details that aren’t typically obvious), I think it would be simpler to implement using an existing trainer.

Trainer does sound like a great tool but I noticed that it lags in speed as compared to the manual loop as it has a lot of stuff built on top of the loop. Also writing a manual loop provides a sense of control over the training. As far as errors are concerned, the script runs well :face_exhaling:. Also HuggingFace doesn’t provide a tutorial for supervised fine-tuning using Trainer from transformers (only a trl based SFTTrainer tutorial is provided). The script I provided tries supervised finetuning using transformers. Also HuggingFace tutorials are not up-to-date with the stable versions of the modules it uses which leads to a lot of errors and debugging sessions on my part. Due to this reason, I try manually doing stuff like the tokenization, data loading, training loop, etc. Can you help me out with the tokenization part and help me figure out my mistake?

1 Like

I see, that makes sense for research and optimization purposes.:grinning_face:

the tokenization part

One is about special token masking. I’m not sure if this needs to be passed when using a non-embedding model. I just wondered.
https://stackoverflow.com/questions/61707371/about-get-special-tokens-mask-in-huggingface-transformers

The other is about chat templates. If using an instruct model, it would be desirable to pass tokens that have been converted from appropriately processed strings for each LLM. I guess this probably isn’t an issue with the base model?