It seems that resize_token_embeddings now is not available in FlaxPreTrainedModel although in the source-code of its subclass, for example modeling_flax_gpt2.py, include “resizing the input embeddings”:
GPT2_START_DOCSTRING = r"“”
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
1 Like
I implement this function like below, whether this is correct or not ? Assume that new_size is always greater than old_size as we often call this function when we add more tokens to tokenizer.
def resize_token_embeddings(model, new_size, rnd_key):
if model.config.vocab_size == new_size:
return
model.config.vocab_size = new_size
params = model.params
params = unfreeze(params)
old_embeddings = params['transformer']['wte']['embedding']
old_size = old_embeddings.shape[0]
dim = old_embeddings.shape[1]
initializer = jax.nn.initializers.normal(stddev=model.config.initializer_range)
new_embeddings = initializer(rnd_key, (new_size, dim))
new_embeddings = new_embeddings.at[:old_size].set(old_embeddings)
params['transformer']['wte']['embedding'] = new_embeddings
params = freeze(params)
model.params = params
3 Likes