The quantization code in the "Gentle Introduction to 8-bit Matrix Multiplication for transformers" blog post yields error

Hi there, I have been walking through the A Gentle Introduction to 8-bit Matrix Multiplication for transformers blog post to get an idea on how llm.int8() works. At the end of the blog post, there is a very short demo to show how to do pass a FP16 input to 8bit model. Unfortunately, it doesn’t work in my system. Here is what I’ve tried:

import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64)
), "")
int8_model = nn.Sequential(
    Linear8bitLt(64, 64, has_fp16_weights=False),
    Linear8bitLt(64, 64, has_fp16_weights=False)
int8_model = # Quantization happens here
input_ = torch.randn(64, dtype=torch.float16)
hidden_states = int8_model('cuda', 0)))

I got the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/", line 217, in forward
    input = module(input)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/nn/", line 242, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/", line 488, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/autograd/", line 376, in forward
    C32A, SA = F.transform(CA, "col32")
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/", line 1701, in transform
    if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/bitsandbytes/", line 351, in get_transform_buffer
    return init_func((rows, cols), dtype=dtype, device=device), state
UnboundLocalError: local variable 'rows' referenced before assignment

Any idea what would be the reason? I’m runnning this on Ubuntu, and my transformer. version is 4.27.2 and I tried this on an instance with a single A10G GPU.

Hello! Try this input_ = torch.randn((1, 64), dtype=torch.float16)