Hi Accelerate team — first off, thank you for the phenomenal work on the library and its FSDP integration.
I’m running large-scale multi-node jobs and have hit what looks like a systematic resume-instability whenever I save checkpoints with fsdp_state_dict_type: SHARDED_STATE_DICT.
In short:
Training is perfectly stable until I interrupt it and resume.
The very first step after --resume_from_checkpoint shows a huge loss spike (see first plot at step ≈ 10 k).
If I change only the state-dict format to FULL_STATE_DICT, the same run resumes smoothly with no spike (second plot) — but the wall-clock time and I/O overhead of saving full checkpoints are prohibitive at this scale.
Because the model, data, and training hyper-parameters are identical between the two runs, the culprit seems to be how the sharded checkpoint restores the optimizer state and/or parameter dtypes (I train in bf16, yet the saved shards are fp32).
Below I’ve included system details, a minimal reproducible script, and the exact configs that trigger the problem. I’d be grateful for any insights or work-arounds that would let me keep the sharded format while resuming safely.
Expected behavior: Resumed training should continue smoothly at the previous loss level, identical to a FULL_STATE_DICT checkpoint, but without the high I/O overhead.
This plot is the loss when saving as FULL_STATE_DICT and resuming training with resume_from_checkpoint. See the smoothed connections. But this method takes too long to save.
Questions
Is this a known limitation or bug with FSDP SHARDED_STATE_DICT in accelerate?
Could the optimizer state be dropped or mismatched while saving/loading shards?
Is there a recommended workaround (e.g., enabling fsdp_cpu_ram_efficient_loading or changing another flag) that lets me keep sharded checkpoints yet resume stably?
Why are the weights serialized in fp32 even though mixed_precision=bf16 was used during training?
Any pointers would be greatly appreciated—thanks for your time and for all the work on Accelerate!
Best practice: keep sharded checkpoints and restore the optimizer state via FSDP’s APIs, preferably using Distributed Checkpoint (DCP) for load-time resharding. Avoid FULL_STATE_DICT except for export.
Checklist
Save under StateDictType.SHARDED_STATE_DICT and capture the optimizer with FSDP:
# refs:
# https://pytorch.org/docs/stable/distributed.checkpoint.html
# https://pytorch.org/docs/stable/fsdp.html
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed.checkpoint as dcp
to_save = {
"model": model.state_dict(), # sharded
"optim": FSDP.optim_state_dict(model, optimizer), # mapped for FSDP
}
dcp.save_state_dict(to_save, checkpoint_id=ckpt_dir) # DCP writes per-rank
DCP handles parallel IO and future resharding; docs last updated 2025-06-16. (docs.pytorch.org)
Load with DCP and remap the optimizer state for the current shards before calling optimizer.load_state_dict:
These FSDP APIs exist exactly to prevent optimizer-state mismatches on resume. (docs.pytorch.org)
Accelerate path: keep the model sharded but make the optimizer state “full” via the FSDP plugin (optim_state_dict_config=FullOptimStateDictConfig). This simplifies restore while preserving sharded weights. (Hugging Face)
Keep invariants identical across save→resume: use_orig_params, auto-wrap policy, and module wrapping. Changing them changes parameter identities and breaks mapping. (Hugging Face)
Elasticity: if world size or topology may change, rely on DCP for load-time resharding. Don’t use plain sharded loads across different topologies. (docs.pytorch.org)
Verify after load: spot-check a few param groups for state['step'] and moment tensor shapes; mismatches indicate a bad mapping. FSDP docs define the mapping API expectations. (docs.pytorch.org)
Known caveat: optim_state_dict_to_load can be memory-spiky with use_orig_params=True; consider CPU staging or use_orig_params=False if you hit OOM. (GitHub)
Why this is the best practice for your case
The HF thread shows loss spikes when resuming from sharded checkpoints but not from FULL_STATE_DICT. The cause is consistent with a mismapped or partial optimizer state. The remedy is the FSDP optimizer-state mapping or using DCP to reshard and restore correctly.
Minimal Accelerate example
# docs:
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
# https://huggingface.co/docs/accelerate/en/package_reference/fsdp
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.fsdp import FullOptimStateDictConfig
plugin = FullyShardedDataParallelPlugin(
state_dict_type=StateDictType.SHARDED_STATE_DICT,
optim_state_dict_config=FullOptimStateDictConfig(), # full optimizer, sharded model
use_orig_params=True, # keep consistent across runs
)
accelerator = Accelerator(fsdp_plugin=plugin)
# prepare(...) as usual; then accelerator.save_state(...) / load_state(...)
This pattern keeps sharded model files fast to write and read, yet prevents the resume spike by restoring a consistent optimizer state. (Hugging Face)
Supplemental reading
PyTorch DCP: load-time resharding, per-rank files. Updated 2025-06-16. Useful when world size changes. (docs.pytorch.org)
FSDP docs: optim_state_dict and optim_state_dict_to_load APIs. Canonical reference for correct optimizer restore. (docs.pytorch.org)
HF Accelerate FSDP guide + API: plugin knobs, state-dict configs, FSDP utilities. Practical knobs for the recipe above. (Hugging Face)
Related issues: reports of resume failures or spikes when optimizer mapping is wrong or topology changes. Confirms the failure mode. (GitHub)