Flash Attention 2 Error on Mistral Based Model

Hi, I am trying to enable flash attention 2 on a model yet I got this error:

ValueError: past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([4, 8, 3968, 128])

I am using openchat’s openchat_3.5 7B model which I believe is based on mistral openchat/openchat_3.5 · Hugging Face. I am loading the model as such:

model = AutoModelForCausalLM.from_pretrained(model_name, 
                             device_map="auto", 
                             torch_dtype=torch.bfloat16,
                             use_flash_attention_2="flash_attention_2",
                             low_cpu_mem_usage=True)

Can someone explain the error for me? Or point me to a resource which can help me understand the problem, thank you.