Hello everyone,
I am trying to shard a model using pjit and flax, and I have some questions.
The first one has to do with the init_weights
function. I think I need to set _do_init=False
and then call model.init_weights(model.key, model.input_shape, params)
. That call needs to be pjit’d to split the model across devices and within the Mesh context manager. Is that right? I couldn’t find any examples of this.
After getting the sharded params, I then need to get the optimizer state. To get the optimizer state, I need to have an optimizer_state_spec which I can get using this.
def get_opt_spec(x):
if isinstance(x, dict):
return param_spec
return None
opt_state_spec, param_spec = jax.tree_util.tree_map(
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
This makes it look like only dict values get sharded and everything else is not. Is that correct?
Then I can get the optimizer state by doing this. Why do I need to freeze the params when passing into p_get_initial_state?
def get_initial_state(params):
state = optimizer.init(params)
return tuple(state), params
p_get_initial_state = pjit(
get_initial_state,
in_axis_resources=None,
out_axis_resources=(opt_state_spec, param_spec),
)
with mesh(mesh_devices, ("dp", "mp")):
opt_state, params = p_get_initial_state(freeze(params))
For the param_spec partitioning, I have these:
import re
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental import PartitionSpec as P
# utils coped from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
# Sentinels
_unmatched = object()
# For specifying empty leaf dict `{}`
empty_dict = object()
def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
for i in range(len(ks) - len(qs) + 1):
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
if matches and all(matches):
return True
return False
def _replacement_rules(rules):
def replace(key, val):
for rule, replacement in rules:
if _match(rule, key):
return replacement
return val
return replace
def _get_partition_rules_t5():
return [
(("SelfAttention", "relative_attention_bias", "embedding"), None),
(("shared", "embedding"), P("mp", None)),
((r"SelfAttention", "(q|k|v)", "kernel"), P(None, "mp")),
((r"SelfAttention", "o", "kernel"), P("mp", None)),
((r"EncDecAttention", "(q|k|v)", "kernel"), P(None, "mp")),
((r"EncDecAttention", "o", "kernel"), P("mp", None)),
((r"DenseReluDense", "wi_0", "kernel"), P(None, "mp")),
((r"DenseReluDense", "wi_1", "kernel"), P(None, "mp")),
((r"DenseReluDense", "wi", "kernel"), P(None, "mp")),
((r"DenseReluDense", "wo", "kernel"), P("mp", None)),
((r"layer_norm", "weight"), None),
((r"final_layer_norm", "weight"), None),
(("lm_head", "kernel"), P(None, "mp")),
]
def set_partitions_t5(in_dict):
rules = _get_partition_rules_t5()
replace = _replacement_rules(rules)
initd = {k: _unmatched for k in flatten_dict(in_dict)}
result = {k: replace(k, v) for k, v in initd.items()}
print(_unmatched)
assert _unmatched not in result.values(), "Incomplete partition spec."
return freeze(unflatten_dict(result))
I used these as references. The Hugging Face example might need to be updated.
Model parallel language model training example
Dalle-mini training script
@patrickvonplaten , @valhalla , @sanchit-gandhi, your assistance would be much appreciated.