Further finetuning a LoRA finetuned CausalLM Model

Hi,

I think there are several possibilities:

  • adding an adapter which you train on task 1, then merge with the base model, and add another adapter which is trained on task 2.
  • adding an adapter which you train on one dataset, then you train it on another dataset, etc. (so only 1 adapter). Eventually you could merge it into the base model.
  • adding and training several adapters (one for each task/dataset) separately and merge them simultanously

For the first case, if you fine-tuned an xxxForCausalLM model using PEFT (i.e. by adding an adapter), you can load the model with its adapter using the AutoPeftModelForCausalLM class. Note that the adapter weights will still be separated from the base model. You could merge the adapter weights into the base model by calling the merge_and_unload method. Next, you could add another adapter, and apply PEFT again.

Let’s show this in code.

Step 1: load base model + adapter weights

First of all, note that there are 2 ways to load a base model with its adapter weights:

from peft import PeftModel, PeftConfig, AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM

# let's say you fine-tuned OPT using PEFT

# method 1: separately
base_model_id = "facebook/opt-350m"
adapter_id = "ybelkada/opt-350m-lora"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
base_with_adapters_model = PeftModel.from_pretrained(base_model, adapter_id)

# method 2: conveniently with the AutoPeftModelForCausalLM class
base_with_adapters_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")

Note that in both cases, the adapter weights are still stored separately from the base model (you can see this in the state dictionary, which still includes separate base_model and adapter keys).

Step 2: merge adapter weights into the base model

Hence, you can merge the adapter parameters with the base model:

# now we just have a regular AutoModelForCausalLM Transformers model
model = base_with_adapters_model.merge_and_unload()

Step 3: add another adapter

Next, we could apply PEFT again by adding another adapter:

# next, we could apply PEFT again by adding another adapter
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "v_proj"],
    task_type=TaskType.CAUSAL_LM,
    lora_alpha=32,
    lora_dropout=0.05
)

base_model_with_new_adapter = get_peft_model(model, lora_config)
base_model_with_new_adapter.print_trainable_parameters()

You can fine-tune the base_model_with_new_adapter using the Trainer API or PyTorch.

See this guide for more info: PEFT integrations.

5 Likes