Fine-Tune BERT with two Classification Heads "next to each other"?

I am currently working on a project to fine-tune BERT models on a multi-class classification task with the goal to classify job ads into some broader categories like “doctors” or “sales” via AutoModelForSequenceClassification (which works quite well :slight_smile: ). Now I am wondering, whether it would be possible to add a second classification head “next” to the first one (not in sequence) to classify the minimum educational level that is required for the job. I imagine that each head is directly connected to the pooler output an then makes a prediction independent of the other’s prediction. I think my use-case is slightly different than a multi-label classification since both labels describe different aspects of the job ad.

Similar to this CV example: What is a multi-headed model? And what exactly is a ‘head’ in a model?

I hope it’s not total nonsense that I’m asking here :smiley:


Sure, you can just use any default model, e.g. BertModel and add your custom classification heads on top of that. Have a look at the existing iclassification implementation. You can basically duplicate that, but add another classifier layer. Of course you’ll also have to adapt the forward method accordingly.

1 Like

@BramVanroy thanks for the swift reply! So I tried to get my head around the first part, which is setting up the custom model but I am not sure if I fully understand what I have to do (still working in my pytorch and transformer skills…). These are my current adjustments to the class:

class CustomBertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        self.num_labels = config.num_labels
        self.config = config
        model_name = 'bert-base-german-cased'
        self.bert = BertPreTrainedModel.from_pretrained(model_name) #BertModel(config)

        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob

        self.dropout = nn.Dropout(classifier_dropout)

        self.classifier_job_type = nn.Linear(config.hidden_size, config.num_labels)
        self.classifier_edu_level = nn.Linear(config.hidden_size, config.num_labels)


I added a linear layers self.classifier_job_type and self.classifier_edu_level. The model looks like this now:

  (bert): BertPreTrainedModel()
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier_job_type): Linear(in_features=768, out_features=2, bias=True)
  (classifier_edu_level): Linear(in_features=768, out_features=2, bias=True)

So maybe you can tell me whether I am on the right track…

Something I don’t quite get is why the out_features are 2 are they coming from the config? How can I adjust the number of labels for both classification layers?

Thanks in advance!

Yes, you are on the right track! I encourage you to go look at the PyTorch documentation to better understand what you are doing. You’ll find that the second argument to nn.Linear is the output dimension. out_features. You have implemented this in such a way that it gets the default values from config. Instead, you can change those output features as you wish. E.g.

self.classifier_job_type = nn.Linear(config.hidden_size, 5)
self.classifier_edu_level = nn.Linear(config.hidden_size, 3)

It may seem from that printed model that those classification layers are sequential in the model, but do not worry, that is not the case. It simply represents the order in which you defined the layers. The “actual” data flow order is controlled by the forward method, that you still have to implement. It should go like this (in wors, so that you can implement it yourself):

  • Push the encoded inputs through self.bert
  • Push the output of self.bert through a dropout and then through the first classification layer
  • Push the output of self.bert through a dropout and then through the second classification layer
  • Calculate loss of output first classification and of output in second classification layer and then sum the loss
  • During training you then can do a single backward pass