To make sure the gradients are computed identically in Flax and PyTorch I’ve run a script that checks for the gradient norm.
Bash script: run_models.sh · patrickvonplaten/codesnippets at main
Python script: check_gradients_pt_flax.py · patrickvonplaten/codesnippets at main
on the small random weights of BERT, RoBERTa, T5, and BART as well as on pretrained weights to see if there are any mismatches. I’ve noticed a bug in FlaxBart: [FlaxBart] make sure no grads are computed an bias by patrickvonplaten · Pull Request #16345 · huggingface/transformers · GitHub that caused a strong mismatch in the gradient of the logits bias.
After having solved this bug all the gradients for both tiny as well as pretrained weigths seem to be correctly computed. There are only three notable differences:
-
facebook/bart-large
has very high gradient norms that seem to quickly differ between PyTorch and Flax - The bias some attention matrices has very low gradient norms that have small relative differences between PyTorch and Flax. This should however be completely fine.
=> As a conclusion, after [FlaxBart] make sure no grads are computed an bias by patrickvonplaten · Pull Request #16345 · huggingface/transformers · GitHub is merged Flax and PyTorch gradients for BERT, RoBERTa, T5, and BART can be considered equal and we should be careful with facebook/bart-large
given it’s very high gradients.
Results for Bart before fix:
Tiny-random
=========================================
Check hf-internal-testing/tiny-random-bart ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 8.940696716308594e-08 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.923163414001465, PyTorch loss: 6.923163414001465
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09163407981395721.
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09163407981395721.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 9.3028212357852e-14 and flax grad norm 1.6552796459901042e-13.
...
=========================================
Real bart-large and bart-large-cnn
=========================================
Check facebook/bart-large ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.00039315223693847656 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 15.027304649353027, PyTorch loss: 15.027304649353027
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 13.111018180847168 and flax grad norm 13.0546875.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 8.751346588134766 and flax grad norm 8.71875.
...
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel') has PT grad norm 18.60892105102539 and flax grad norm 18.59375.
...
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel') has PT grad norm 96.85579681396484 and flax grad norm 96.8125.
...
❌ Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel') has PT grad norm 199.41278076171875 and flax grad norm 199.25.
...
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.4212106691502413e-07 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.0100719311244575e-08 and flax grad norm 0.0.
...
=========================================
Check facebook/bart-large-cnn ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50264), PyTorch logits shape: torch.Size([2, 64, 50264])
✅ Difference between Flax and PyTorch is 0.0001919269561767578 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.262251853942871, PyTorch loss: 13.262249946594238
✅ Difference between Flax and PyTorch is 1.9073486328125e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194.
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.1513474735002092e-07 and flax grad norm 1.5481474235912174e-07.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.8047311079481005e-08 and flax grad norm 3.508952062247772e-08.
...
=========================================
After fix
Tiny Random
Check hf-internal-testing/tiny-random-roberta ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 1.7881393432617188e-07 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.887884140014648, PyTorch loss: 6.887884616851807
✅ Difference between Flax and PyTorch is 4.76837158203125e-07 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 7.584575871001642e-13 and flax grad norm 6.388195094436666e-13.
...
=========================================
Check hf-internal-testing/tiny-random-bert ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1124), PyTorch logits shape: torch.Size([2, 64, 1124])
✅ Difference between Flax and PyTorch is 1.7881393432617188e-07 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 7.036032199859619, PyTorch loss: 7.036032676696777
✅ Difference between Flax and PyTorch is 4.76837158203125e-07 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 5.234438642080785e-13 and flax grad norm 4.935363641205004e-13.
...
=========================================
Check hf-internal-testing/tiny-random-t5 ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1103), PyTorch logits shape: torch.Size([2, 64, 1103])
✅ Difference between Flax and PyTorch is 3.725290298461914e-09 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 7.006012916564941, PyTorch loss: 7.006012916564941
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
✅ All rel grads pass
=========================================
Check hf-internal-testing/tiny-random-bart ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 8.940696716308594e-08 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.919522285461426, PyTorch loss: 6.919522285461426
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.1293364247239035e-13 and flax grad norm 7.444291358479557e-14.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.9028742882613858e-13 and flax grad norm 1.0847509820726894e-13.
...
=========================================
Real models
Check roberta-base ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.00013017654418945312 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 14.801228523254395, PyTorch loss: 14.801219940185547
✅ Difference between Flax and PyTorch is 8.58306884765625e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 6.889232651019483e-08 and flax grad norm 5.7956174970286156e-08.
...
=========================================
Check bert-base-cased ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 28996), PyTorch logits shape: torch.Size([2, 64, 28996])
✅ Difference between Flax and PyTorch is 5.4836273193359375e-05 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.967159271240234, PyTorch loss: 13.967162132263184
✅ Difference between Flax and PyTorch is 2.86102294921875e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 8.025740783068613e-08 and flax grad norm 8.381563532111613e-08.
...
=========================================
Check t5-small ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 32128), PyTorch logits shape: torch.Size([2, 64, 32128])
✅ Difference between Flax and PyTorch is 7.62939453125e-05 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 20.534835815429688, PyTorch loss: 20.534835815429688
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
✅ All rel grads pass
=========================================
Check facebook/bart-large ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.0004191398620605469 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.993148803710938, PyTorch loss: 13.993138313293457
✅ Difference between Flax and PyTorch is 1.049041748046875e-05 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 11.655710220336914 and flax grad norm 11.6015625.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 7.740886211395264 and flax grad norm 7.71484375.
❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel') has PT grad norm 6.97633171081543 and flax grad norm 6.96484375.
...
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 8.274865592738934e-08 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.2391466458770992e-08 and flax grad norm 0.0.
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.3155640595578e-08 and flax grad norm 0.0.
...
=========================================
Check facebook/bart-large-cnn ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50264), PyTorch logits shape: torch.Size([2, 64, 50264])
✅ Difference between Flax and PyTorch is 0.0003502368927001953 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.418181419372559, PyTorch loss: 13.418176651000977
✅ Difference between Flax and PyTorch is 4.76837158203125e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.5387660091146245e-07 and flax grad norm 4.874667069998395e-07.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.254911966152576e-08 and flax grad norm 6.927437112835833e-08.
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.864935914701164e-08 and flax grad norm 6.345069891722233e-08.
...
=========================================