Hello! I’m running into an issue with checkpoints saved when training an LLM with FSDP and the default HuggingFace trainer, if I also do inference during training. I provide code at the end of this post for clarity.
What I’m trying to achieve
I want to write a callback to monitor model outputs on a validation set throughout the training process. This requires doing inference with model.generate()
. Since I’m also using FSDP, I need to summon all weights on a single device, as described in this Github issue.
My issue
The callback I provide below seems to work fine for evaluation, but it affects the checkpoints that get saved. Specifically, when unsharding the final checkpoint and trying to replicate the results I see from my training script, I get different, much worse results from the checkpoint.
To test this, I trained an LLM to memorize a simple phrase: “Two times 10 equals 20.”. At the end of training, my callback reports the completions I expect, meaning the model trained well. However, if I load the checkpoint from the disk and feed it the same prompts, I get this:
# With callback
# Outputs from the training script, after training.
"Two" -> "times 10 equals 20."
"Two times" -> "10 equals 20."
"Two times 10" -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two" -> " "
"Two times" -> "equals "
"Two times 10" -> " "
"Two times 10 equals" -> " "
This does not happen if I don’t run the callback during training. If I remove it, the checkpoint produced outputs the expected results:
# Without callback
# Outputs from the checkpoint loaded from disk.
"Two" -> "times 10 equals 20."
"Two times" -> "10 equals 20."
"Two times 10" -> "equals 20."
"Two times 10 equals" -> "20."
To make extra sure, I also tried this experiment with DDP instead of FSDP (I removed the summon instruction). The DDP checkpoint is correct regardless of using my callback or not.
# With DDP
# Outputs from the training script, after training.
"Two" -> "times 10 equals 20."
"Two times" -> "10 equals 20."
"Two times 10" -> "equals 20."
"Two times 10 equals" -> "20."
# Outputs from the checkpoint loaded from disk.
"Two" -> "times 10 equals 20."
"Two times" -> "10 equals 20."
"Two times 10" -> "equals 20."
"Two times 10 equals" -> "20."
I believe this points to summon_full_params
being the problem. Do you think this could be a problem with the library, or maybe with my implementation? Any ideas or advice? Thank you!
Minimal example
main.py
from typing import cast
import accelerate
import datasets
import torch
import transformers
from torch.distributed import fsdp
class ValidCallback(transformers.TrainerCallback):
def __init__(self, tokenizer: transformers.PreTrainedTokenizerBase, dataset: datasets.Dataset) -> None:
super().__init__()
self.tokenizer = tokenizer
self.dataset = dataset
def on_epoch_end(
self,
args: transformers.TrainingArguments,
state: transformers.TrainerState,
control: transformers.TrainerControl,
**kwargs,
) -> None:
if state.epoch is None or int(state.epoch) % 25 != 0:
return
model = cast(transformers.PreTrainedModel, kwargs["model"])
with torch.no_grad():
self.run(model)
@torch.no_grad()
def run(self, model: transformers.PreTrainedModel) -> None:
model.eval()
for batch in self.dataset.iter(batch_size=7):
encoding = self.tokenizer(batch["text"], return_tensors="pt", padding=True).to(model.device)
with fsdp.FullyShardedDataParallel.summon_full_params(model):
outputs = model.generate(
inputs=encoding.input_ids,
attention_mask=encoding.attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=16,
do_sample=False,
)
predictions = self.tokenizer.batch_decode(
outputs[:, encoding.input_ids.shape[1] :], # Skip the returned prompt.
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
if accelerate.PartialState().is_main_process:
print(predictions)
def main() -> None:
# Load model and tokenizer.
checkpoint = "mistralai/Mistral-7B-v0.3"
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
# Load and prepare a toy dataset.
def tokenize_function(examples):
tokenized = tokenizer(examples["text"], max_length=32, padding="max_length", truncation=True)
tokenized["labels"] = cast(list, tokenized["input_ids"]).copy()
return tokenized
train_dataset = datasets.Dataset.from_dict({"text": ["Two times 10 equals 20."] * 100})
valid_dataset = datasets.Dataset.from_dict(
{"text": ["Two", "Two times", "Two times 10", "Two times 10 equals", "Two times 10 equals 20."]}
)
train_dataset = train_dataset.map(
tokenize_function, batched=True, remove_columns=list(train_dataset.features)
)
# Train.
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
args=transformers.TrainingArguments(
output_dir="./output-minimal",
save_strategy="steps",
save_steps=1_000_000,
overwrite_output_dir=True,
remove_unused_columns=False,
optim="adamw_torch_fused",
bf16=True,
learning_rate=1e-2,
num_train_epochs=100,
per_device_train_batch_size=1,
ddp_timeout=9999999,
report_to=[],
),
callbacks=[
ValidCallback(tokenizer, valid_dataset),
],
)
trainer.train()
if __name__ == "__main__":
main()
fsdp.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
I run my code on Slurm, using this command:
srun bash -c "accelerate launch \
--config_file fsdp.yaml \
--main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \
--main_process_port 6000 \
--machine_rank \$SLURM_PROCID \
main.py"
To unshard an FSDP checkpoint, I use the code from this forum post.
Relevant environment and library versions:
Linux Debian 6.1.90-1
CUDA version: 12.4
accelerate==1.0.1
datasets==3.0.1
torch==2.4.1
torchaudio==2.4.1
torchvision==0.19.1
transformers==4.45.2