Pjit basics - flax

Hey @nbroad !

Sorry for being late here, I have a little experience with Pjit on optimizer states and model params I will try to answer you with the best of my knowledge.

First of all as you mentioned, it is important that you set correctly the partition rules. I can see that you did it correctly from the function _get_partition_rules_t5(). Also, please make sure that these names are correctly set, by using flax.layers.DenseGeneral layers, see for example here, by matching the kernel axis with the values that are set on the tuples above.

I can also see that you are using the

with mesh(mesh_devices, ("dp", "mp")):

context manager. I think that in the latest version of t5x, it is preferable to use the PjitPartitioner from t5x.partitioning. The way you initialize the partitioner is a bit tricky, see for instance this example - based from the bloom inference script example. Once this is initialized, you have to shard the parameters of your model using this partitioner, by calling the method partition, that takes as an argument a function and the param_specs that you can retrieve from the partitioner (see this example ).
In my project I did something like the following:

shard_params = self.student_partitioner.partition(lambda x: x, (self.student_params_spec,), self.student_params_spec)
self.student_params = shard_params(freeze(self.student_params))

You can double check that the process has been done successfully by printing the parameters after the process, and check that the dictionary contains SharedDeviceArray instead of DeviceArray.

So this was for the parameters partitioning, regarding the optimizer states the process is slightly similar. First of all, please use the latest version of t5x, as a large refactoring has been made since the deprecation of flax.optim in favor of optax, this impacted the t5x library where a special optimizer support has been made.

First of all, import t5x.optimizers as optimizers and get your favorite optimizer_def from this package, I used for instance adam. You can declare with something like:

optimizer_def = getattr(optimizers, "adam")(self.params.learning_rate)

Afterwards, import AxisMetadata class from

from flax.linen import partitioning
AxisMetadata = partitioning.AxisMetadata

Because the initialization of the t5x flax optimizer state requires to have an argument model_variables with a key param_axes that contains the axes specification that has to be an AxisMetadata class. Overall, this line should do the trick:

model_variables = flax.core.freeze({
            "params": model_params,
            "params_axes": jax.tree_map(lambda x: AxisMetadata(names=tuple(x)), model_params_spec)
        })

Then create your optimizer state by calling:


def init_state(model_variables):
            return FlaxOptimTrainState.create(
                optimizer_def=optimizer_def,
                model_variables=model_variables
            )
state = init_state(model_variables)

Again, you can confirm everything is correct by checking state.params, and make sure it contains SharedDeviceArray.

For the training procedure nothing crazy, compute your gradients in a jax style as usual, then apply the updates by doing:

state =state.apply_gradient(grad, learning_rate=self.params.learning_rate)

And confirm that the parameters are still SharedDeviceArray

Hope this helps! I would say good references for inference is the Flax bloom inference script: bloom-jax-inference/sharding_example.py at main · huggingface/bloom-jax-inference · GitHub , and for training I used to check the testing script from t5x: t5x/optimizers_test.py at main · google-research/t5x · GitHub