The quantization code in the "Gentle Introduction to 8-bit Matrix Multiplication for transformers" blog post yields error

Hi there, I have been walking through the A Gentle Introduction to 8-bit Matrix Multiplication for transformers blog post to get an idea on how llm.int8() works. At the end of the blog post, there is a very short demo to show how to do pass a FP16 input to 8bit model. Unfortunately, it doesn’t work in my system. Here is what I’ve tried:

import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64)
)
torch.save(fp16_model.state_dict(), "model.pt")
int8_model = nn.Sequential(
    Linear8bitLt(64, 64, has_fp16_weights=False),
    Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # Quantization happens here
input_ = torch.randn(64, dtype=torch.float16)
hidden_states = int8_model(input_.to(torch.device('cuda', 0)))

I got the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 376, in forward
    C32A, SA = F.transform(CA, "col32")
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1701, in transform
    if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/functional.py", line 351, in get_transform_buffer
    return init_func((rows, cols), dtype=dtype, device=device), state
UnboundLocalError: local variable 'rows' referenced before assignment

Any idea what would be the reason? I’m runnning this on Ubuntu, and my transformer. version is 4.27.2 and I tried this on an instance with a single A10G GPU.

Hello! Try this input_ = torch.randn((1, 64), dtype=torch.float16)