Sudden Loss Drop and Poor Performance During Model Training

Sudden Loss Drop and Poor Performance During Model Training

Issue Description

I’m encountering a perplexing issue during the training of a model using the provided script. During the first epoch, both the training and evaluation losses drop dramatically, but the model performs poorly when tested. This suggests a potential bug in the code, configuration, or the libraries used. I would greatly appreciate any insights or guidance from the community to diagnose and resolve this issue.

Training Script

Below is the complete training script used:

import os
import torch
from datetime import datetime
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from trl import SFTTrainer, SFTConfig
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Model and training parameters
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True
)
tokenizer.pad_token = "<|end_of_text|>"
tokenizer.padding_side = "right"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    use_safetensors=True,
    torch_dtype=torch.bfloat16,
    use_cache=False
)

# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()

max_seq_len = 5000
num_epochs = 2
output_dir = "outputs"

# Load and process dataset
proc_ds = load_from_disk('updated_cleaned_sft_ds_350k')

class CustomSFTTrainer(SFTTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute training loss with emphasis on EOS token, and additionally compute token accuracies.
        """
        mode = "eval" if self.control.should_evaluate else "train"

        # Token counting logic
        if mode == "train":
            if "attention_mask" in inputs:
                num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
            elif "position_ids" in inputs:
                num_tokens_in_batch = (
                    self.accelerator.gather_for_metrics(torch.tensor(inputs["position_ids"].size(1))).sum().item()
                )
            else:
                raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
            self._total_train_tokens += num_tokens_in_batch
        self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

        # Process logits and labels
        if "labels" in inputs and not self.args.use_liger_kernel:
            # Get base outputs from model
            (base_loss, outputs) = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
            
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = inputs["labels"][..., 1:].contiguous()
            eos_token_id = self.processing_class.eos_token_id
                               
            # Log sample predictions and actual text
            if self.state.global_step % 100 == 0:
                sample_indices = list(range(min(3, shift_labels.size(0))))
                
                for idx in sample_indices:
                    predictions = shift_logits[idx].argmax(dim=-1)
                    actual_labels = shift_labels[idx]
                    inputs_label = inputs["input_ids"][idx]
                    
                    mask = actual_labels != -100
                    pred_tokens = predictions.cpu().tolist()
                    actual_tokens = actual_labels.cpu().tolist()
                    inputs_label = inputs_label.cpu().tolist()
                    
                    try:
                        pred_text = self.processing_class.decode(pred_tokens)
                        actual_text = self.processing_class.decode(actual_tokens)
                        
                        with open("model_predictions.txt", "a") as f:
                            f.write(f"\n\nStep {self.state.global_step}, Sample {idx}:\n")
                            f.write(f"Predicted text: {pred_text}\n")
                            f.write(f"Actual text: {actual_text}\n")
                            f.write("-" * 50 + "\n")
                        
                    except Exception as e:
                        logger.warning(f"Error decoding tokens: {e}")
            
            # Calculate loss differently based on mode
            if mode == "train":
                loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                per_token_loss = loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)), 
                    shift_labels.view(-1)
                )
                
                eos_mask = (shift_labels.view(-1) == eos_token_id).float()
                eos_weight = 2.0
                weights = torch.ones_like(per_token_loss) + eos_mask * (eos_weight - 1.0)
                loss = (per_token_loss * weights).mean()
                
            else:
                loss = base_loss
            
            # Compute token accuracy
            predictions = shift_logits.argmax(dim=-1)
            mask = shift_labels != -100
            correct_predictions = (predictions == shift_labels) & mask
            total_tokens = mask.sum()
            correct_tokens = correct_predictions.sum()
            
            correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
            total_tokens = self.accelerator.gather_for_metrics(total_tokens)
            
            total_sum = total_tokens.sum()
            accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
            self._metrics[mode]["mean_token_accuracy"].append(accuracy)

            # Compute EOS-specific accuracy
            eos_mask = (shift_labels == eos_token_id)
            eos_correct = ((predictions == eos_token_id) & eos_mask).sum()
            eos_total = eos_mask.sum()
            
            if eos_total > 0:
                eos_accuracy = (eos_correct / eos_total).item()
                if f"{mode}_eos_accuracy" not in self._metrics:
                    self._metrics[mode]["eos_accuracy"] = []
                self._metrics[mode]["eos_accuracy"].append(eos_accuracy)

        else:
            (loss, outputs) = super().compute_loss(
                model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
            )
        
        return (loss, outputs) if return_outputs else loss
    
def format_deepseek_text_for_training(example):
    tokens = tokenizer(
        example['text'],
        max_length=max_seq_len,
        padding='max_length',
        add_special_tokens=False,
        truncation=True
    )
    
    tokens['labels'] = tokens['input_ids'].copy()
    return tokens

# Prepare dataset
proc_ds = proc_ds.shuffle().select(range(5000))
formatted_ds = proc_ds.map(format_deepseek_text_for_training, remove_columns=proc_ds.column_names)
formatted_ds = formatted_ds.train_test_split(test_size=0.01)

# Training arguments
training_args = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    weight_decay=0.01,
    warmup_steps=10,
    logging_steps=50,
    eval_steps=50,
    save_strategy="no",
    eval_strategy="steps",
    bf16=True,
    max_grad_norm=1.0,
    report_to="none",
    max_seq_length=max_seq_len,
    packing=False,
    optim="adamw_8bit",
    gradient_checkpointing=True,
    dataset_text_field="text"
)

# Create the SFT trainer
trainer = CustomSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_ds['train'],
    eval_dataset=formatted_ds['test'],
    processing_class=tokenizer,
    formatting_func=format_deepseek_text_for_training
)

# Train the model
trainer.train()

# Save the model
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

final_save_path = os.path.join(output_dir, "final-model_v1.15")
trainer.save_model(final_save_path)
tokenizer.save_pretrained(final_save_path)
print(f"Final model saved to {final_save_path}")

Configuration YAML

The following YAML configuration is used for FSDP:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_offload_optimizer_state: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Training Output

The training output shows a sudden drop in loss, which is unexpected:

/home/bello/tts_env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:502: UserWarning: You passed a dataset that is already processed (contains an `input_ids` field) together with a formatting function. Therefore `formatting_func` will be ignored. Either remove the `formatting_func` or pass a dataset that is not already processed.
  warnings.warn(
Truncating eval dataset: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:00<00:00, 4730.03 examples/s]
Truncating train dataset: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4950/4950 [00:01<00:00, 3712.92 examples/s]
/home/bello/tts_env/lib/python3.10/site-packages/accelerate/accelerator.py:1731: UserWarning: Upcasted low precision parameters in LlamaForCausalLM because mixed precision turned on in FSDP. Affects: model.embed_tokens.weight, model.norm.weight, lm_head.weight.
  warnings.warn(
/home/bello/tts_env/lib/python3.10/site-packages/accelerate/accelerator.py:1737: UserWarning: FSDP upcast of low precision parameters may affect the precision of model checkpoints.
  warnings.warn(
 15%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                                      | 46/308 [11:16<1:04:06, 14.68s/it]
{'loss': 13.1457, 'grad_norm': 2.065039873123169, 'learning_rate': 8.691275167785235e-06, 'epoch': 0.32}
{'eval_loss': 0.2015465497970581, 'eval_runtime': 6.5752, 'eval_samples_per_second': 7.604, 'eval_steps_per_second': 1.065, 'eval_num_tokens': 16000000.0, 'eval_mean_token_accuracy': 0.9532942261014666, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.32}
{'loss': 0.7535, 'grad_norm': 1.572648286819458, 'learning_rate': 7.013422818791947e-06, 'epoch': 0.65}
{'eval_loss': 0.1875360906124115, 'eval_runtime': 6.7008, 'eval_samples_per_second': 7.462, 'eval_steps_per_second': 1.045, 'eval_num_tokens': 32000000.0, 'eval_mean_token_accuracy': 0.9556053962026324, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.65}
{'loss': 0.7283, 'grad_norm': 1.5819722414016724, 'learning_rate': 5.335570469798658e-06, 'epoch': 0.97}
{'eval_loss': 0.18005725741386414, 'eval_runtime': 6.5648, 'eval_samples_per_second': 7.616, 'eval_steps_per_second': 1.066, 'eval_num_tokens': 48000000.0, 'eval_mean_token_accuracy': 0.9569306799343654, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 0.97}
{'loss': 0.6623, 'grad_norm': 1.321942925453186, 'learning_rate': 3.6577181208053697e-06, 'epoch': 1.3}
{'eval_loss': 0.17572417855262756, 'eval_runtime': 6.7027, 'eval_samples_per_second': 7.46, 'eval_steps_per_second': 1.044, 'eval_num_tokens': 64160000.0, 'eval_mean_token_accuracy': 0.9581202013151986, 'eval_eos_accuracy': 0.800000011920929, 'epoch': 1.3}

Observations

  • The training loss drops from 13.1457 at epoch 0.32 to 0.7535 at epoch 0.65, which seems unusually rapid.
  • The evaluation loss also decreases significantly, but the model’s post-training performance is poor, indicating possible overfitting or an issue with loss computation.
  • The CustomSFTTrainer class modifies the loss computation to emphasize the EOS token, which might be related to the issue.
  • Warnings about FSDP upcasting and dataset processing are present, which could indicate configuration or compatibility issues.

Questions

  1. Could the sudden loss drop be due to an issue in the compute_loss method of the CustomSFTTrainer class, particularly with the EOS token weighting? (But I also try with the normal SFTClass Trainer but also getting same issues)
  2. Are there known issues with FSDP or mixed precision (bf16) that could cause this behavior?
  3. Could the dataset preprocessing or tokenization be contributing to the problem?
  4. Any recommendations for debugging or stabilizing the training process?

Environment

  • Python: 3.10
  • Libraries: transformers (4.51.3), trl (0.17.0), datasets(4.51.3), torch (2.5.1+cu124), accelerate (1.6.0)
  • Model: meta-llama/Meta-Llama-3-8B-Instruct
  • Hardware: Local machine with 8 processes (FSDP) Nvidia A100 80gb

Any help or suggestions would be greatly appreciated. Thank you!

1 Like