In the original mesh-transformer-jax code , the embedding layer is implemented as haiku Linear layer, which has a bias parameter by default:
However, in HF transformers code, the embedding layer is implemented as plain nn.Embedding, which has no bias:
Is this a bug when porting the mesh-transformer-jax gpt-j-6B model to HF?