I use dpo trainer to train the model using FSDP using accelerate.
After done, i want to merge the adaptor to the base model.
I am using the code below, but it throws error saying cannot find the layer name.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(trainer.model):
base_model = trainer.model.unload()
print(base_model) # I see it still has layer name inside: (_fsdp_wrapped_module)
peft_model = PeftModel.from_pretrained(base_model, adapter_path)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained(output_path, safe_serialization=False)
self.tokenizer.save_pretrained(output_path)
Error is like
[rank0]: AssertionError: FSDP assumes model.layers.0.self_attn.q_proj.base_layer.weight is in the state_dict but the state_dict only has odict_keys([....
1 Like
It seems possible that assertions are being made where assertions should not be made. Not so much a bug, apparently…
opened 03:09AM - 07 Jun 23 UTC
closed 06:58PM - 21 Jun 23 UTC
triaged
module: fsdp
### 🐛 Describe the bug
I wrote a context manager that allows me to only save … specific modules when `module.state_dict`:
```python
with partial_state_only(model):
sd = model.state_dict()
```
However, this seems to violate an FSDP assumption:
```python
Traceback (most recent call last):
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/carmocca/git/lit-stablelm/kk.py", line 63, in work
sd = model.state_dict()
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1835, in state_dict
hook_result = hook(self, destination, prefix, local_metadata)
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 661, in _post_state_dict_hook
processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 325, in _full_post_state_dict_hook
return _common_unshard_post_state_dict_hook(
File "/home/carmocca/git/nightly-venv/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 209, in _common_unshard_post_state_dict_hook
assert fqn in state_dict, (
AssertionError: FSDP assumes l2.weight is in the state_dict but the state_dict only has odict_keys(['l1.weight']). prefix=, module_name=l2., param_name=weight rank=0.
```
Repro code:
```python
import contextlib
import os
from functools import partial
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.api import FullOptimStateDictConfig, FullStateDictConfig, StateDictType
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(100, 50, bias=False)
self.l2 = nn.Linear(50, 1, bias=False)
@contextlib.contextmanager
def partial_state_only(module: nn.Module):
originals = {}
def save(name, destination, prefix, keep_vars):
if "l1" in prefix:
original_fn = originals[name]
return original_fn(destination, prefix, keep_vars)
for name, submodule in module.named_modules():
originals[name] = submodule._save_to_state_dict
submodule._save_to_state_dict = partial(save, name)
yield
for name, module in module.named_modules():
module._save_to_state_dict = originals[name]
def work(rank):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "1234"
dist.init_process_group("nccl", world_size=1, rank=rank)
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
model = MyModel().to(device)
model = FullyShardedDataParallel(model)
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
state_dict_type_context = FullyShardedDataParallel.state_dict_type(
module=model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)
with partial_state_only(model), state_dict_type_context:
sd = model.state_dict()
print(sd)
def run():
mp.spawn(work, nprocs=1)
if __name__ == "__main__":
run()
```
I'd appreciate advice on either how to make the context manager work with FSDP or if FSDP could be updated to support this sort of behavior
Thank you!
### Versions
pytorch-triton 2.1.0+7d1a95b046
torch 2.1.0.dev20230505+cu118
cc @zhaojuanmao @mrshenli @rohan-varma @awgu
After I run trainer.model.unload(), It still has fsdp wrapped layer name. The saved adapter state dict does not have this name becase insave the full state dict
What I do now is to reload the base model using normal from_prretained method and then use PeftModel.from_pretained method to convert to peft model and then run merge and unload method
Is it possible directly using the fsdp peft model to merge the adapter?
1 Like