In the ViT paper, the author said they the standard learnable 1D position embeddings. I want to implement using Flax.
If the initial status of the embeds is random, I think I can just use the
nn.Embed class to initialize the embeds. But how do I apply the embeds to the inputs?
This is how I think it could be done. I wonder if it makes any sense.
class PositionEmbed(nn.Module): dtype: Any = jnp.float32 @compact def __call__(self, x): ''' x: [N, L, D] ''' embed = nn.Embed(x.shape, x.shape[-1])(jnp.arange(x.shape)) batch_apply = jax.vmap(lambda x_: x_ + embed) return batch_apply(x)