How to load a model fine-tuned with QLoRA

I fine-tuned a mistral 7b with QLoRA that initiated as the following code during training:

## TRAINING
bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= True,
)

model = AutoModelForSequenceClassification.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
peft_config = LoraConfig(...)
model = get_peft_model(model, peft_config)

In inferences, I tried to load the model using the same setting as training:

## INFERENCING
bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= True,
)

base_model = AutoModelForSequenceClassification.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(base_model, "saved-adapter-path").eval()

However, the inference process is very slow.

I found that loading the model only with load_in_4bit=True like the following. The inference process is much faster.

base_model = AutoModelForSequenceClassification.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    bnb_config = BitsAndBytesConfig(load_in_4bit= True)
)
model = PeftModel.from_pretrained(base_model, "saved-adapter-path").eval()

Isn’t using double quant and bfloat supposed to be faster and use less memory when compared to the default setting?

What should be the correct way to load the model in my case?

Hi,

Refer to my demo notebook on fine-tuning Mistral-7B, it includes an inference section.

In summary, one can simply use the Auto classes (like AutoModelForCausalLM) to load models fine-tuned with Q-LoRa, thanks to the PEFT integration in Transformers. It will automatically load the base model + adapter weights.

Optionally, you can then call the merge_and_unload() method on it to merge the adapters into the base model.