Hi, I am trying to enable flash attention 2 on a model yet I got this error:
ValueError: past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([4, 8, 3968, 128])
I am using openchat’s openchat_3.5 7B model which I believe is based on mistral openchat/openchat_3.5 · Hugging Face. I am loading the model as such:
model = AutoModelForCausalLM.from_pretrained(model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
use_flash_attention_2="flash_attention_2",
low_cpu_mem_usage=True)
Can someone explain the error for me? Or point me to a resource which can help me understand the problem, thank you.