Gradients verification between JAX/Flax models and PyTorch

To make sure the gradients are computed identically in Flax and PyTorch I’ve run a script that checks for the gradient norm.

Bash script: run_models.sh · patrickvonplaten/codesnippets at main
Python script: check_gradients_pt_flax.py · patrickvonplaten/codesnippets at main

on the small random weights of BERT, RoBERTa, T5, and BART as well as on pretrained weights to see if there are any mismatches. I’ve noticed a bug in FlaxBart: [FlaxBart] make sure no grads are computed an bias by patrickvonplaten · Pull Request #16345 · huggingface/transformers · GitHub that caused a strong mismatch in the gradient of the logits bias.

After having solved this bug all the gradients for both tiny as well as pretrained weigths seem to be correctly computed. There are only three notable differences:

  • facebook/bart-large has very high gradient norms that seem to quickly differ between PyTorch and Flax
  • The bias some attention matrices has very low gradient norms that have small relative differences between PyTorch and Flax. This should however be completely fine.

=> As a conclusion, after [FlaxBart] make sure no grads are computed an bias by patrickvonplaten · Pull Request #16345 · huggingface/transformers · GitHub is merged Flax and PyTorch gradients for BERT, RoBERTa, T5, and BART can be considered equal and we should be careful with facebook/bart-large given it’s very high gradients.

Results for Bart before fix:

Tiny-random

=========================================
Check hf-internal-testing/tiny-random-bart ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 8.940696716308594e-08 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.923163414001465, PyTorch loss: 6.923163414001465
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09163407981395721.
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09163407981395721.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 9.3028212357852e-14 and flax grad norm 1.6552796459901042e-13.
...
=========================================

Real bart-large and bart-large-cnn

=========================================
Check facebook/bart-large ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.00039315223693847656 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 15.027304649353027, PyTorch loss: 15.027304649353027
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 13.111018180847168 and flax grad norm 13.0546875.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 8.751346588134766 and flax grad norm 8.71875.
...
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel') has PT grad norm 18.60892105102539 and flax grad norm 18.59375.
...
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel') has PT grad norm 96.85579681396484 and flax grad norm 96.8125.
...
❌ Layer ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel') has PT grad norm 199.41278076171875 and flax grad norm 199.25.
...
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09944064915180206.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.4212106691502413e-07 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.0100719311244575e-08 and flax grad norm 0.0.
...
=========================================
Check facebook/bart-large-cnn ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50264), PyTorch logits shape: torch.Size([2, 64, 50264])
✅ Difference between Flax and PyTorch is 0.0001919269561767578 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.262251853942871, PyTorch loss: 13.262249946594238
✅ Difference between Flax and PyTorch is 1.9073486328125e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194.
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.09764379262924194.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 2.1513474735002092e-07 and flax grad norm 1.5481474235912174e-07.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 3.8047311079481005e-08 and flax grad norm 3.508952062247772e-08.
...
=========================================

After fix

Tiny Random

Check hf-internal-testing/tiny-random-roberta ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 1.7881393432617188e-07 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.887884140014648, PyTorch loss: 6.887884616851807
✅ Difference between Flax and PyTorch is 4.76837158203125e-07 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 7.584575871001642e-13 and flax grad norm 6.388195094436666e-13.
...
=========================================
Check hf-internal-testing/tiny-random-bert ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1124), PyTorch logits shape: torch.Size([2, 64, 1124])
✅ Difference between Flax and PyTorch is 1.7881393432617188e-07 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 7.036032199859619, PyTorch loss: 7.036032676696777
✅ Difference between Flax and PyTorch is 4.76837158203125e-07 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 5.234438642080785e-13 and flax grad norm 4.935363641205004e-13.
...
=========================================
Check hf-internal-testing/tiny-random-t5 ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1103), PyTorch logits shape: torch.Size([2, 64, 1103])
✅ Difference between Flax and PyTorch is 3.725290298461914e-09 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 7.006012916564941, PyTorch loss: 7.006012916564941
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
✅ All rel grads pass
=========================================
Check hf-internal-testing/tiny-random-bart ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 1000), PyTorch logits shape: torch.Size([2, 64, 1000])
✅ Difference between Flax and PyTorch is 8.940696716308594e-08 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 6.919522285461426, PyTorch loss: 6.919522285461426
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 1.1293364247239035e-13 and flax grad norm 7.444291358479557e-14.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 1.9028742882613858e-13 and flax grad norm 1.0847509820726894e-13.
...
=========================================

Real models

Check roberta-base ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.00013017654418945312 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 14.801228523254395, PyTorch loss: 14.801219940185547
✅ Difference between Flax and PyTorch is 8.58306884765625e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('roberta', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 6.889232651019483e-08 and flax grad norm 5.7956174970286156e-08.
...
=========================================
Check bert-base-cased ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 28996), PyTorch logits shape: torch.Size([2, 64, 28996])
✅ Difference between Flax and PyTorch is 5.4836273193359375e-05 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.967159271240234, PyTorch loss: 13.967162132263184
✅ Difference between Flax and PyTorch is 2.86102294921875e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias') has PT grad norm 8.025740783068613e-08 and flax grad norm 8.381563532111613e-08.
...
=========================================
Check t5-small ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 32128), PyTorch logits shape: torch.Size([2, 64, 32128])
✅ Difference between Flax and PyTorch is 7.62939453125e-05 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 20.534835815429688, PyTorch loss: 20.534835815429688
✅ Difference between Flax and PyTorch is 0.0 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
✅ All rel grads pass
=========================================
Check facebook/bart-large ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50265), PyTorch logits shape: torch.Size([2, 64, 50265])
✅ Difference between Flax and PyTorch is 0.0004191398620605469 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.993148803710938, PyTorch loss: 13.993138313293457
✅ Difference between Flax and PyTorch is 1.049041748046875e-05 (< 0.01)
--------------------------Checking gradients match--------------------------
❌ Layer ('model', 'decoder', 'layers', '0', 'fc1', 'kernel') has PT grad norm 11.655710220336914 and flax grad norm 11.6015625.
❌ Layer ('model', 'decoder', 'layers', '0', 'fc2', 'kernel') has PT grad norm 7.740886211395264 and flax grad norm 7.71484375.
❌ Layer ('model', 'decoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel') has PT grad norm 6.97633171081543 and flax grad norm 6.96484375.
...
--------------------------Checking rel gradients match--------------------------
❌ Layer ('final_logits_bias',) has PT grad norm 0.0 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 8.274865592738934e-08 and flax grad norm 0.0.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 2.2391466458770992e-08 and flax grad norm 0.0.
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 8.3155640595578e-08 and flax grad norm 0.0.
...
=========================================
Check facebook/bart-large-cnn ...
--------------------------Checking logits match--------------------------
Flax logits shape: (2, 64, 50264), PyTorch logits shape: torch.Size([2, 64, 50264])
✅ Difference between Flax and PyTorch is 0.0003502368927001953 (< 0.01)
--------------------------Checking losses match--------------------------
Flax loss: 13.418181419372559, PyTorch loss: 13.418176651000977
✅ Difference between Flax and PyTorch is 4.76837158203125e-06 (< 0.01)
--------------------------Checking gradients match--------------------------
✅ All grads pass
--------------------------Checking rel gradients match--------------------------
❌ Layer ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias') has PT grad norm 3.5387660091146245e-07 and flax grad norm 4.874667069998395e-07.
❌ Layer ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 6.254911966152576e-08 and flax grad norm 6.927437112835833e-08.
❌ Layer ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias') has PT grad norm 5.864935914701164e-08 and flax grad norm 6.345069891722233e-08.
...
=========================================

Why do we create bias parameters then if we won’t update them?
Can we just remove them?

Very good point and removing them is clearly the cleaner solution! The problem is that we can’t convert PT-BART to JAX-BART then anymore. We cannot update all existing BART checkpoints sadly to correct this mistake.

I really don’t know why and how this bias was added to BART in the first place