Fine tune a finetuned model

Hello, I fine tuned Llama-2-7b model over a dataset of size 15K, using SFTTrainer and peft (LoRA). I saved this model and pushed it to the huggingface hub.
I now want to fine tune this already finetuned model of mine (and not the original pretrained Llama) on a small amount (~20) of new samples.
When I’m loading the model from the hub, and run inference on the old samples (the 15K dataset), it succeed on them as before. But when I initialize SFTTrainer with this finetuned model, it “forgets” everything it learned and doesn’t succeed on the old task.
I want to be able to initialize SFTTrainer & peft with the finetuned model, train it on the new 20 samples, and that it will still remember what it learned from the 15K dataset.
I know one soultion is to just train the pretrained model once on a dataset that will include both the new samples and the old dataset, and I did try it, but this soultion does not work well with my problem as the 20 new samples are “getting lost” in all the data and the model does not learn them well.

What can be the reason that the trainer is deleting the learned data from the model? How can I prevent this behavior?

Attaching here part of my training code, when model_name is the location of the model on the hub.

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
) 

device_map = {"": 0}

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
    trust_remote_code=True,
    token=hf_access_token,
    cache_dir="Llama-2/Cache2/",
)
base_model.config.use_cache = False

base_model.config.pretraining_tp = 1

peft_config = LoraConfig(
    lora_alpha=16 if not lora_args else lora_args['alpha'],
    lora_dropout=0.1,
    r=64 if not lora_args else lora_args['rank'],
    bias="none",
    task_type="CAUSAL_LM",
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          trust_remote_code=True,
                                          token=hf_access_token,
                                          cache_dir="Llama-2/Cache/")
tokenizer.pad_token = tokenizer.eos_token

if not output_dir:
    output_dir = f"../Models/{model_name.split('/')[1]}-{train_dataset_name}-" \
                 f"{datetime.now().date()}"

if not training_args:
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=train_args['train_batch_size'],
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        logging_steps=10,
        max_steps=train_args['max_steps'],
        report_to='none',
        num_train_epochs=train_args['num_of_epochs'],
        gradient_checkpointing=True,
    )

# Change padding side to right for training
tokenizer.padding_side = "right"
assert tokenizer.padding_side == "right"

train_dataset = dataset['train'] if data_dict is None else data_dict['train']
print('$$$ Training on dataset with size:', len(train_dataset), '$$$')
print(f'&&& max seq length = {train_args["max_seq_length"]} &&&')

trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    dataset_text_field="instruction",
    max_seq_length=train_args['max_seq_length'],
    tokenizer=tokenizer,
    args=training_args,
)

trainer.train()

Thank you!

Why the Model “Forgets” Previous Knowledge

  1. Overwriting Weights:
  • When fine-tuning on new data, the optimizer updates the model’s weights primarily to fit the new dataset. Since this dataset is small (20 samples), the updates can disproportionately affect the model’s behavior, overwriting previous learning.
  1. Quantization Effects:
  • Using 4-bit quantization (via bnb_config) can exacerbate catastrophic forgetting, as the precision of the weight updates is reduced.
  1. LoRA-Specific Configuration:
  • LoRA (Low-Rank Adaptation) introduces low-rank matrices for fine-tuning. However, improperly managing LoRA adapters across multiple fine-tuning runs can lead to conflicts or loss of previously learned adaptations.

Solutions

1. Merge LoRA Layers Before Re-Fine-Tuning

Before you fine-tune on the 20 new samples, merge the existing LoRA layers into the base model weights. This will consolidate the knowledge learned from the 15K dataset:

python

Copy code

from peft import merge_lora_weights

# Merge LoRA layers into the base model
base_model = merge_lora_weights(base_model)

After merging, reinitialize LoRA for the new fine-tuning task.


2. Use Continual Fine-Tuning with LoRA

Ensure that you load the existing LoRA adapters correctly before fine-tuning further. This approach retains prior knowledge in the adapter layers:

python

Copy code

from peft import PeftModel

# Load the model with the existing LoRA adapter
base_model = PeftModel.from_pretrained(
    base_model, 
    model_name, 
    trust_remote_code=True,
    device_map=device_map
)

You can then fine-tune the model on the new samples while keeping prior adaptations intact.


3. Apply Regularization to Preserve Knowledge

Introduce techniques that reduce catastrophic forgetting:

  • Regularization Loss (Elastic Weight Consolidation): Penalize updates that deviate too far from the original weights of the fine-tuned model.
  • Lower Learning Rate: Use a smaller learning rate for fine-tuning on the 20 new samples to minimize overwriting previous knowledge.
  • Freezing Certain Layers: Freeze some layers of the model (e.g., embedding layers or lower transformer blocks) to limit updates to only a subset of the model.