Modeling bart JAX vs Pytorch/Tensorflow implementation

I am comparing the implementation of Bart in JAX vs PyTorch/TensorFlow. I realized causal masking is not done in the JAX implementation compared to PyTorch and TensorFlow. Is there a reason for this?