How to enable BF16 on tpus?

With notebook_launcher(main, use_fp16=True) my data still fp32

YEs that would be expected, this does not control bfloat16. Support for bfloat16 on TPUs is not in Accelerate yet.

Any way I can do so? Torch xla docs say to use env var, but it doesn’t work too.

I am trying to figure out about using bf16 on TPU. What I read from this GH issue is that setting the environment variable XLA_USE_BF16=1 should automagically work?

Can anyone confirm if this does in fact let one train on TPU with Accelerate in bf16? Or am I mis-reading this.

That’s one way, yes. But accelerate also now supports mixed_precision="bf16". You can also specify bf16 during accleerate config

Read more here on our docs: Training on TPUs with 🤗 Accelerate

1 Like