Is resize_token_embeddings available to the FlaxPreTrainedModel?

It seems that resize_token_embeddings now is not available in FlaxPreTrainedModel although in the source-code of its subclass, for example modeling_flax_gpt2.py, include “resizing the input embeddings”:
GPT2_START_DOCSTRING = r"“”

This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
1 Like

I implement this function like below, whether this is correct or not ? Assume that new_size is always greater than old_size as we often call this function when we add more tokens to tokenizer.

def resize_token_embeddings(model, new_size, rnd_key):
    if model.config.vocab_size == new_size:
        return
    model.config.vocab_size = new_size
    params = model.params
    params = unfreeze(params)
    old_embeddings = params['transformer']['wte']['embedding']
    old_size = old_embeddings.shape[0]
    dim = old_embeddings.shape[1]
    initializer = jax.nn.initializers.normal(stddev=model.config.initializer_range)
    new_embeddings = initializer(rnd_key, (new_size, dim))
    new_embeddings = new_embeddings.at[:old_size].set(old_embeddings)
    params['transformer']['wte']['embedding'] = new_embeddings
    params = freeze(params)
    model.params = params
3 Likes