Hi Accelerate team — first off, thank you for the phenomenal work on the library and its FSDP integration.
I’m running large-scale multi-node jobs and have hit what looks like a systematic resume-instability whenever I save checkpoints with fsdp_state_dict_type: SHARDED_STATE_DICT
.
In short:
- Training is perfectly stable until I interrupt it and resume.
- The very first step after
--resume_from_checkpoint
shows a huge loss spike (see first plot at step ≈ 10 k). - If I change only the state-dict format to
FULL_STATE_DICT
, the same run resumes smoothly with no spike (second plot) — but the wall-clock time and I/O overhead of saving full checkpoints are prohibitive at this scale.
Because the model, data, and training hyper-parameters are identical between the two runs, the culprit seems to be how the sharded checkpoint restores the optimizer state and/or parameter dtypes (I train in bf16, yet the saved shards are fp32).
Below I’ve included system details, a minimal reproducible script, and the exact configs that trigger the problem. I’d be grateful for any insights or work-arounds that would let me keep the sharded format while resuming safely.
Expected behavior: Resumed training should continue smoothly at the previous loss level, identical to a FULL_STATE_DICT
checkpoint, but without the high I/O overhead.
This plot shows the loss of resume_from_checkpoint after saving a checkpoint with SHARDED_STATE_DICT. You can see the loss spike at the 10K point.
This plot is the loss when saving as FULL_STATE_DICT and resuming training with resume_from_checkpoint. See the smoothed connections. But this method takes too long to save.
Questions
-
Is this a known limitation or bug with FSDP SHARDED_STATE_DICT in accelerate?
-
Could the optimizer state be dropped or mismatched while saving/loading shards?
-
Is there a recommended workaround (e.g., enabling fsdp_cpu_ram_efficient_loading or changing another flag) that lets me keep sharded checkpoints yet resume stably?
-
Why are the weights serialized in fp32 even though mixed_precision=bf16 was used during training?
-
Any pointers would be greatly appreciated—thanks for your time and for all the work on Accelerate!
System Info
AWS p5en.48-xlarge, 16 nodes
Accelerate
version: 1.8.1- Platform: Linux-6.8.0-1028-aws-x86_64-with-glibc2.35
- Python version: 3.12.11
- Numpy version: 2.3.1
- PyTorch version: 2.7.0+cu126
- PyTorch accelerator: CUDA
- System RAM: 1999.95 GB
- GPU type: NVIDIA H200
Accelerate configs
distributed_type: FSDP
mixed_precision: bf16
fsdp_config:
fsdp_state_dict_type: SHARDED_STATE_DICT
offload_to_cpu: false
save_optimizer_state: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: PreNormDecoderLayer
fsdp_activation_checkpointing: false
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_sync_module_states: true
fsdp_use_orig_params: false
Reproduction
- Train for a few hundred steps with the first config until the loss plateaus.
- Allow accelerate to save an automatic checkpoint.
- Relaunch the same command with --resume_from_checkpoint path/to/latest.
- Observe the immediate loss jump on the next optimizer step.
train.py
...
if train_args.resume_from_checkpoint:
trainer.accelerator.load_state(train_args.resume_from_checkpoint)
trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint)
else:
trainer.train()