Hi all!
I’m having an extremely difficult time finetuning Llama-2-7B-Chat and Llama-3.1-8B-Instruct models on the GSM8K dataset. I have spent a week optimizing hyperparameters but never seem to obtain satisfactory results when evaluating on GSM8K using lm-eval-harness.
I have trained LLMs before, so this really frustrates me at this point. I am wondering whether there is an error in my finetuning script and/or the way how I evaluate and process it.
What I want to do:
Specifically, I want to finetune using PEFT/LoRA without the overhead of unsloth and other finetuning frameworks. The goal is to investigate a research question targeting Chat/Instruct models.
How I do it:
Here is my finetuning script for Llama-2-7B-Chat-HF:
import functools
from typing import Dict, Any
import yaml
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
set_seed
)
from trl import SFTTrainer, SFTConfig
def chat_format(example: Dict[str, Any], tokenizer) -> Dict[str, str]:
"""Format the example to include the question and answer in the text field."""
prompt = example['question']
answer = example['answer']
chat = [
{"role": "system", "content": "You are a helpful assistant."},
{'role': 'user', 'content': prompt},
{'role': 'assistant', 'content': answer}
]
text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt = False)
return {'text': text}
def finetune() -> None:
"""Fine-tune the model based on the config."""
# Seed
seed = 42
set_seed(seed)
# Model and tokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Handle pad token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.pad_token = tokenizer.eos_token
# Load and preprocess dataset
dataset = load_dataset("openai/gsm8k", "main")
chat_dataset = dataset.map(functools.partial(chat_format, tokenizer=tokenizer), remove_columns=dataset['train'].column_names)
# LoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=16,
lora_dropout=0.01,
bias="none"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Huggingface
hf_output_dir = "./results/llama-2-7b-chat-hf"
push_to_hub = hf_output_dir is not None
# SFTConfig
training_args = SFTConfig(
output_dir=hf_output_dir if push_to_hub else "./results",
max_seq_length=1024,
dataset_text_field='text',
auto_find_batch_size=False,
per_device_train_batch_size=2,
per_device_eval_batch_size=4,
gradient_accumulation_steps=16,
num_train_epochs=3,
logging_steps=1,
save_steps=100,
eval_strategy="steps",
eval_steps=50,
bf16=True,
lr_scheduler_type="cosine",
learning_rate=5e-5,
warmup_ratio=0.1,
weight_decay=0.0,
push_to_hub=push_to_hub,
load_best_model_at_end=False,
ddp_find_unused_parameters=False
)
# SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=chat_dataset['train'],
eval_dataset=chat_dataset['test'],
tokenizer=tokenizer,
args=training_args
)
trainer.train()
# Evaluation
print(trainer.evaluate())
# Save locally
save_path = "./results/llama-2-7b-chat-hf"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
if __name__ == "__main__":
finetune()
I have found the current hyperparams (LR, epochs, etc.) to be the sweet spot.
How I finetune:
I finetune using FSDP as follows:
accelerate launch --config_file "configs/fsdp.yaml" --num_processes=8 src/finetune_basic.py --config configs/llama-2-7B-chat-hf-gsm8k.yaml
How I evaluate:
I evaluate using lm-eval-harness (0-shot, gsm8k task)
accelerate launch --num_processes=8 -m lm_eval --model hf --model_args pretrained=meta-llama/Llama-2-7b-chat-hf,peft=ketchup123/llama-2-7b-chat-hf --tasks gsm8k --num_fewshot 0 --batch_size 16 --apply_chat_template
For Llama-3.1-8B instruct, I have used the gsm8k_cot_llama from lm-eval-harness, as suggested.
My problem:
I do not encounter a specific error per se, but when comparing to the non-finetuned Llama-2-7B-Chat-hf, I only get a 2-3% increase in performance from 22% to 25% exact match on GSM8K. I have seen other works achieve much more (e.g. neuralmagic model even gets 37%). I would expect a better performance after finetuning, especially for the larger 7B and 8B models.
My questions:
- Is my chat templating ok?
- Are there any severely wrong settings?
- Is evaluation using lm-eval-harness the right way to go?
I have exhausted hyperparameter tuning (1-7 epochs, LR from 1e-4 to 1e-5, several batch sizes and gradient accumulations, etc.). I am finetuning on 8xA100 (80GB) GPUs.
I would appreciate any help and suggestions! Currently, I am exhausted and cannot seem to find the error, if any.
Thank you!