I have the following Flax module definition:
import flax
import jax
import jax.numpy as jnp
from flax import linen as nn
from transformers import AutoTokenizer, FlaxBertModel
from transformers.utils import logging
class Classifier(nn.Module):
def setup(self):
self.bert = FlaxBertModel.from_pretrained('bert-base-cased')
self.fc = nn.Dense(features=2)
def __call__(self, input_ids, attention_mask):
out = self.bert(input_ids, attention_mask)
out = out.pooler_output
out = self.fc(out)
out = jax.nn.log_softmax(out, axis=-1)
return out
With an equivalent pytorch definition, bert params will be shared with the Classifier parameters as trainable (unless marked as frozen in init). However in Flax, even after declaring bert in setup, I get the following set of params from the module pytree:
FrozenDict({
params: {
fc: {
kernel: DeviceArray([[ 0.00646114, 0.00364647],
[ 0.07954237, -0.06435367],
[-0.01362615, 0.02644117],
...,
[ 0.00158925, 0.02556718],
[-0.07334478, 0.03936611],
[-0.0254313 , -0.01591106]], dtype=float32),
bias: DeviceArray([0., 0.], dtype=float32),
},
},
})
Am I missing something here or is there some Flax specific method to add bert to the pytree? For reference, I’m using the following to init the params.
clf = Classifier()
key = jax.random.PRNGKey(0)
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
dummy_input = tokenizer.encode_plus(
"This is some dummy text",
add_special_tokens=True,
max_length=512,
truncation=True,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors="np"
)
params = clf.init(key, dummy_input["input_ids"], dummy_input["attention_mask"])