As bfloat16
hardware support is becoming more available there is an emerging trend of training in bfloat16, which leads to the issue of not being able to finetune such models in mixed precision (or eval in fp16) - be it amp, apex or deepspeed/fairscale.
Last week I spent some time sitting with the NaN issues reported in t5/mt5 (and pegasus apparently too), and I have been watching the activation values: [T5/MT5] resolve inf/nan under amp (mixed precision) by stas00 · Pull Request #10956 · huggingface/transformers · GitHub
and studying the numerical qualities of bfloat16 vs bloat16: ml-ways/bfloat16-vs-float16-study.ipynb at master · stas00/ml-ways · GitHub
So my conclusion/understanding is this: since bfloat16
has no access to precision it basically compensates and trains itself to use huge numbers, so rather than having small activation values it operates in the 1e5
- 1e10
+ range which is beyond the 64k limit float16
can handle and thus overflows (inf
) which then immediately leads to nan
(see my nb for how inf
/nan
comes about).
To make things worse bfloat16
huge number range has huge gaps with no numbers in it:
torch.tensor(283, dtype=torch.bfloat16)*10 # 2848 instead of 2830!
so it trains to compensate for that handicap as well. And so when float16
comes around which has much smaller gaps it obviously won’t produce the same results. See my notebook to see the gaps demo’ed.
Ideally there should be some plane transform that could take the weights trained in bfloat16 and convert those to the numerical domain of float16. A naive approach could be to divide everything by ~100000 to shift to a different effective range . But because the training is non-linear I can’t see how this would be possible, other than via some DNN that was trained for such transform.
As you can see from the PR some workarounds may work, but it’s hard to keep the numbers in check when the model wants to constantly operate in the range float16 wasn’t designed for. A user already reported NaNs after a 3h training with this PR, but hasn’t shared a way to reproduce yet.
@sshleifer suggested here that perhaps finetuning with a penalty for large activations could do the trick. It’s unclear how much of such finetuning it’d take, since the need is to lower the weights by several orders of magnitude, so that the activations and accumulative math operations don’t break the 64K barrier.
So currently t5/mt5/pegasus models are affected, but I’m sure there will be more emerging as new hardware supporting bfloat16 is quickly emerging so we will have to deal with that a lot more very soon I believe.
Of course, if we wait long enough, the mixed precision will be moved to fp32/bf16 or even not be needed anymore.
If perhaps some of you have experimented with such bf16 to fp16 finetuning and had good results please do share. It’s possible that if a solid approach is found then we will need to make a 2nd set of these models whose weights are finetuned for fp16.
Thank you.