Why models(llama in particular) upcasts softmax to fp32?

Consider the following:

For my laptop’s 3080Ti it’s difference between getting OoM at ~1K context on open_llama_3b loaded in bf16 and not getting OoM, sitting at 14GB VRAM used.

Is bf16 that unstable and prone to returning nans/infs? I removed upscale and didn’t notice the difference(other than not getting OoM).

(see this related GH issue)