Enabling load_in_8bit makes inference much slower

I loaded the 7b llama on an A100 this way:

quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=False, llm_int8_threshold=0.0)
model = LlamaForCausalLM.from_pretrained(
   "decapoda-research/llama-7b-hf",
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

When load_in_8bit is False, it generates 16.7 tokens/sec whereas load_in_8bit=True generates only 6.7 tokens/sec. It seems I probably set up configurations incorrectly for load_in_8bit=True.

My transformers version is 4.29.0.dev0. Did I miss anything? Thanks.

1 Like

I am facing similar issue, did you find a solution?

me tooļ¼Œand training is slower too