Loss spike when resuming from FSDP SHARDED_STATE_DICT checkpoint (possible optimizer-state mismatch)

Hi Accelerate team — first off, thank you for the phenomenal work on the library and its FSDP integration.
I’m running large-scale multi-node jobs and have hit what looks like a systematic resume-instability whenever I save checkpoints with fsdp_state_dict_type: SHARDED_STATE_DICT.

In short:

  • Training is perfectly stable until I interrupt it and resume.
  • The very first step after --resume_from_checkpoint shows a huge loss spike (see first plot at step ≈ 10 k).
  • If I change only the state-dict format to FULL_STATE_DICT, the same run resumes smoothly with no spike (second plot) — but the wall-clock time and I/O overhead of saving full checkpoints are prohibitive at this scale.

Because the model, data, and training hyper-parameters are identical between the two runs, the culprit seems to be how the sharded checkpoint restores the optimizer state and/or parameter dtypes (I train in bf16, yet the saved shards are fp32).

Below I’ve included system details, a minimal reproducible script, and the exact configs that trigger the problem. I’d be grateful for any insights or work-arounds that would let me keep the sharded format while resuming safely.

Expected behavior: Resumed training should continue smoothly at the previous loss level, identical to a FULL_STATE_DICT checkpoint, but without the high I/O overhead.

This plot shows the loss of resume_from_checkpoint after saving a checkpoint with SHARDED_STATE_DICT. You can see the loss spike at the 10K point.

This plot is the loss when saving as FULL_STATE_DICT and resuming training with resume_from_checkpoint. See the smoothed connections. But this method takes too long to save.

Questions

  1. Is this a known limitation or bug with FSDP SHARDED_STATE_DICT in accelerate?

  2. Could the optimizer state be dropped or mismatched while saving/loading shards?

  3. Is there a recommended workaround (e.g., enabling fsdp_cpu_ram_efficient_loading or changing another flag) that lets me keep sharded checkpoints yet resume stably?

  4. Why are the weights serialized in fp32 even though mixed_precision=bf16 was used during training?

  5. Any pointers would be greatly appreciated—thanks for your time and for all the work on Accelerate!

System Info

AWS p5en.48-xlarge, 16 nodes

  • Accelerate version: 1.8.1
  • Platform: Linux-6.8.0-1028-aws-x86_64-with-glibc2.35
  • Python version: 3.12.11
  • Numpy version: 2.3.1
  • PyTorch version: 2.7.0+cu126
  • PyTorch accelerator: CUDA
  • System RAM: 1999.95 GB
  • GPU type: NVIDIA H200

Accelerate configs
distributed_type: FSDP
mixed_precision: bf16
fsdp_config:
fsdp_state_dict_type: SHARDED_STATE_DICT
offload_to_cpu: false
save_optimizer_state: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: PreNormDecoderLayer
fsdp_activation_checkpointing: false
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_sync_module_states: true
fsdp_use_orig_params: false

Reproduction

  • Train for a few hundred steps with the first config until the loss plateaus.
  • Allow accelerate to save an automatic checkpoint.
  • Relaunch the same command with --resume_from_checkpoint path/to/latest.
  • Observe the immediate loss jump on the next optimizer step.

train.py

...
if train_args.resume_from_checkpoint:
    trainer.accelerator.load_state(train_args.resume_from_checkpoint)
    trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint)

else:
    trainer.train()
1 Like

4

I know the known issues regarding this.