Create custom head? I think that's what I need to use custom features

Loving Transformers, and so far having some success.
I am wanting to do something along the lines of sequence classification and when using AutoModelForSequenceClassification with bert-base-uncased and tokenizing and the using Trainer, all is great.:smiley:
However, I want to see if I can get better results though by creating some custom features to augment the 768 outputs from bert and train on that.
So in my mind, I can use AutoModel and then use those weights and take it from there, presumably creating my own head?!?

From looking at the difference between AutoModel model and AutoModelForSequenceClassification model I see it is a dropout and then a linear layer from 768 to 2. Makes sense.

So first step I thought I would try and reproduce that and that is when I realised I didn’t understand as much as I thought I did :frowning:

I thought I could do something like this (inspired by other topics on this forum):

import torch.nn as nn
from transformers import AutoModel
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.base_model = AutoModel.from_pretrained(checkpoint)
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(768, 2)
        
    def forward(self, input_ids, token_type_ids, attention_mask, labels):
        outputs = self.base_model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

        outputs = self.dropout(outputs['last_hidden_state'])
        outputs = self.linear(outputs)
        
        return outputs

model = MyModel()
model.to('cuda')

But I get errors:

RuntimeError: grad can be implicitly created only for scalar outputs

And I realise i really don’t know where to look next.

Any guidance would be greatly appreciated, either on this approach, or in general how best to approach augmenting additional custom features.
For example, even if this works, how would I ‘inject’ my custom features into this model?

Thanks :slight_smile:

1 Like

@mkeywood Have you found any solution? If you have please, I need the solution for this very problem.

Thanks.