Hey everyone,
I’m facing a dilemma whether the below code I’m using for supervised finetuning on a dataset is correct or not. The code runs but the model starts to overfit (training logs attached below code). Can anyone please help me if this is an issue and if yes, what part of the script?
script.py
:
import os
import torch
import peft as p
import pandas as pd
import argparse as g
import datasets as d
import accelerate as a
import transformers as t
from tqdm import tqdm
from torch.utils.data import DataLoader
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if __name__ == '__main__':
parser = g.ArgumentParser(description="Train a causal LM with HuggingFace Transformers.")
parser.add_argument("--model", type=str, required=True, help="Model name or local path.")
parser.add_argument("--precision", type=str, choices=['none', 'fp16', 'bf16'], default='none', help="Precision.")
parser.add_argument("--quantization", type=str, choices=['none', '4bit', '8bit'], default='none', help="Quantization.")
parser.add_argument("--lora", action='store_true', help="Should LoRA be used.")
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--max_length", type=int, default=512, help="Max sequence length.")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training.")
parser.add_argument("--patience", type=int, default=10, help="Patience for early stopping.")
args = parser.parse_args()
DATASET = 'zou-lab/MedCaseReasoning'
MODEL = args.model
MAX_LENGTH = args.max_length
FP16 = True if args.precision == 'fp16' else False
BF16 = True if args.precision == 'bf16' else False
LOAD_IN_4BIT = True if args.quantization == '4bit' else False
LOAD_IN_8BIT = True if args.quantization == '8bit' else False
LORA = args.lora
BATCH_SIZE = args.batch_size
LEARNING_RATE = args.learning_rate
EPOCHS = args.epochs
PATIENCE = args.patience
OUTPUT_DIR = f'{MODEL.replace("/", "-")}_{DATASET.replace("/", "-")}'
df = d.load_dataset(DATASET)
df = df.remove_columns(column_names=['title', 'text', 'Unnamed: 0', 'pmcid', 'journal', 'article_link', 'publication_date'])
df['train'] = df['test']
tokenizer = t.AutoTokenizer.from_pretrained(MODEL, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def format_dataset(x):
input_tokens = tokenizer([f"CASE:\n{case}\n\nDIAGNOSTIC REASONING:\n" for case in x['case_prompt']])
return input_tokens
df = df.map(format_dataset, batched=True)
df = df.filter(lambda x: len(x['input_ids']) < MAX_LENGTH)
df = df.remove_columns(column_names=['input_ids', 'attention_mask'])
def tokenize_data(x):
inp = f"CASE:\n{x['case_prompt']}\n\nDIAGNOSTIC REASONING:\n"
out = f"{x['diagnostic_reasoning']}.\n\nFINAL DIAGNOSIS:\n{x['final_diagnosis']}"
full = f"{inp}{out}"
full_tokens = tokenizer(full, truncation=True, max_length=MAX_LENGTH, padding='max_length') # [<BOS>, $A, $B, <EOS>, .... , <EOS>]
input_tokens = tokenizer(inp, truncation=True, max_length=MAX_LENGTH)['input_ids'] # [<BOS>, $A]
output_tokens = tokenizer(text=out, truncation=True, max_length=MAX_LENGTH, padding='max_length')['input_ids'] # [BOS, $B, <EOS>, .... , <EOS>]
labels = [-100] * len(input_tokens) + output_tokens
labels = labels[:MAX_LENGTH]
labels = labels + [tokenizer.pad_token_id] * (MAX_LENGTH - len(labels))
full_tokens['labels'] = labels
return full_tokens
df = df.map(tokenize_data, batched=False, remove_columns=['case_prompt', 'diagnostic_reasoning', 'final_diagnosis'])
df.set_format(type='torch')
print(df)
print(df['train'][0])
train_dataloader = DataLoader(
df['train'],
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
pin_memory=True,
prefetch_factor=2,
)
val_dataloader = DataLoader(
df['val'],
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4,
pin_memory=True,
prefetch_factor=2,
)
accelerator = a.Accelerator(mixed_precision=args.precision if args.precision in ['fp16', 'bf16'] else None, gradient_accumulation_steps=4)
device = accelerator.device
quantization_config = None
if LOAD_IN_4BIT or LOAD_IN_8BIT:
quantization_config = t.BitsAndBytesConfig(
load_in_8bit=LOAD_IN_8BIT,
load_in_4bit=LOAD_IN_4BIT,
llm_int8_enable_fp32_cpu_offload=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16 if BF16 else torch.float16,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
)
model = t.AutoModelForCausalLM.from_pretrained(
MODEL,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16 if LOAD_IN_4BIT or LOAD_IN_8BIT else None,
device_map={"": device.index if device.type == 'cuda' else 0} if quantization_config is not None else None,
)
if quantization_config is not None:
model = p.prepare_model_for_kbit_training(model)
print(model)
if LORA:
lora_config = p.LoraConfig(
task_type='CAUSAL_LM',
inference_mode=False,
r=16,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
lora_alpha=8,
lora_dropout=0.1,
fan_in_fan_out=False,
bias='none',
)
model = p.get_peft_model(model, lora_config)
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0)
lr_scheduler = t.get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=EPOCHS * len(train_dataloader),
)
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader, lr_scheduler)
log_df = {'Epoch': [], 'Train_Loss': [], 'Val_Loss': [], 'Learning_Rate': []}
print("============ TRAINING STARTED ============")
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop = False
for epoch in range(EPOCHS):
model.train()
total_train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
train_loss = outputs.loss
if torch.isnan(train_loss):
print("NaN detected in training loss")
print("Input IDs:", batch['input_ids'][0])
print("Labels:", batch['labels'][0])
assert False
accelerator.backward(train_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
total_train_loss = total_train_loss + accelerator.gather(train_loss.detach()).mean().item()
global_step = epoch * len(train_dataloader) + step
avg_train_loss = total_train_loss / len(train_dataloader)
model.eval()
total_val_loss = 0.0
with torch.no_grad():
for batch in val_dataloader:
outputs = model(**batch)
val_loss = outputs.loss
total_val_loss = total_val_loss + accelerator.gather(val_loss.detach()).mean().item()
avg_val_loss = total_val_loss / len(val_dataloader)
print(f'EPOCH {epoch + 1}: TRAIN LOSS ----> {avg_train_loss}, VALIDATION LOSS ----> {avg_val_loss}, LEARNING RATE ----> {lr_scheduler.get_last_lr()[0]}')
log_df['Epoch'].append(epoch)
log_df['Train_Loss'].append(avg_train_loss)
log_df['Val_Loss'].append(avg_val_loss)
log_df['Learning_Rate'].append(lr_scheduler.get_last_lr()[0])
if PATIENCE > 0:
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
print("Validation loss improved. Saving best model...")
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
if LORA:
unwrapped_model.save_pretrained(OUTPUT_DIR, save_adapter=True)
else:
unwrapped_model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
else:
epochs_no_improve = epochs_no_improve + 1
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s).")
if epochs_no_improve >= PATIENCE:
print("Early stopping triggered!")
early_stop = True
break
print(log_df)
log_df = pd.DataFrame(log_df)
log_df.to_csv(f'{OUTPUT_DIR}/logs.csv', index=False)
print("============ TRAINING ENDED ============")
logs.csv
:
Epoch,Train_Loss,Val_Loss,Learning_Rate
0,1.6755454306395292,1.2330640574633065,9.996651457418269e-05
1,1.204512412677943,1.197177179789139,9.985584850050682e-05
2,1.1752366725108967,1.1813179224224415,9.966798371124085e-05
3,1.1584693791028282,1.1724048307386494,9.940321110345845e-05
4,1.1442841144155436,1.1653084876173634,9.9061940661261e-05
5,1.132354564355969,1.1615709036083546,9.864470082094231e-05
6,1.1195868387451122,1.1601799867920957,9.815213765273755e-05
7,1.108680416096664,1.1578090019145255,9.758501386042321e-05
8,1.0922332310954885,1.1589073164988373,9.694420760031725e-05
9,1.079660495199582,1.1611757339057276,9.623071112150801e-05
10,1.0658968001720042,1.1645658208151994,9.544562922941746e-05
11,1.0500226521383773,1.166986705893177,9.459017757507808e-05
12,1.0339320170585593,1.1740479408684423,9.366568077277185e-05
13,1.0168083200736113,1.1799179931818429,9.267357034894651e-05
14,0.9996888250228496,1.1892044332067846,9.161538252558502e-05
15,0.9836784733022079,1.1998685891345395,9.049275584146009e-05
16,0.9637572932722658,1.2207969629158408,8.93074286149578e-05
17,0.9444253675050154,1.2331248673342041,8.806123625239845e-05
I run the code on multi-GPU setup using accelerate launch script.py --model <decoder-only LLM> --batch_size <batch size that fits> --lora
Please feel free to comment and let me know my mistakes. Thanks in advance.