I have created my own BertClassifier model, starting from a pretrained and then added my own classification heads composed by different layers. After the training I want to save the model using model.save_pretrained() but when I print it upload it from pretrained i don’t see my classifier head.
The code is the following. How can I save the all structure on my model and make it full accessible with AutoModel.from_preatrained('folder_path')
?
. Thanks!
class BertClassifier(PreTrainedModel):
"""Bert Model for Classification Tasks."""
config_class = AutoConfig
def __init__(self,config, freeze_bert=True): #tuning only the head
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
#super(BertClassifier, self).__init__()
super().__init__(config)
# Instantiate BERT model
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
self.D_in = 1024 #hidden size of Bert
self.H = 512
self.D_out = 2
# Instantiate the classifier head with some one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(self.D_in, 512),
nn.Tanh(),
nn.Linear(512, self.D_out),
nn.Tanh()
)
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task
last_hidden_state_cls = outputs[0][:, 0, :]
# Feed input to classifier to compute logits
logits = self.classifier(last_hidden_state_cls)
return logits
configuration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)
after training
model.save_pretrained('path')
If I print the model after model = AutoModel.from_pretrained(‘path’) I have as the last layer the following and missing my 2 linear layer:
(output): BertOutput(
(dense): Linear(in_features=4096, out_features=1024, bias=True)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(adapters): ModuleDict()
(adapter_fusion_layer): ModuleDict()
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(activation): Tanh()
)
(prefix_tuning): PrefixTuningPool(
(prefix_tunings): ModuleDict()
)
)