Compatibility of flash attention 2 and type conversion due to accelerator.prepare

Hi, I’m trying to fine-tune my model, which is BLIP-2, using flash attention 2 on OPT 2.7B, but using FA2 produces significantly higher loss than using eager attention mode, which seems similar to issues reported previously (#26498, #28925, #28142).
From the comments from those issues, the best way to use fa2 normally is to load the model in full precision and train the model with autocast context.
However, when using accelerate library, accelerator.prepare function converts the model into a specified dtype (for me, bf16) including layer norm.
I guess this caused the problem for me, but I’m not sure.

Could you check this behavior and give any suggestions? I’m using transformers==4.40.0.dev0, accelerate==0.23.0 and flash_attn==2.5.5.
Or if there is any more detail that I have to elaborate on, please let me know.
Thanks in advance :slight_smile: