Sudden Loss Drop and Poor Performance During Model Training
Issue Description
Iβm encountering a perplexing issue during the training of a model using the provided script. During the first epoch, both the training and evaluation losses drop dramatically, but the model performs poorly when tested. This suggests a potential bug in the code, configuration, or the libraries used. I would greatly appreciate any insights or guidance from the community to diagnose and resolve this issue.
Training Script
Below is the complete training script used:
import os
import torch
from datetime import datetime
from datasets import load_from_disk
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from trl import SFTTrainer, SFTConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Model and training parameters
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
tokenizer.pad_token = "<|end_of_text|>"
tokenizer.padding_side = "right"
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_safetensors=True,
torch_dtype=torch.bfloat16,
use_cache=False
)
# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
max_seq_len = 5000
num_epochs = 2
output_dir = "outputs"
# Load and process dataset
proc_ds = load_from_disk('updated_cleaned_sft_ds_350k')
class CustomSFTTrainer(SFTTrainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute training loss with emphasis on EOS token, and additionally compute token accuracies.
"""
mode = "eval" if self.control.should_evaluate else "train"
# Token counting logic
if mode == "train":
if "attention_mask" in inputs:
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
elif "position_ids" in inputs:
num_tokens_in_batch = (
self.accelerator.gather_for_metrics(torch.tensor(inputs["position_ids"].size(1))).sum().item()
)
else:
raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
self._total_train_tokens += num_tokens_in_batch
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
# Process logits and labels
if "labels" in inputs and not self.args.use_liger_kernel:
# Get base outputs from model
(base_loss, outputs) = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()
eos_token_id = self.processing_class.eos_token_id
# Log sample predictions and actual text
if self.state.global_step % 100 == 0:
sample_indices = list(range(min(3, shift_labels.size(0))))
for idx in sample_indices:
predictions = shift_logits[idx].argmax(dim=-1)
actual_labels = shift_labels[idx]
inputs_label = inputs["input_ids"][idx]
mask = actual_labels != -100
pred_tokens = predictions.cpu().tolist()
actual_tokens = actual_labels.cpu().tolist()
inputs_label = inputs_label.cpu().tolist()
try:
pred_text = self.processing_class.decode(pred_tokens)
actual_text = self.processing_class.decode(actual_tokens)
with open("model_predictions.txt", "a") as f:
f.write(f"\n\nStep {self.state.global_step}, Sample {idx}:\n")
f.write(f"Predicted text: {pred_text}\n")
f.write(f"Actual text: {actual_text}\n")
f.write("-" * 50 + "\n")
except Exception as e:
logger.warning(f"Error decoding tokens: {e}")
# Calculate loss differently based on mode
if mode == "train":
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
per_token_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
eos_mask = (shift_labels.view(-1) == eos_token_id).float()
eos_weight = 2.0
weights = torch.ones_like(per_token_loss) + eos_mask * (eos_weight - 1.0)
loss = (per_token_loss * weights).mean()
else:
loss = base_loss
# Compute token accuracy
predictions = shift_logits.argmax(dim=-1)
mask = shift_labels != -100
correct_predictions = (predictions == shift_labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
total_sum = total_tokens.sum()
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
# Compute EOS-specific accuracy
eos_mask = (shift_labels == eos_token_id)
eos_correct = ((predictions == eos_token_id) & eos_mask).sum()
eos_total = eos_mask.sum()
if eos_total > 0:
eos_accuracy = (eos_correct / eos_total).item()
if f"{mode}_eos_accuracy" not in self._metrics:
self._metrics[mode]["eos_accuracy"] = []
self._metrics[mode]["eos_accuracy"].append(eos_accuracy)
else:
(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
return (loss, outputs) if return_outputs else loss
def format_deepseek_text_for_training(example):
tokens = tokenizer(
example['text'],
max_length=max_seq_len,
padding='max_length',
add_special_tokens=False,
truncation=True
)
tokens['labels'] = tokens['input_ids'].copy()
return tokens
# Prepare dataset
proc_ds = proc_ds.shuffle().select(range(5000))
formatted_ds = proc_ds.map(format_deepseek_text_for_training, remove_columns=proc_ds.column_names)
formatted_ds = formatted_ds.train_test_split(test_size=0.01)
# Training arguments
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.01,
warmup_steps=10,
logging_steps=50,
eval_steps=50,
save_strategy="no",
eval_strategy="steps",
bf16=True,
max_grad_norm=1.0,
report_to="none",
max_seq_length=max_seq_len,
packing=False,
optim="adamw_8bit",
gradient_checkpointing=True,
dataset_text_field="text"
)
# Create the SFT trainer
trainer = CustomSFTTrainer(
model=model,
args=training_args,
train_dataset=formatted_ds['train'],
eval_dataset=formatted_ds['test'],
processing_class=tokenizer,
formatting_func=format_deepseek_text_for_training
)
# Train the model
trainer.train()
# Save the model
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
final_save_path = os.path.join(output_dir, "final-model_v1.15")
trainer.save_model(final_save_path)
tokenizer.save_pretrained(final_save_path)
print(f"Final model saved to {final_save_path}")
Configuration YAML
The following YAML configuration is used for FSDP:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_offload_optimizer_state: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Training Output
The training output shows a sudden drop in loss, which is unexpected:
/home/bello/tts_env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:502: UserWarning: You passed a dataset that is already processed (contains an `input_ids` field) together with a formatting function. Therefore `formatting_func` will be ignored. Either remove the `formatting_func` or pass a dataset that is not already processed.
warnings.warn(
Truncating eval dataset: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:00<00:00, 4730.03 examples/s]
Truncating train dataset: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4950/4950 [00:01<00:00, 3712.92 examples/s]
/home/bello/tts_env/lib/python3.10/site-packages/accelerate/accelerator.py:1731: UserWarning: Upcasted low precision parameters in LlamaForCausalLM because mixed precision turned on in FSDP. Affects: model.embed_tokens.weight, model.norm.weight, lm_head.weight.
warnings.warn(
/home/bello/tts_env/lib/python3.10/site-packages/accelerate/accelerator.py:1737: UserWarning: FSDP upcast of low precision parameters may affect the precision of model checkpoints.
warnings.warn(
15%|βββββββββββββββ | 46/308 [11:16<1:04:06, 14.68s/it]
{'loss': 13.1457, 'grad_norm': 2.065039873123169, 'learning_rate': 8.691275167785235e-06, 'epoch': 0.32}
{'eval_loss': 0.2015465497970581, 'eval_runtime': 6.5752, 'eval_samples_per_second': 7.604, 'eval_steps_per_second': 1.065, 'eval_num_tokens': 16000000.0, 'eval_mean_token_accuracy': 0.9532942261014666, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.32}
{'loss': 0.7535, 'grad_norm': 1.572648286819458, 'learning_rate': 7.013422818791947e-06, 'epoch': 0.65}
{'eval_loss': 0.1875360906124115, 'eval_runtime': 6.7008, 'eval_samples_per_second': 7.462, 'eval_steps_per_second': 1.045, 'eval_num_tokens': 32000000.0, 'eval_mean_token_accuracy': 0.9556053962026324, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.65}
{'loss': 0.7283, 'grad_norm': 1.5819722414016724, 'learning_rate': 5.335570469798658e-06, 'epoch': 0.97}
{'eval_loss': 0.18005725741386414, 'eval_runtime': 6.5648, 'eval_samples_per_second': 7.616, 'eval_steps_per_second': 1.066, 'eval_num_tokens': 48000000.0, 'eval_mean_token_accuracy': 0.9569306799343654, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.97}
{'loss': 0.6623, 'grad_norm': 1.321942925453186, 'learning_rate': 3.6577181208053697e-06, 'epoch': 1.3}
{'eval_loss': 0.17572417855262756, 'eval_runtime': 6.7027, 'eval_samples_per_second': 7.46, 'eval_steps_per_second': 1.044, 'eval_num_tokens': 64160000.0, 'eval_mean_token_accuracy': 0.9581202013151986, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 1.3}
Observations
- The training loss drops from 13.1457 at epoch 0.32 to 0.7535 at epoch 0.65, which seems unusually rapid.
- The evaluation loss also decreases significantly, but the modelβs post-training performance is poor, indicating possible overfitting or an issue with loss computation.
- The
CustomSFTTrainer
class modifies the loss computation to emphasize the EOS token, which might be related to the issue. - Warnings about FSDP upcasting and dataset processing are present, which could indicate configuration or compatibility issues.
Questions
- Could the sudden loss drop be due to an issue in the
compute_loss
method of theCustomSFTTrainer
class, particularly with the EOS token weighting? (But I also try with the normal SFTClass Trainer but also getting same issues) - Are there known issues with FSDP or mixed precision (bf16) that could cause this behavior?
- Could the dataset preprocessing or tokenization be contributing to the problem?
- Any recommendations for debugging or stabilizing the training process?
Environment
- Python: 3.10
- Libraries:
transformers (4.51.3)
,trl (0.17.0)
,datasets(4.51.3)
,torch (2.5.1+cu124)
,accelerate (1.6.0)
- Model:
meta-llama/Meta-Llama-3-8B-Instruct
- Hardware: Local machine with 8 processes (FSDP) Nvidia A100 80gb
Any help or suggestions would be greatly appreciated. Thank you!