Llama-2 Sequence Classification: Much lower accuracy on inference from checkpoint compared to model

I’m finetuning a Llama-2 sequence classification model with peft and qlora, and evaluating every 100 steps. I also save a checkpoint every 100 steps. When I load the checkpoint and do inference on the same validation set as during training, the accuracy is really much lower. Here’s the relevant code:
Training:

 q_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForSequenceClassification.from_pretrained(
    "meta-llama/Llama-2-13b-hf",
    quantization_config=q_config,
    device_map="auto", 
    num_labels=n_labels,
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False 

peft_config = LoraConfig(
        r=16,  
        lora_alpha=64, 
        lora_dropout=0.1, 
        bias="none",
        task_type=TaskType.SEQ_CLS,
        target_modules=['v_proj', 'down_proj', 'up_proj', 'q_proj', 'gate_proj', 'k_proj', 'o_proj']
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

training_args = TrainingArguments(...)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model("final-checkpoint")

Then for inference, I load the model as follows:

model = AutoModelForSequenceClassification.from_pretrained(
    "final-checkpoint",
    device_map="auto", 
    num_labels=n_labels,
    quantization_config=q_config
)

Doing inference with this latest model gives much much worse predictions on the test set then during training. I’ve tried loading with other instantiation methods (AutoPeftModelForSequenceClassification…) but the result is the same.
What am I doing wrong? Is it the saving that is wrong, or the loading? Something in the parameters?
Thing is, if a model has been training for days, and you cannot save it or load it again, then …?
Thank you for your help.

well well well it’s a bug…

(score): ModulesToSaveWrapper(
        (original_module): Linear(in_features=5120, out_features=647, bias=False)
        (modules_to_save): ModuleDict(
          (default): Linear(in_features=5120, out_features=647, bias=False)
        )
)

The module that is saved is score.modules_to_save.default, however the trained weights are in score.original_module.

The workaround is to save and load this module explicitly.

See Trainer of AutoModelForSequenceClassification is saving the wrong score module (or trained parameters are in the wrong module) · Issue #26160 · huggingface/transformers · GitHub

As a workaround, you can use a callback for saving during training:

class SaveScoreCallback(TrainerCallback):  
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def on_save(self, 
                args: TrainingArguments, 
                state: TrainerState,
                control: TrainerControl,
                **kwargs ):
        fname = f"{args.output_dir}/checkpoint-{state.global_step}/score.original_module.pt"
        torch.save(model.model.score.original_module.state_dict(), fname)

trainer.add_callback(SaveScoreCallback(model)) 

and when you load a checkpoint:

model = AutoPeftModelForSequenceClassification.from_pretrained(
    "path/to/checkpoint",
    device_map="auto", 
    num_labels=n_labels,
    quantization_config=q_config,  
)
score_weights = torch.load("path/to/checkpoint/score.original_module.pt", map_location='cpu')
model.score.original_module.load_state_dict(score_weights)
2 Likes

This is a really severe bug. I looked at the issue that you created in github but it is not even assigned for fixing…
Do you know if this bug is not happening in an older version of transformers?

1 Like

No, sorry I don’t know.

What I do know is that the bug does not occur when you don’t specify the target_modules in the LoraConfig, so effectively only use default q & v of the attention blocks.

Which is probably the most common use case, and the reason why the bug doesn’t get any attention.

Thanks for sharing the code. You could also use torch.save(kwargs['model'].score.original_module.state_dict(),fname) instead of passing model.

Any chance for more detailed code for the workaround? I’m also stuck with this bug (esm/sequence classification), and I don’t know what should be imported from hf vs the model in order ot use the callback workaround you mentioned.
Thanks!