Pjit basics - flax

Hello everyone,

I am trying to shard a model using pjit and flax, and I have some questions.

The first one has to do with the init_weights function. I think I need to set _do_init=False and then call model.init_weights(model.key, model.input_shape, params) . That call needs to be pjit’d to split the model across devices and within the Mesh context manager. Is that right? I couldn’t find any examples of this.

After getting the sharded params, I then need to get the optimizer state. To get the optimizer state, I need to have an optimizer_state_spec which I can get using this.

def get_opt_spec(x):
    if isinstance(x, dict):
        return param_spec
    return None

opt_state_spec, param_spec = jax.tree_util.tree_map(
    get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))

This makes it look like only dict values get sharded and everything else is not. Is that correct?

Then I can get the optimizer state by doing this. Why do I need to freeze the params when passing into p_get_initial_state?

def get_initial_state(params):
    state = optimizer.init(params)
    return tuple(state), params

p_get_initial_state = pjit(
    out_axis_resources=(opt_state_spec, param_spec),

with mesh(mesh_devices, ("dp", "mp")):
    opt_state, params = p_get_initial_state(freeze(params))

For the param_spec partitioning, I have these:

import re
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental import PartitionSpec as P
# utils coped from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
# Sentinels
_unmatched = object()
# For specifying empty leaf dict `{}`
empty_dict = object()
def _match(qs, ks):
    """Return True if regexes in qs match any window of strings in tuple ks."""
    # compile regexes and force complete match
    qts = tuple(map(lambda x: re.compile(x + "$"), qs))
    for i in range(len(ks) - len(qs) + 1):
        matches = [x.match(y) for x, y in zip(qts, ks[i:])]
        if matches and all(matches):
            return True
    return False
def _replacement_rules(rules):
    def replace(key, val):
        for rule, replacement in rules:
            if _match(rule, key):
                return replacement
        return val
    return replace
def _get_partition_rules_t5():
    return [
        (("SelfAttention", "relative_attention_bias", "embedding"), None),
        (("shared", "embedding"), P("mp", None)),
        ((r"SelfAttention", "(q|k|v)", "kernel"), P(None, "mp")),
        ((r"SelfAttention", "o", "kernel"), P("mp", None)),
        ((r"EncDecAttention", "(q|k|v)", "kernel"), P(None, "mp")),
        ((r"EncDecAttention", "o", "kernel"), P("mp", None)),
        ((r"DenseReluDense", "wi_0", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wi_1", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wi", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wo", "kernel"), P("mp", None)),
        ((r"layer_norm", "weight"), None),
        ((r"final_layer_norm", "weight"), None),
        (("lm_head", "kernel"), P(None, "mp")),
def set_partitions_t5(in_dict):
    rules = _get_partition_rules_t5()
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    assert _unmatched not in result.values(), "Incomplete partition spec."
    return freeze(unflatten_dict(result))

I used these as references. The Hugging Face example might need to be updated.
Model parallel language model training example
Dalle-mini training script

@patrickvonplaten , @valhalla , @sanchit-gandhi, your assistance would be much appreciated.

1 Like

Hey @nbroad !

Sorry for being late here, I have a little experience with Pjit on optimizer states and model params I will try to answer you with the best of my knowledge.

First of all as you mentioned, it is important that you set correctly the partition rules. I can see that you did it correctly from the function _get_partition_rules_t5(). Also, please make sure that these names are correctly set, by using flax.layers.DenseGeneral layers, see for example here, by matching the kernel axis with the values that are set on the tuples above.

I can also see that you are using the

with mesh(mesh_devices, ("dp", "mp")):

context manager. I think that in the latest version of t5x, it is preferable to use the PjitPartitioner from t5x.partitioning. The way you initialize the partitioner is a bit tricky, see for instance this example - based from the bloom inference script example. Once this is initialized, you have to shard the parameters of your model using this partitioner, by calling the method partition, that takes as an argument a function and the param_specs that you can retrieve from the partitioner (see this example ).
In my project I did something like the following:

shard_params = self.student_partitioner.partition(lambda x: x, (self.student_params_spec,), self.student_params_spec)
self.student_params = shard_params(freeze(self.student_params))

You can double check that the process has been done successfully by printing the parameters after the process, and check that the dictionary contains SharedDeviceArray instead of DeviceArray.

So this was for the parameters partitioning, regarding the optimizer states the process is slightly similar. First of all, please use the latest version of t5x, as a large refactoring has been made since the deprecation of flax.optim in favor of optax, this impacted the t5x library where a special optimizer support has been made.

First of all, import t5x.optimizers as optimizers and get your favorite optimizer_def from this package, I used for instance adam. You can declare with something like:

optimizer_def = getattr(optimizers, "adam")(self.params.learning_rate)

Afterwards, import AxisMetadata class from

from flax.linen import partitioning
AxisMetadata = partitioning.AxisMetadata

Because the initialization of the t5x flax optimizer state requires to have an argument model_variables with a key param_axes that contains the axes specification that has to be an AxisMetadata class. Overall, this line should do the trick:

model_variables = flax.core.freeze({
            "params": model_params,
            "params_axes": jax.tree_map(lambda x: AxisMetadata(names=tuple(x)), model_params_spec)

Then create your optimizer state by calling:

def init_state(model_variables):
            return FlaxOptimTrainState.create(
state = init_state(model_variables)

Again, you can confirm everything is correct by checking state.params, and make sure it contains SharedDeviceArray.

For the training procedure nothing crazy, compute your gradients in a jax style as usual, then apply the updates by doing:

state =state.apply_gradient(grad, learning_rate=self.params.learning_rate)

And confirm that the parameters are still SharedDeviceArray

Hope this helps! I would say good references for inference is the Flax bloom inference script: bloom-jax-inference/sharding_example.py at main · huggingface/bloom-jax-inference · GitHub , and for training I used to check the testing script from t5x: t5x/optimizers_test.py at main · google-research/t5x · GitHub

@nbroad Did you solve this problem? I’m trying to shard flax model as well. Do you have open source project to share the learning? Thx.