Loss spike when resuming from FSDP SHARDED_STATE_DICT checkpoint (possible optimizer-state mismatch)

Hi Accelerate team — first off, thank you for the phenomenal work on the library and its FSDP integration.
I’m running large-scale multi-node jobs and have hit what looks like a systematic resume-instability whenever I save checkpoints with fsdp_state_dict_type: SHARDED_STATE_DICT.

In short:

  • Training is perfectly stable until I interrupt it and resume.
  • The very first step after --resume_from_checkpoint shows a huge loss spike (see first plot at step ≈ 10 k).
  • If I change only the state-dict format to FULL_STATE_DICT, the same run resumes smoothly with no spike (second plot) — but the wall-clock time and I/O overhead of saving full checkpoints are prohibitive at this scale.

Because the model, data, and training hyper-parameters are identical between the two runs, the culprit seems to be how the sharded checkpoint restores the optimizer state and/or parameter dtypes (I train in bf16, yet the saved shards are fp32).

Below I’ve included system details, a minimal reproducible script, and the exact configs that trigger the problem. I’d be grateful for any insights or work-arounds that would let me keep the sharded format while resuming safely.

Expected behavior: Resumed training should continue smoothly at the previous loss level, identical to a FULL_STATE_DICT checkpoint, but without the high I/O overhead.

This plot shows the loss of resume_from_checkpoint after saving a checkpoint with SHARDED_STATE_DICT. You can see the loss spike at the 10K point.

This plot is the loss when saving as FULL_STATE_DICT and resuming training with resume_from_checkpoint. See the smoothed connections. But this method takes too long to save.

Questions

  1. Is this a known limitation or bug with FSDP SHARDED_STATE_DICT in accelerate?

  2. Could the optimizer state be dropped or mismatched while saving/loading shards?

  3. Is there a recommended workaround (e.g., enabling fsdp_cpu_ram_efficient_loading or changing another flag) that lets me keep sharded checkpoints yet resume stably?

  4. Why are the weights serialized in fp32 even though mixed_precision=bf16 was used during training?

  5. Any pointers would be greatly appreciated—thanks for your time and for all the work on Accelerate!

System Info

AWS p5en.48-xlarge, 16 nodes

  • Accelerate version: 1.8.1
  • Platform: Linux-6.8.0-1028-aws-x86_64-with-glibc2.35
  • Python version: 3.12.11
  • Numpy version: 2.3.1
  • PyTorch version: 2.7.0+cu126
  • PyTorch accelerator: CUDA
  • System RAM: 1999.95 GB
  • GPU type: NVIDIA H200

Accelerate configs
distributed_type: FSDP
mixed_precision: bf16
fsdp_config:
fsdp_state_dict_type: SHARDED_STATE_DICT
offload_to_cpu: false
save_optimizer_state: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: PreNormDecoderLayer
fsdp_activation_checkpointing: false
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_sync_module_states: true
fsdp_use_orig_params: false

Reproduction

  • Train for a few hundred steps with the first config until the loss plateaus.
  • Allow accelerate to save an automatic checkpoint.
  • Relaunch the same command with --resume_from_checkpoint path/to/latest.
  • Observe the immediate loss jump on the next optimizer step.

train.py

...
if train_args.resume_from_checkpoint:
    trainer.accelerator.load_state(train_args.resume_from_checkpoint)
    trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint)

else:
    trainer.train()
1 Like

4

I know the known issues regarding this.

did you address the problem?
I encountered the same problem

1 Like

I couldn’t figure out how to resolve the issue, so I ended up saving it using `FULL_STATE_DICT`, even though it took more time.

1 Like

Using FULL_STATE_DICT seems simpler…?


Best practice: keep sharded checkpoints and restore the optimizer state via FSDP’s APIs, preferably using Distributed Checkpoint (DCP) for load-time resharding. Avoid FULL_STATE_DICT except for export.

Checklist

  • Save under StateDictType.SHARDED_STATE_DICT and capture the optimizer with FSDP:

    # refs:
    # https://pytorch.org/docs/stable/distributed.checkpoint.html
    # https://pytorch.org/docs/stable/fsdp.html
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    import torch.distributed.checkpoint as dcp
    
    to_save = {
        "model": model.state_dict(),                                 # sharded
        "optim": FSDP.optim_state_dict(model, optimizer),            # mapped for FSDP
    }
    dcp.save_state_dict(to_save, checkpoint_id=ckpt_dir)             # DCP writes per-rank
    

    DCP handles parallel IO and future resharding; docs last updated 2025-06-16. (docs.pytorch.org)

  • Load with DCP and remap the optimizer state for the current shards before calling optimizer.load_state_dict:

    to_load = {"model": model.state_dict(), "optim": {}}
    dcp.load_state_dict(to_load, checkpoint_id=ckpt_dir)             # reshard-safe
    optim_sd = FSDP.optim_state_dict_to_load(model, optimizer, to_load["optim"])
    optimizer.load_state_dict(optim_sd)                               # now consistent
    

    These FSDP APIs exist exactly to prevent optimizer-state mismatches on resume. (docs.pytorch.org)

  • Accelerate path: keep the model sharded but make the optimizer state “full” via the FSDP plugin (optim_state_dict_config=FullOptimStateDictConfig). This simplifies restore while preserving sharded weights. (Hugging Face)

  • Keep invariants identical across save→resume: use_orig_params, auto-wrap policy, and module wrapping. Changing them changes parameter identities and breaks mapping. (Hugging Face)

  • Elasticity: if world size or topology may change, rely on DCP for load-time resharding. Don’t use plain sharded loads across different topologies. (docs.pytorch.org)

  • Verify after load: spot-check a few param groups for state['step'] and moment tensor shapes; mismatches indicate a bad mapping. FSDP docs define the mapping API expectations. (docs.pytorch.org)

  • Known caveat: optim_state_dict_to_load can be memory-spiky with use_orig_params=True; consider CPU staging or use_orig_params=False if you hit OOM. (GitHub)

Why this is the best practice for your case

  • The HF thread shows loss spikes when resuming from sharded checkpoints but not from FULL_STATE_DICT. The cause is consistent with a mismapped or partial optimizer state. The remedy is the FSDP optimizer-state mapping or using DCP to reshard and restore correctly.

Minimal Accelerate example

# docs:
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
# https://huggingface.co/docs/accelerate/en/package_reference/fsdp
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.fsdp import FullOptimStateDictConfig

plugin = FullyShardedDataParallelPlugin(
    state_dict_type=StateDictType.SHARDED_STATE_DICT,
    optim_state_dict_config=FullOptimStateDictConfig(),  # full optimizer, sharded model
    use_orig_params=True,                                # keep consistent across runs
)
accelerator = Accelerator(fsdp_plugin=plugin)
# prepare(...) as usual; then accelerator.save_state(...) / load_state(...)

This pattern keeps sharded model files fast to write and read, yet prevents the resume spike by restoring a consistent optimizer state. (Hugging Face)

Supplemental reading

  • PyTorch DCP: load-time resharding, per-rank files. Updated 2025-06-16. Useful when world size changes. (docs.pytorch.org)
  • FSDP docs: optim_state_dict and optim_state_dict_to_load APIs. Canonical reference for correct optimizer restore. (docs.pytorch.org)
  • HF Accelerate FSDP guide + API: plugin knobs, state-dict configs, FSDP utilities. Practical knobs for the recipe above. (Hugging Face)
  • Related issues: reports of resume failures or spikes when optimizer mapping is wrong or topology changes. Confirms the failure mode. (GitHub)