Once I load a model (e.g. bert-base-uncased) in AutoModel, how can I access the internal weights and bias values ? How does the internal model parameter access look like ?
I can access the encoder parameters by model.encoder.named_parameters. But how to get one of the layer’s weight values ?
Hello,
To look at the structure of a class, you can use the __dict__ method.
model = AutoModel.from_pretrained('bert-base-uncased')
print(model.__dict__)
'''
'_state_dict_pre_hooks': OrderedDict(),
'_load_state_dict_pre_hooks': OrderedDict(),
'_load_state_dict_post_hooks': OrderedDict(),
'_modules': OrderedDict([('embeddings',
BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)),
('encoder', (...)
'''
From this you understand that the weights might be saved in the _modules variable. _modules is a orderedDict, and model._modules.keys() == odict_keys(['embeddings', 'encoder', 'pooler']).
Repeating this exploration method for ever leads to this horribe piece of code that you can use to get some weights of a layer:
layer0_attention_query_weight = (
model
._modules['encoder']
._modules['layer']
._modules['0']
._modules['attention']
._modules['self']
._modules['query']
._parameters['weight']
.detach()
.numpy())
If that answers your question…
If someone is looking for a simpler solution, use
model.state_dict().keys()
to see the names of the layers.
odict_keys(['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', ...
Then use
model.state_dict()['layer_name'].detach().numpy()
where layer_name is an item from the layers above to get the weights
ex:
model.state_dict()['bert.encoder.layer.1.attention.self.query.weight'].detach().numpy()
[[ 0.03132744 -0.02555237 0.00497124 ... -0.03382339 -0.05933356
0.06496567]
[-0.01340016 -0.02452595 -0.02806653 ... -0.03486649 0.05245861
0.04957511]
[-0.07761582 -0.0500178 -0.00823568 ... 0.01840313 0.01155226
-0.04699306]
...