Proposal to Enhance `get_state_dict` and Introduce `load_from_state_dict` for Greater Flexibility

Dear Accelerate Team,

First of all, thank you for developing such a powerful and versatile library. Accelerate has been an incredible tool for simplifying distributed training and mixed precision setups. While working with the library, I’ve come across a potential improvement that could make it even more flexible and user-friendly, particularly for users who have custom serialization workflows.

Current Limitations

The get_state_dict function is a great utility for retrieving the state dictionary of a model, but it is currently limited to models only. This creates challenges for users who also want to retrieve and manage the state of other components, such as optimizers, schedulers, or RNG states, in a similarly abstracted way.

For example:

  • To retrieve the state of an optimizer (e.g., FSDP optimizer), users must manually interact with lower-level PyTorch APIs like torch.distributed.fsdp. This somewhat breaks the abstraction that Accelerate provides and can be cumbersome for users unfamiliar with these internals.
  • While save_state does provide a comprehensive solution for saving models, optimizers, schedulers, and RNG states together, it assumes that users will adopt its specific approach to checkpointing. For those who already have custom serialization logic in place, integrating with Accelerate becomes less straightforward.

Proposed Enhancements

  1. Extend get_state_dict for Broader Use Cases:

    • Allow get_state_dict to support not just models but also optimizers, schedulers, RNG states, or even custom objects.
    • This would enable users to retrieve the state of any component in a consistent manner without needing to directly interact with lower-level APIs.
  2. Introduce load_from_state_dict:

    • Complement the above by adding a method like load_from_state_dict, which allows users to load models, optimizers, schedulers, or RNG states directly from user-defined dictionaries (e.g., OrderedDict).
    • This would make it easier for users with custom checkpointing workflows to integrate their logic with Accelerate while still benefiting from its abstractions.

Benefits of These Changes

  • Improved Flexibility: Users with custom serialization workflows can more easily integrate their logic with Accelerate without being forced into specific patterns like save_state.
  • Enhanced Abstraction: By extending get_state_dict and introducing load_from_state_dict, users can interact with models, optimizers, schedulers, and RNG states in a consistent and abstracted manner.

Closing Thoughts

I believe these enhancements would make Accelerate even more flexible and accessible for a wider range of users. I’d love to hear your thoughts on this proposal—whether it aligns with the library’s design philosophy and if there are any plans to address these challenges in future updates.

Thank you again for all your hard work on this amazing library! I look forward to seeing how Accelerate continues to grow and improve.

1 Like