Thank you so much for your detailed response!
@John6666
I was able to save the training state with accelerate without timeouts! … but I am now facing issues when trying to load the state back and continue training.
this is what I am doing:
# agent is a custom class which uses a VLM model (i.e. agent.vlm is a HF model)
agent, optimizer = accelerator.prepare(agent, optimizer)
# debug:
for name, param in agent.named_parameters():
print(name)
# module.vlm.model.base_model.model.model.language_model.layers.35.self_attn.o_proj.lora_A.lora_1.weight
# module.vlm.model.base_model.model.model.language_model.layers.35.self_attn.o_proj.lora_A.lora_2.weight
but when I do:
accelerator.load_state(checkpoint_path)
I get:
[2025-10-30 16:25:58,412] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from <model_path>/mp_rank_00_model_states.pt...
[2025-10-30 16:26:27,156] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
[2025-10-30 16:26:28,022] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
[2025-10-30 16:26:28,158] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
[2025-10-30 16:26:28,324] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from <model_path>/mp_rank_00_model_states.pt...
[2025-10-30 16:26:28,977] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from <model_path>/mp_rank_00_model_states.pt...
[2025-10-30 16:26:29,302] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from <model_path>/mp_rank_00_model_states.pt...
[2025-10-30 16:26:29,344] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
[2025-10-30 16:26:30,254] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from <model_path>/mp_rank_00_model_states.pt...
[2025-10-30 16:26:57,833] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
[rank2]: Traceback (most recent call last):
[rank2]: File "script.py", line 143, in <module>
[rank2]: accelerator.load_state(args.checkpoint_dir)
[rank2]: File "python/lib/python3.10/site-packages/accelerate/accelerator.py", line 3089, in load_state
[rank2]: model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
[rank2]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2806, in load_checkpoint
[rank2]: load_path, client_states = self._load_checkpoint(load_dir,
[rank2]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2889, in _load_checkpoint
[rank2]: self.load_module_state_dict(checkpoint=checkpoint,
[rank2]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2681, in load_module_state_dict
[rank2]: param.data.copy_(saved_frozen_params[name].data)
[rank2]: KeyError: 'vlm.model.base_model.model.model.visual.blocks.0.attn.qkv.lora_A.lora_2.weight'
[2025-10-30 16:26:59,399] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
Traceback (most recent call last):
File "script.py", line 143, in <module>
accelerator.load_state(args.checkpoint_dir)
File "python/lib/python3.10/site-packages/accelerate/accelerator.py", line 3089, in load_state
model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2806, in load_checkpoint
load_path, client_states = self._load_checkpoint(load_dir,
File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2889, in _load_checkpoint
self.load_module_state_dict(checkpoint=checkpoint,
File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2681, in load_module_state_dict
param.data.copy_(saved_frozen_params[name].data)
KeyError: 'vlm.model.base_model.model.model.visual.blocks.0.attn.qkv.lora_A.lora_2.weight'
[rank0]: Traceback (most recent call last):
[rank0]: File "script.py", line 143, in <module>
[rank0]: accelerator.load_state(args.checkpoint_dir)
[rank0]: File "python/lib/python3.10/site-packages/accelerate/accelerator.py", line 3089, in load_state
[rank0]: model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
[rank0]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2806, in load_checkpoint
[rank0]: load_path, client_states = self._load_checkpoint(load_dir,
[rank0]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2889, in _load_checkpoint
[rank0]: self.load_module_state_dict(checkpoint=checkpoint,
[rank0]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2681, in load_module_state_dict
[rank0]: param.data.copy_(saved_frozen_params[name].data)
[rank0]: KeyError: 'vlm.model.base_model.model.model.visual.blocks.0.attn.qkv.lora_A.lora_2.weight'
[2025-10-30 16:27:01,047] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from <model_path>/mp_rank_00_model_states.pt.
W1030 16:27:01.766000 2884113 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2884289 closing signal SIGTERM
W1030 16:27:01.768000 2884113 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2884290 closing signal SIGTERM
W1030 16:27:01.768000 2884113 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2884292 closing signal SIGTERM
[rank1]: Traceback (most recent call last):
[rank1]: File "script.py", line 143, in <module>
[rank1]: accelerator.load_state(args.checkpoint_dir)
[rank1]: File "python/lib/python3.10/site-packages/accelerate/accelerator.py", line 3089, in load_state
[rank1]: model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
[rank1]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2806, in load_checkpoint
[rank1]: load_path, client_states = self._load_checkpoint(load_dir,
[rank1]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2889, in _load_checkpoint
[rank1]: self.load_module_state_dict(checkpoint=checkpoint,
[rank1]: File "python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2681, in load_module_state_dict
[rank1]: param.data.copy_(saved_frozen_params[name].data)
[rank1]: KeyError: 'vlm.model.base_model.model.model.visual.blocks.0.attn.qkv.lora_A.lora_2.weight'
agent is a custom class but it still inherits from nn.Module so I thought it should be fine.
It seems both the model instance and the checkpoint have the correct parameters but saved under different names.