Pjit basics - flax

Hello everyone,

I am trying to shard a model using pjit and flax, and I have some questions.

The first one has to do with the init_weights function. I think I need to set _do_init=False and then call model.init_weights(model.key, model.input_shape, params) . That call needs to be pjit’d to split the model across devices and within the Mesh context manager. Is that right? I couldn’t find any examples of this.

After getting the sharded params, I then need to get the optimizer state. To get the optimizer state, I need to have an optimizer_state_spec which I can get using this.

def get_opt_spec(x):
    if isinstance(x, dict):
        return param_spec
    return None

opt_state_spec, param_spec = jax.tree_util.tree_map(
    get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))

This makes it look like only dict values get sharded and everything else is not. Is that correct?

Then I can get the optimizer state by doing this. Why do I need to freeze the params when passing into p_get_initial_state?

def get_initial_state(params):
    state = optimizer.init(params)
    return tuple(state), params

p_get_initial_state = pjit(
    out_axis_resources=(opt_state_spec, param_spec),

with mesh(mesh_devices, ("dp", "mp")):
    opt_state, params = p_get_initial_state(freeze(params))

For the param_spec partitioning, I have these:

import re
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental import PartitionSpec as P
# utils coped from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
# Sentinels
_unmatched = object()
# For specifying empty leaf dict `{}`
empty_dict = object()
def _match(qs, ks):
    """Return True if regexes in qs match any window of strings in tuple ks."""
    # compile regexes and force complete match
    qts = tuple(map(lambda x: re.compile(x + "$"), qs))
    for i in range(len(ks) - len(qs) + 1):
        matches = [x.match(y) for x, y in zip(qts, ks[i:])]
        if matches and all(matches):
            return True
    return False
def _replacement_rules(rules):
    def replace(key, val):
        for rule, replacement in rules:
            if _match(rule, key):
                return replacement
        return val
    return replace
def _get_partition_rules_t5():
    return [
        (("SelfAttention", "relative_attention_bias", "embedding"), None),
        (("shared", "embedding"), P("mp", None)),
        ((r"SelfAttention", "(q|k|v)", "kernel"), P(None, "mp")),
        ((r"SelfAttention", "o", "kernel"), P("mp", None)),
        ((r"EncDecAttention", "(q|k|v)", "kernel"), P(None, "mp")),
        ((r"EncDecAttention", "o", "kernel"), P("mp", None)),
        ((r"DenseReluDense", "wi_0", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wi_1", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wi", "kernel"), P(None, "mp")),
        ((r"DenseReluDense", "wo", "kernel"), P("mp", None)),
        ((r"layer_norm", "weight"), None),
        ((r"final_layer_norm", "weight"), None),
        (("lm_head", "kernel"), P(None, "mp")),
def set_partitions_t5(in_dict):
    rules = _get_partition_rules_t5()
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    assert _unmatched not in result.values(), "Incomplete partition spec."
    return freeze(unflatten_dict(result))

I used these as references. The Hugging Face example might need to be updated.
Model parallel language model training example
Dalle-mini training script

@patrickvonplaten , @valhalla , @sanchit-gandhi, your assistance would be much appreciated.