Multi class text classification tutorial: how does he get away with one out_feature on linear layer?

I’ve read this tutorial: https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb

As I see it, the dataset uses 4 labels. I thought that that would imply having 4 out_features on your last linear layer. But when I check his model, he uses just 1.

Please help me to fix my misunderstanding :slight_smile:

1 Like

Hi @dickdanieljr, are you referring to this class?

class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 4)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

Here you can see that the last layer is torch.nn.Linear of shape [hidden_dim, num_labels] - or am I missing something?

Yeah, that looks ok. I don’t remember the issue I had after such a long time. I described it very poorly. Thank you for the message though!