Hi there, I have been walking through the A Gentle Introduction to 8-bit Matrix Multiplication for transformers blog post to get an idea on how llm.int8() works. At the end of the blog post, there is a very short demo to show how to do pass a FP16 input to 8bit model. Unfortunately, it doesn’t work in my system. Here is what I’ve tried:
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
torch.save(fp16_model.state_dict(), "model.pt")
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # Quantization happens here
input_ = torch.randn(64, dtype=torch.float16)
hidden_states = int8_model(input_.to(torch.device('cuda', 0)))
I got the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 376, in forward
C32A, SA = F.transform(CA, "col32")
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1701, in transform
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/functional.py", line 351, in get_transform_buffer
return init_func((rows, cols), dtype=dtype, device=device), state
UnboundLocalError: local variable 'rows' referenced before assignment
Any idea what would be the reason? I’m runnning this on Ubuntu, and my transformer. version is 4.27.2 and I tried this on an instance with a single A10G GPU.