Doing inference with FSDP during training affects checkpointing

Hello! I’m running into an issue with checkpoints saved when training an LLM with FSDP and the default HuggingFace trainer, if I also do inference during training. I provide code at the end of this post for clarity.

What I’m trying to achieve

I want to write a callback to monitor model outputs on a validation set throughout the training process. This requires doing inference with model.generate(). Since I’m also using FSDP, I need to summon all weights on a single device, as described in this Github issue.

My issue

The callback I provide below seems to work fine for evaluation, but it affects the checkpoints that get saved. Specifically, when unsharding the final checkpoint and trying to replicate the results I see from my training script, I get different, much worse results from the checkpoint.

To test this, I trained an LLM to memorize a simple phrase: “Two times 10 equals 20.”. At the end of training, my callback reports the completions I expect, meaning the model trained well. However, if I load the checkpoint from the disk and feed it the same prompts, I get this:

# With callback
# Outputs from the training script, after training.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "               "
"Two times"           -> "equals               "
"Two times 10"        -> "               "
"Two times 10 equals" -> "               "

This does not happen if I don’t run the callback during training. If I remove it, the checkpoint produced outputs the expected results:

# Without callback
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."

To make extra sure, I also tried this experiment with DDP instead of FSDP (I removed the summon instruction). The DDP checkpoint is correct regardless of using my callback or not.

# With DDP
# Outputs from the training script, after training.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two"                 -> "times 10 equals 20."
"Two times"           -> "10 equals 20."
"Two times 10"        -> "equals 20."
"Two times 10 equals" -> "20."

I believe this points to summon_full_params being the problem. Do you think this could be a problem with the library, or maybe with my implementation? Any ideas or advice? Thank you!

Minimal example

main.py
from typing import cast

import accelerate
import datasets
import torch
import transformers
from torch.distributed import fsdp


class ValidCallback(transformers.TrainerCallback):
    def __init__(self, tokenizer: transformers.PreTrainedTokenizerBase, dataset: datasets.Dataset) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.dataset = dataset

    def on_epoch_end(
        self,
        args: transformers.TrainingArguments,
        state: transformers.TrainerState,
        control: transformers.TrainerControl,
        **kwargs,
    ) -> None:
        if state.epoch is None or int(state.epoch) % 25 != 0:
            return
        model = cast(transformers.PreTrainedModel, kwargs["model"])
        with torch.no_grad():
            self.run(model)

    @torch.no_grad()
    def run(self, model: transformers.PreTrainedModel) -> None:
        model.eval()

        for batch in self.dataset.iter(batch_size=7):
            encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device)

            with fsdp.FullyShardedDataParallel.summon_full_params(model):
                outputs = model.generate(
                    inputs=encoding.input_ids,
                    attention_mask=encoding.attention_mask,
                    pad_token_id=self.tokenizer.eos_token_id,
                    max_new_tokens=16,
                    do_sample=False,
                )

            predictions = self.tokenizer.batch_decode(
                outputs[:, encoding.input_ids.shape[1] :],  # Skip the returned prompt.
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

            if accelerate.PartialState().is_main_process:
                print(predictions)


def main() -> None:
    # Load model and tokenizer.
    checkpoint = "mistralai/Mistral-7B-v0.3"
    tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
    tokenizer.padding_side = "left"
    if not tokenizer.pad_token:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
    model.resize_token_embeddings(len(tokenizer))

    # Load and prepare a toy dataset.
    def tokenize_function(examples):
        tokenized = tokenizer(examples["text"], max_length=32, padding="max_length", truncation=True)
        tokenized["labels"] = cast(list, tokenized["input_ids"]).copy()
        return tokenized

    train_dataset = datasets.Dataset.from_dict({"text": ["Two times 10 equals 20."] * 100})
    valid_dataset = datasets.Dataset.from_dict(
        {"text": ["Two", "Two times", "Two times 10", "Two times 10 equals", "Two times 10 equals 20."]}
    )
    train_dataset = train_dataset.map(
        tokenize_function, batched=True, remove_columns=list(train_dataset.features)
    )

    # Train.
    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_dataset,
        args=transformers.TrainingArguments(
            output_dir="./output-minimal",
            save_strategy="steps",
            save_steps=1_000_000,
            overwrite_output_dir=True,
            remove_unused_columns=False,
            optim="adamw_torch_fused",
            bf16=True,
            learning_rate=1e-2,
            num_train_epochs=100,
            per_device_train_batch_size=1,
            ddp_timeout=9999999,
            report_to=[],
        ),
        callbacks=[
            ValidCallback(tokenizer, valid_dataset),
        ],
    )
    trainer.train()


if __name__ == "__main__":
    main()
fsdp.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I run my code on Slurm, using this command:

srun bash -c "accelerate launch \
    --config_file fsdp.yaml \
    --main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \
    --main_process_port 6000 \
    --machine_rank \$SLURM_PROCID \
    main.py"

To unshard an FSDP checkpoint, I use the code from this forum post.

Relevant environment and library versions:

Linux Debian 6.1.90-1
CUDA version: 12.4

accelerate==1.0.1
datasets==3.0.1
torch==2.4.1
torchaudio==2.4.1
torchvision==0.19.1
transformers==4.45.2
2 Likes

Hello! Unfortunately, I don’t have a solution for this problem, but I was able to reproduce your findings and did a quick weight comparison test.

After unsharding the model, only the *.safetensors files differ, so no weird offsets or other metadata contributing to how the model is loaded.

I was curios if there’s something obviously different between the weights of different training runs, so I started 2 control runs without callback (where inference from saved checkpoints is fine) and 1 test run with the callback (where inference from saved checkpoints is wrong). I compared the weights across corresponding layers with np.linalg.norm.

I didn’t find any notable changes between the runs – nothing particularly stands out for the test run when compared to the control runs, the differences are scattered across multiple layers, as expected. I was hoping to see an isolated set of weights which differ significantly.

Maybe more test runs are needed to account for the variance or a more sophisticated weight comparison method, but in the end, not sure if this route is too productive.

Not much, but hope it helps!

2 Likes