I tried to fine-tune a small model as lightly as I could just to test my workflow, and I can’t get past an error.
My quantization config is as follows:
quantization_config= BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_compute_dtype= torch.bfloat16,
bnb_4bit_use_double_quant= True,
bnb_4bit_quant_type="nf4",
)
LORA Config:
peft_params = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
Trainer config:
training_params = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
optim="adamw_torch",
save_steps=25,
logging_steps=25,
learning_rate=1e-4,
weight_decay=0.001,
fp16=False,
bf16=True,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to="tensorboard"
)
After a brief fine-tuning of just a few dataset examples, I try to generate some demo completions. I’m just trying to make sure everything works. I used the following commands:
prompt_text= "Hello Hello Hello"
input_ids= tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True)["input_ids"]
input_ids= input_ids.to(model.device)
output= model.generate(input_ids)
generated_text= tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
And I get the following error seemingly no matter what I do:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-51-1b75867bada9> in <cell line: 4>()
2 input_ids= tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True)["input_ids"]
3 input_ids= input_ids.to(model.device)
----> 4 output= model.generate(input_ids)
5 generated_text= tokenizer.decode(output[0], skip_special_tokens=True)
6 print(generated_text)
10 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
I’ve tried re-casting the input_ids to bfloat16, but I still get the same error. I can’t seem to get it to work. Any advice?