Hi,
In the ViT paper, the author said they the standard learnable 1D position embeddings. I want to implement using Flax.
If the initial status of the embeds is random, I think I can just use the nn.Embed
class to initialize the embeds. But how do I apply the embeds to the inputs?
This is how I think it could be done. I wonder if it makes any sense.
class PositionEmbed(nn.Module):
dtype: Any = jnp.float32
@compact
def __call__(self, x):
'''
x: [N, L, D]
'''
embed = nn.Embed(x.shape[1], x.shape[-1])(jnp.arange(x.shape[1]))
batch_apply = jax.vmap(lambda x_: x_ + embed)
return batch_apply(x)