Adding linear layer to transformer model (+ save_pretrained and load_pretrained)

I want to extend a transformer model (let’s take bert, electra, etc for example) with a linear layer and initialize the linear layer with the same initializer as the transformer model. I also want save_pretrained and load_pretrained to work smoothly (i.e. save and load the model WITH the linear layer, not separately). This is what I’m currently doing, but it’s not working

class ExtendedTransformer(PreTrainedModel):
    base_model_prefix = "model"
    def __init__(self, config):
        self.model = AutoModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Linear(self.config.hidden_size, 128, bias=False)

Any advice on how to do this properly?

1 Like

If you just want to increase the output dimensions, you can simply use

model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=128)

But here’s an explanation of what I think the issue is with your code.

If I’m not mistaken, your code it only initializes the model (which is already being initialized by the from_pretrained call). As far as I know, you should sub-class specific architectures, because PreTrainedModel itself does not have an init_weights method and models may implement a different scheme for that. Since the end of last year, it’s probably best to use post_init instead of init_weights, too. Maybe you can achieve this for a variety of models with AutoModel as well, but I am not sure.

So, the following should work. It is basically the same as BertForSequenceClassification.

from transformers import BertPreTrainedModel, AutoConfig, AutoModel
import torch

class ExtendedTransformer(BertPreTrainedModel):
    def __init__(self, config):
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Linear(self.config.hidden_size, 128, bias=False)

if __name__ == '__main__':
    config = AutoConfig.from_pretrained('bert-base-uncased')
    inst = ExtendedTransformer(config)
1 Like