Flash Attention 2 Error on Mistral Based Model

Hi, I am trying to enable flash attention 2 on a model yet I got this error:

ValueError: past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([4, 8, 3968, 128])

I am using openchat’s openchat_3.5 7B model which I believe is based on mistral openchat/openchat_3.5 · Hugging Face. I am loading the model as such:

model = AutoModelForCausalLM.from_pretrained(model_name, 

Can someone explain the error for me? Or point me to a resource which can help me understand the problem, thank you.