I want to load state_dict export bypeft.get_peft_model_state_dict into a new model with same lora config, but I find that the keys of state_dict export by get_peft_model_state_dict doesnât contain lora name. For example:
model_name_or_path = "bert-base-cased"
bert: BertModel = AutoModel.from_pretrained(model_name_or_path)
lora_bert = get_peft_model(bert, lora_config)
state_dict = get_peft_model_state_dict(lora_bert, save_embedding_layers=False)
for key in state_dict.keys():
print(key)
"""
fragment of output: missing 'default' key name in keys
base_model.model.encoder.layer.0.attention.self.query.lora_A.weight
base_model.model.encoder.layer.0.attention.self.query.lora_B.weight
"""
model_name_or_path = "bert-base-cased"
bert: BertModel = AutoModel.from_pretrained(model_name_or_path)
lora_bert = get_peft_model(bert, lora_config)
state_dict = lora_bert.state_dict()
for key in state_dict.keys():
print(key)
"""
fragment of output:
base_model.model.encoder.layer.0.attention.self.query.lora_A.default.weight
base_model.model.encoder.layer.0.attention.self.query.lora_B.default.weight
"""
I know I can manually manage this mapping to solve this problem, but it is a bit complicated. Is there an easier solution?
I use torch.nn.Module.load_state_dict() to load state_dict into new model, I donât understand where the paramenter remove_duplicate=False is. Another question is what the effection of paramater state_dict of get_peft_model_state_dict is. Thanks for your help!
The issue youâre facing stems from a difference in how the state dicts are stored by get_peft_model_state_dict and the regular state_dict when dealing with LoRa layers. The keys are altered when the model is wrapped with LoRa, specifically by the addition of .default in the keys, which is why they donât match exactly when youâre trying to reload them into the model.
One possible solution is to modify the get_peft_model_state_dict method or implement a utility function that ensures the keys are consistent when saving and loading the LoRa layers. Since manually mapping the keys seems complicated, hereâs an approach that could streamline the process:
Post-process the State Dict: After exporting the state dict using get_peft_model_state_dict, you can remove or replace the .default suffix in the keys so that it matches the original modelâs state dict format. This can be done by simply iterating over the keys and modifying them before loading the state dict into the model.
Hereâs a code example that does this:
import re
def fix_lora_state_dict(state_dict):
fixed_state_dict = {}
for key, value in state_dict.items():
# Remove '.default' suffix if present
new_key = re.sub(r'\.default
Using strict=False: In the load_state_dict method, setting strict=False ensures that any mismatched keys wonât throw errors, allowing the model to load the parameters even if some keys are missing or slightly different. This can be useful if the LoRa keys have minor differences but are still compatible.
Automation: You can integrate this process into your workflow so that the keys are automatically fixed before loading the state dict, reducing the manual mapping overhead.
This approach should help streamline the process without needing to manually track each key. Let me know if this works for you!, ââ, key)
fixed_state_dict[new_key] = value
return fixed_state_dict
2. **Using `DISCOURSE_PLACEHOLDER_8`**: In the `DISCOURSE_PLACEHOLDER_9` method, setting `DISCOURSE_PLACEHOLDER_10` ensures that any mismatched keys won't throw errors, allowing the model to load the parameters even if some keys are missing or slightly different. This can be useful if the LoRa keys have minor differences but are still compatible.
3. **Automation**: You can integrate this process into your workflow so that the keys are automatically fixed before loading the state dict, reducing the manual mapping overhead.
This approach should help streamline the process without needing to manually track each key. Let me know if this works for you!
Iâm not sure either, but if I summarize the conversation at the link, I think this code should be fine. Replace AutoModel~ with the class you actually use.
from peft import get_peft_model_state_dict, VBLoRAConfig
from collections import OrderedDict
from transformers import AutoModelForSequenceClassification
model_name_or_path = "roberta-large"
peft_config = VBLoRAConfig() # it should not work...
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None)
model = get_peft_model(model, peft_config)
params_dict = OrderedDict((name, param.detach())
for name, param in model.named_parameters(remove_duplicate=False) if "default" in name)
get_peft_model_state_dict(model, params_dict, "default")
This is my solution when loading state_dict into new model:
state_dict = {key.replace('.weight', f'.default.weight') if 'lora' in key
else key: value for key, value in state_dict.items()}
I hope the get_peft_model_state_dict can supply parameter to control its default filtering lora-name behavior in the next version. Thanks for your help
Just to explain the situation, PEFT removes the adapter name from the keys because the adapter name is somewhat arbitrary. E.g. it is âdefaultâ if not indicated otherwise. Therefore, if I train 2 adapters in separate sessions and then save them, they would both have the name âdefaultâ. However, when I later want to load them both at the same time, I would get a name clash. Therefore, the name is removed and then dynamically re-added when calling PeftModel.from_pretrained, load_adapter, etc.
Regarding the linked issue #2302, note that this concerns VBLoRA, which is a different method than LoRA.
If you want to have an option to keep the name, please open an issue on the PEFT github page and explain your reasoning.
I see. Even if the specification is changed to save the name, the destination of the name is the tensor key name itself, so the handling at load time becomes more troublesome. If it is changed, there will probably be compatibility issuesâŚ
For example, you could create an empty tensor in state_dict and save the name in the key, but itâs dirty.
If we users know the situation, we can avoid the problem ourselves, so the current specification is simple enough.
I guess we should only save the state in state_dictâŚ