T5x Model Checkpoint Surgery

I would like to use t5x framework to pre-finetune some models. Before doing that I will need to modify the embedding weights of a given checkpoint. For example, take this gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000 checkpoint - what I would like to do is to modify the t5_small_state_dict['target']['token_embedder']['embedding'] and save t5_small_state_dict back in a way that it can be used by t5x for training.

I also opened this issue however there is no response yet so I wanted to try my luck here in case anyone has some experience. I am new to jax, flax, t5x ecosystem and partitioning so I am trying to act cautious about doing something wrong along the way which might introduce some silent bug during training.

So far here are the things I’ve tried using the kaggle TPU VM (which recently started to fail probably due to some jax nightly conflicts).

  • Load t5x checkpoint using
t5_small_state_dict = checkpoints.load_t5x_checkpoint("gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000")
  • Modify the jax array
# set new token embedding array
t5_small_state_dict['target']['token_embedder']['embedding'] = final_new_emb_array

Not sure what to do next honestly.

Ideally it would be best to be able to just load the embedder partition file, modify it and save it back. Since I didn’t know how to do that exactly I just used the checkpoints.load_t5x_checkpoint utility function to load all the weights.