With notebook_launcher(main, use_fp16=True) my data still fp32
YEs that would be expected, this does not control bfloat16. Support for bfloat16 on TPUs is not in Accelerate yet.
Any way I can do so? Torch xla docs say to use env var, but it doesn’t work too.
I am trying to figure out about using bf16 on TPU. What I read from this GH issue is that setting the environment variable XLA_USE_BF16=1
should automagically work?
Can anyone confirm if this does in fact let one train on TPU with Accelerate in bf16? Or am I mis-reading this.
That’s one way, yes. But accelerate also now supports mixed_precision="bf16"
. You can also specify bf16 during accleerate config
Read more here on our docs: Training on TPUs with 🤗 Accelerate
1 Like