Dear Accelerate Team,
First of all, thank you for developing such a powerful and versatile library. Accelerate has been an incredible tool for simplifying distributed training and mixed precision setups. While working with the library, I’ve come across a potential improvement that could make it even more flexible and user-friendly, particularly for users who have custom serialization workflows.
Current Limitations
The get_state_dict
function is a great utility for retrieving the state dictionary of a model, but it is currently limited to models only. This creates challenges for users who also want to retrieve and manage the state of other components, such as optimizers, schedulers, or RNG states, in a similarly abstracted way.
For example:
- To retrieve the state of an optimizer (e.g., FSDP optimizer), users must manually interact with lower-level PyTorch APIs like
torch.distributed.fsdp
. This somewhat breaks the abstraction that Accelerate provides and can be cumbersome for users unfamiliar with these internals. - While
save_state
does provide a comprehensive solution for saving models, optimizers, schedulers, and RNG states together, it assumes that users will adopt its specific approach to checkpointing. For those who already have custom serialization logic in place, integrating with Accelerate becomes less straightforward.
Proposed Enhancements
-
Extend
get_state_dict
for Broader Use Cases:- Allow
get_state_dict
to support not just models but also optimizers, schedulers, RNG states, or even custom objects. - This would enable users to retrieve the state of any component in a consistent manner without needing to directly interact with lower-level APIs.
- Allow
-
Introduce
load_from_state_dict
:- Complement the above by adding a method like
load_from_state_dict
, which allows users to load models, optimizers, schedulers, or RNG states directly from user-defined dictionaries (e.g.,OrderedDict
). - This would make it easier for users with custom checkpointing workflows to integrate their logic with Accelerate while still benefiting from its abstractions.
- Complement the above by adding a method like
Benefits of These Changes
- Improved Flexibility: Users with custom serialization workflows can more easily integrate their logic with Accelerate without being forced into specific patterns like
save_state
. - Enhanced Abstraction: By extending
get_state_dict
and introducingload_from_state_dict
, users can interact with models, optimizers, schedulers, and RNG states in a consistent and abstracted manner.
Closing Thoughts
I believe these enhancements would make Accelerate even more flexible and accessible for a wider range of users. I’d love to hear your thoughts on this proposal—whether it aligns with the library’s design philosophy and if there are any plans to address these challenges in future updates.
Thank you again for all your hard work on this amazing library! I look forward to seeing how Accelerate continues to grow and improve.