Having saved a model in 8bit:
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit = True,
device_map = 'auto')
model.save_pretrained('model_weights_test')
I then loaded it in 4bit (with a different script):
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
'model_weights_test',
quantization_config=double_quant_config,
device_map = 'auto')
However I am then running into OOM issues I was not seeing by just initially loading it in 4bit:
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=double_quant_config,
device_map = 'auto')
My question is therefore: does the 4bit config fail to override the 8bit saved model? Am I essentially just training an 8bit model, hence the OOM issues?