Why do I have to use "model.half()" when I load a int4 model?

I’m trying to test THUDM/chatglm-6b-int4 at main (huggingface.co) on my PC

the code below works fine.

# exp1
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
print(model.dtype) # output "torch.float16"
model = model.half().cuda()  
print(model.dtype) # output "torch.float16"
response, history = model.chat(tokenizer, "hi", history=[])
print(response)

but if I remove “half()”, error occurs.

#exp2
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
print(model.dtype) # output "torch.float16"
model = model.cuda()  
print(model.dtype) # output "torch.float16"
response, history = model.chat(tokenizer, "hi", history=[])
print(response)
Traceback (most recent call last):
  File "test_infer_int4.py", line 17, in <module>
    response, history = model.chat(tokenizer, "hi", history=[])
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/ChatGLM-6B-Int4/modeling_chatglm.py", line 1253, in chat
    outputs = self.generate(**inputs, **gen_kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 1452, in generate
    return self.sample(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2468, in sample
    outputs = self(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/ChatGLM-6B-Int4/modeling_chatglm.py", line 1158, in forward
    transformer_outputs = self.transformer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/ChatGLM-6B-Int4/modeling_chatglm.py", line 971, in forward
    layer_ret = layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/ChatGLM-6B-Int4/modeling_chatglm.py", line 609, in forward
    attention_input = self.input_layernorm(hidden_states)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/normalization.py", line 190, in forward
    return F.layer_norm(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2515, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Half but found Float

It’s utterly confusing to me that model dtype is torch.float16 both before and after using model.half(), why did the 2nd exp fail?

ps. exp env:
Ubuntu 20.04
Python 3.8.10
torch 1.13.1+cu117
transformers 4.27.1
sentencepiece 0.2.0

btw, setting parameter torch_dtype=torch.float16 instead of using model.half() also works fine.

model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True, torch_dtype=torch.float16).cuda()
response, history = model.chat(tokenizer, "hi", history=[])

Hi @dannyp024

It looks like you’re working with THUDM/chatglm-6b-int4 and encountering an issue when removing .half() from your code. It’s interesting how .half() seems crucial for your setup.

Have you experimented with different models or configurations to see if this behavior is consistent across other setups?

I figure it out.
Not all the parameters in model “THUDM/chatglm-6b-int4” are torch.float16, and model.dtype returns the first found floating dtype in parameters if there is one. Thus runtime error occurs, we should use model.half() to cast all the parameters dtype to torch.float16.

Here are the first 5 layer’s dtype without using model.half() :

layer name: transformer.word_embeddings.weight, dtype: torch.float16
layer name: transformer.layers.0.input_layernorm.weight, dtype: torch.float32
layer name: transformer.layers.0.input_layernorm.bias, dtype: torch.float32
layer name: transformer.layers.0.attention.query_key_value.bias, dtype: torch.float16
layer name: transformer.layers.0.attention.query_key_value.weight, dtype: torch.int8

and after using model.half(), the first 5 layer’s dtype turn into:

layer name: transformer.word_embeddings.weight, dtype: torch.float16
layer name: transformer.layers.0.input_layernorm.weight, dtype: torch.float16
layer name: transformer.layers.0.input_layernorm.bias, dtype: torch.float16
layer name: transformer.layers.0.attention.query_key_value.bias, dtype: torch.float16
layer name: transformer.layers.0.attention.query_key_value.weight, dtype: torch.int8

Both model.half() and model = AutoModel.from_pretrained(“THUDM/chatglm-6b-int4”, trust_remote_code=True, torch_dtype=torch.float16) work.

however, model = AutoModel.from_pretrained(“THUDM/chatglm-6b-int4”, trust_remote_code=True, torch_dtype=“auto”) dont work even though torch_dtype in config.json is torch.float16, since setting torch_dtype=“auto” is useless. I am not sure if it is a bug in transformers.

here the torch_dtype=“auto” is meanless, and function AutoConfig.from_pretrained returns a config with torch_dtype=torch.float16 and a “()” kwargs in line 441

but model dtype will not be set to torch.float16 as empty kwargs passed to function model.from_pretrained where torch_dtype = kwargs.pop(“torch_dtype”, None), i.e. None

Upgrading transformers 4.27.1 to 4.33.2 can make torch_dtype=“auto” sense

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.