Example of how to pretrain T5?

Hi, I convert the parameters trained from JAX/FLAX to the pytorch version.
model = FlaxT5ForConditionalGeneration.from_pretrained(pretrained_path)
pt_model = T5ForConditionalGeneration.from_pretrained(tmp_path, from_flax=True)

However, some weights of T5ForConditionalGeneration were not initialized from the Flax model.
Here are the details.
All Flax model weights were used when initializing T5ForConditionalGeneration.
Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: [‘decoder.embed_tokens.weight’, ‘encoder.embed_tokens.weight’, ‘lm_head.weight’]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

I guess these three weights are shared.
So, I add three lines before saving the parameters.
pt_model.encoder.embed_tokens.weight.data = model.params[‘shared’][‘embedding’]._value
pt_model.decoder.embed_tokens.weight.data = model.params[‘shared’][‘embedding’]._value
pt_model.lm_head.weight.data = model.params[‘shared’][‘embedding’]._value
pt_model.save_pretrained(tmp_path)

Is this RIGHT?