Loading Peft model from checkpoint leading into size missmatch

I am finetuning a model using peft adapters and want to check if a checkpoint that I have saved is good enough. However, I have added an extra token to the vocabulary before fine-tuning, which results in different embedding size.

I try to load the model like this:

Load base model

model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
quantization_config=bnb_config,
device_map=device_map,
)

And I run into the following error:

Loading checkpoint shards: 100%
3/3 [01:10<00:00, 23.21s/it]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-ceef6d0bb055> in <cell line: 2>()
      1 # Load base model
----> 2 model = AutoModelForCausalLM.from_pretrained(
      3    checkpoint_dir,
      4     quantization_config=bnb_config,
      5     device_map=device_map,

4 frames
/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    565         elif type(config) in cls._model_mapping.keys():
    566             model_class = _get_model_class(config, cls._model_mapping)
--> 567             return model_class.from_pretrained(
    568                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    569             )

/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3643 
   3644         if _adapter_model_path is not None:
-> 3645             model.load_adapter(
   3646                 _adapter_model_path,
   3647                 adapter_name=adapter_name,

/usr/local/lib/python3.10/dist-packages/transformers/integrations/peft.py in load_adapter(self, peft_model_id, adapter_name, revision, token, device_map, max_memory, offload_folder, offload_index, peft_config, adapter_state_dict, adapter_kwargs)
    204 
    205         # Load state dict
--> 206         incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
    207 
    208         if incompatible_keys is not None:

/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py in set_peft_model_state_dict(model, peft_model_state_dict, adapter_name)
    239         raise NotImplementedError
    240 
--> 241     load_result = model.load_state_dict(peft_model_state_dict, strict=False)
    242     if config.is_prompt_learning:
    243         model.prompt_encoder[adapter_name].embedding.load_state_dict(

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
   2150 
   2151         if len(error_msgs) > 0:
-> 2152             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2154         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MistralForCausalLM:
	size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

I have tested this code with a fine-tuned model (not from a midway checkpoint) and it worked okay. However, when I try to load a checkpoint I run into this mismatch error. Any suggestions (I tried to add ignore_missmatch param but it seems to not work anymore)?

Hi,

In case you have added tokens, it’s recommended to use the PeftModel class rather than AutoModelForCausalLM. The former takes into account resizing the embedding matrix.

cc @ybelkada

1 Like

I loaded the base model with AutoModelForCausalLM like this:

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)

Then, I tried to load the Peft Model (a midway checkpoint) with a resized embedding matrix that I had trained with PeftModel like this:

 # Load PEFT model
model = PeftModel.from_pretrained(
    model=model,
    model_id =checkpoint_dir,
    peft_config=bnb_config,
    device_map=device_map,
)

And I still run into this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-31-b33168ee43cd> in <cell line: 2>()
      1 # Load PEFT model
----> 2 model = PeftModel.from_pretrained(
      3     model=model,
      4     model_id =checkpoint_dir,
      5     peft_config=bnb_config,

3 frames
/usr/local/lib/python3.10/dist-packages/peft/peft_model.py in from_pretrained(cls, model, model_id, adapter_name, is_trainable, config, **kwargs)
    352         else:
    353             model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
--> 354         model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
    355         return model
    356 

/usr/local/lib/python3.10/dist-packages/peft/peft_model.py in load_adapter(self, model_id, adapter_name, is_trainable, **kwargs)
    696 
    697         # load the weights into the model
--> 698         load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
    699         if (
    700             (getattr(self, "hf_device_map", None) is not None)

/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py in set_peft_model_state_dict(model, peft_model_state_dict, adapter_name)
    239         raise NotImplementedError
    240 
--> 241     load_result = model.load_state_dict(peft_model_state_dict, strict=False)
    242     if config.is_prompt_learning:
    243         model.prompt_encoder[adapter_name].embedding.load_state_dict(

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
   2150 
   2151         if len(error_msgs) > 0:
-> 2152             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2154         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

If I did not grab a midway checkpoint from a PEFT model I can load the final fine-tuned model (with resized embedding matrix) without a problem even using AutoModelForCausalLM like this:

# Load final fine-tuned model
model = AutoModelForCausalLM.from_pretrained(
    final_checkpoint,
    quantization_config=bnb_config,
    device_map=device_map
)

I was able to make it run by resizing the original model before calling the Peft model. Thank you! However, I am still unsure as why I can run this out of the box with the final model but I have to call a PeftModel when working with a midway checkpoint, thanks

Hi,

A reply from @smangrul:

Were you resizing embedding layers before? In the merge and unload code too, you need to resize the embedding layers before loading the PEFT adapters. Now, whenever you resize embedding layers, those get saved along with the adapters as those new tokens were initialised randomly and adapters are tuned wrt those.

Hi,

I have always resized the embeddings, which is why I was wondering why AutoModelForCausalLM works with adapters saved with trainer.model.save_pretrained(new_model) but not with a midway training checkpoint of the adapters. If you happen to know why please share :slight_smile:

Anyway, I was able to fix my issue, which I am grateful for!

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.