Imagine the following (common) training-inference scenario:
Training
- You load a LLama-2 model in 4 bit and you train using PEFT (Q-LoRA)
- You train an adapter and push it to the hub
Inference options:
Option 1: Use Quantized base model w/ adapter:
It seems the most logical, since the adapter was calibrated on a model with this precision. However, doing this throws the error:
RuntimeError: expected scalar type Float but found Half
Option 2: use Non-quantized base model w/ adapter
Loading the non-quantized model does work, but wouldn’t using the same adapter on a higher-precision model cause the adapter to be miscalibrated?
model_name = "meta-llama/Llama-2-13b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
quantization_config=bnb_config, # This line apparently causes an error?
device_map='auto'
)
model.config.use_cache = False
model.load_adapter(model_name)
Could somebody help me understand this?