Bad Performance Finetuning Llama Chat and Instruct Models on GSM8K

Hi all!

I’m having an extremely difficult time finetuning Llama-2-7B-Chat and Llama-3.1-8B-Instruct models on the GSM8K dataset. I have spent a week optimizing hyperparameters but never seem to obtain satisfactory results when evaluating on GSM8K using lm-eval-harness.

I have trained LLMs before, so this really frustrates me at this point. I am wondering whether there is an error in my finetuning script and/or the way how I evaluate and process it.

What I want to do:
Specifically, I want to finetune using PEFT/LoRA without the overhead of unsloth and other finetuning frameworks. The goal is to investigate a research question targeting Chat/Instruct models.

How I do it:
Here is my finetuning script for Llama-2-7B-Chat-HF:

import functools
from typing import Dict, Any

import yaml
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    set_seed
)
from trl import SFTTrainer, SFTConfig

def chat_format(example: Dict[str, Any], tokenizer) -> Dict[str, str]:
    """Format the example to include the question and answer in the text field."""
    prompt = example['question']
    answer = example['answer']
    chat = [
        {"role": "system", "content": "You are a helpful assistant."},
        {'role': 'user', 'content': prompt},
        {'role': 'assistant', 'content': answer}
    ]
    text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt = False)
    return {'text': text}

def finetune() -> None:
    """Fine-tune the model based on the config."""

    # Seed
    seed = 42
    set_seed(seed)

    # Model and tokenizer
    model_name = "meta-llama/Llama-2-7b-chat-hf"
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Handle pad token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        tokenizer.pad_token = tokenizer.eos_token

    # Load and preprocess dataset
    dataset = load_dataset("openai/gsm8k", "main")
    chat_dataset = dataset.map(functools.partial(chat_format, tokenizer=tokenizer), remove_columns=dataset['train'].column_names)

    # LoRA
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=16,
        lora_dropout=0.01,
        bias="none"
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    # Huggingface
    hf_output_dir = "./results/llama-2-7b-chat-hf"
    push_to_hub = hf_output_dir is not None

    # SFTConfig
    training_args = SFTConfig(
        output_dir=hf_output_dir if push_to_hub else "./results",
        max_seq_length=1024,
        dataset_text_field='text',
        auto_find_batch_size=False,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=16,
        num_train_epochs=3,
        logging_steps=1,
        save_steps=100,
        eval_strategy="steps",
        eval_steps=50,
        bf16=True,
        lr_scheduler_type="cosine",
        learning_rate=5e-5,
        warmup_ratio=0.1,
        weight_decay=0.0,
        push_to_hub=push_to_hub,
        load_best_model_at_end=False,
        ddp_find_unused_parameters=False
    )

    # SFTTrainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=chat_dataset['train'],
        eval_dataset=chat_dataset['test'],
        tokenizer=tokenizer,
        args=training_args
    )

    trainer.train()

    # Evaluation
    print(trainer.evaluate())

    # Save locally
    save_path = "./results/llama-2-7b-chat-hf"
    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)


if __name__ == "__main__":
    finetune()

I have found the current hyperparams (LR, epochs, etc.) to be the sweet spot.

How I finetune:
I finetune using FSDP as follows:

accelerate launch --config_file "configs/fsdp.yaml" --num_processes=8 src/finetune_basic.py --config configs/llama-2-7B-chat-hf-gsm8k.yaml

How I evaluate:
I evaluate using lm-eval-harness (0-shot, gsm8k task)

accelerate launch --num_processes=8 -m lm_eval --model hf --model_args pretrained=meta-llama/Llama-2-7b-chat-hf,peft=ketchup123/llama-2-7b-chat-hf --tasks gsm8k --num_fewshot 0 --batch_size 16 --apply_chat_template

For Llama-3.1-8B instruct, I have used the gsm8k_cot_llama from lm-eval-harness, as suggested.

My problem:
I do not encounter a specific error per se, but when comparing to the non-finetuned Llama-2-7B-Chat-hf, I only get a 2-3% increase in performance from 22% to 25% exact match on GSM8K. I have seen other works achieve much more (e.g. neuralmagic model even gets 37%). I would expect a better performance after finetuning, especially for the larger 7B and 8B models.

My questions:

  1. Is my chat templating ok?
  2. Are there any severely wrong settings?
  3. Is evaluation using lm-eval-harness the right way to go?

I have exhausted hyperparameter tuning (1-7 epochs, LR from 1e-4 to 1e-5, several batch sizes and gradient accumulations, etc.). I am finetuning on 8xA100 (80GB) GPUs.

I would appreciate any help and suggestions! :slight_smile: Currently, I am exhausted and cannot seem to find the error, if any.

Thank you!

1 Like