Hi everyone,
I’m fine-tuning GPT-Neo on the GSM8K dataset and evaluating its performance using the lm-eval-harness. After training, I’ve noticed the model does a decent job generating reasoning steps, but still performs poorly on arithmetic accuracy — with correctness hovering around 2–3% on the test set.
Here’s an example where the reasoning is there, but the final answer is incorrect:
Example 1
Prompt:
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers’ market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers’ market?
Answer: Let’s think step by step.
Fine-tuned model output:
1. Janet eats three eggs for breakfast every day.
2. She bakes muffins for her friends every day with four.
3. She sells the remainder at the farmers' market daily for $2 per fresh duck egg.
4. She makes $16 per day at the farmers' market.
So the final answer is:
4
Base (pretrained) GPT-Neo output:
– Base pretrained model output raw –
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers’ market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers’ market?
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers’ market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers’ market?
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers’ market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers
Repeats the question multiple times without a coherent answer.
The fine-tuned model is clearly better than the base, but still fails the actual math.
Training Details
training_args = TrainingArguments(
output_dir="/kaggle/working/fine_tuned_gptneo_gsm8k",
overwrite_output_dir=True,
do_train=True,
do_eval=True,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
gradient_checkpointing=False,
fp16=True,
learning_rate=5e-5,
warmup_steps=200,
weight_decay=0.01,
max_grad_norm=1.0,
num_train_epochs=5,
logging_dir="./logs",
logging_steps=5,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
Student–Teacher Approach
To improve reasoning and correctness, I’m experimenting with a student–teacher setup. A larger LLM (Mistral 7B in this case) generates chain-of-thought (CoT) explanations, which I hope to use as improved labels for re-training.
Here’s an example of the teacher model’s output:
Example 2 (Teacher-generated CoT)
Prompt:
Question: Every Sunday John is allowed to watch cartoons. However, for every 10 minutes of cartoons, he has to do 8 minutes of chores. If he watches cartoons for 2 hours, how many minutes of chores does he have to do?
Answer: Let’s think step by step.
Teacher Output raw:
First, we need to find out how many minutes of cartoons John watches in 2 hours. We know that 1 hour has 60 minutes, so 2 hours have 120 minutes. Now, we know that for every 10 minutes of cartoons, John has to do 8 minutes of chores. So, we can set up a proportion: 10 minutes of cartoons : 8 minutes of chores = 120 minutes of cartoons : x minutes of chores. Solving for x, we get x = (120 minutes * 8 minutes) / 10 minutes = 96 minutes. Therefore, John has to do 96 minutes of chores for watching cartoons for 2 hours.
I’d like to replace the original CoT labels in GSM8K with these more natural explanations, but only if the final answer matches the ground truth. However, parsing out the final answer from freeform CoTs is tricky.
Does anyone have advice on how to extract and verify the final numeric answer reliably? I was thinking something along the lines of regex, but I want to see if there is something more consistent since there can be multiple numeric values in the output.
Also, are there other methods you’d recommend to improve arithmetic reasoning beyond CoT distillation? Anything I can tweak in the hyper params?
Any feedback would be great.
Thanks for your time!