How to implement learnable position embed?


In the ViT paper, the author said they the standard learnable 1D position embeddings. I want to implement using Flax.

If the initial status of the embeds is random, I think I can just use the nn.Embed class to initialize the embeds. But how do I apply the embeds to the inputs?

This is how I think it could be done. I wonder if it makes any sense.

class PositionEmbed(nn.Module):
    dtype: Any = jnp.float32

    def __call__(self, x):
        x: [N, L, D]
        embed = nn.Embed(x.shape[1], x.shape[-1])(jnp.arange(x.shape[1]))

        batch_apply = jax.vmap(lambda x_: x_ + embed)

        return batch_apply(x)