Loading an LoRA adapter trained on quantized model on a non-quantized model

Imagine the following (common) training-inference scenario:

Training

  • You load a LLama-2 model in 4 bit and you train using PEFT (Q-LoRA)
  • You train an adapter and push it to the hub

Inference options:
Option 1: Use Quantized base model w/ adapter:
It seems the most logical, since the adapter was calibrated on a model with this precision. However, doing this throws the error:

RuntimeError: expected scalar type Float but found Half

Option 2: use Non-quantized base model w/ adapter
Loading the non-quantized model does work, but wouldn’t using the same adapter on a higher-precision model cause the adapter to be miscalibrated?

model_name = "meta-llama/Llama-2-13b-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    quantization_config=bnb_config,  # This line apparently causes an error?
    device_map='auto'
)
model.config.use_cache = False
model.load_adapter(model_name)

Could somebody help me understand this?

1 Like