How to get [CLS] embeddings from BertForTokenClassification model

Sorry for the issue, I don’t really write any code but only use the example code as a tool.

I trained with my own NER dataset with the transformers example code.

I want to get sentence embedding from the model I trained with the token classification example code here (this is the older version of example code by the way.)

I want to get the sentence embedding from the trained model, which I think the [CLS] token embedding output should be one way.

This github issue answer answers exactly how to get an embedding from a BertModel (I can also get [CLS] token as the first token in sentence)
The answer code is copy paste below:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple  

So here comes my problem: How to get embeddings from BertForTokenClassification instead of BertModel? Can I simply replace the BertModel with BertForTokenClassification in the code and the expected output is what I wanted?

I suggest having a look at the hidden_states output of the model returned when output_hidden_states=True is passed to the forward call, or to the from_pretrained method while loading the model. The [CLS] embedding is the 0th slice of the output tensor of the last layer.

Hi @slecraphi

Just to elaborate on @ehalit’s correct approach here your example adapted for token classification:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors='pt')
outputs = model(**inputs, output_hidden_states=True)
last_hidden_states = outputs.hidden_states[-1]

The shape of last_hidden_states will be [batch_size, tokens, hidden_dim] so if you want to get the embedding of the first element in the batch and the [CLS] token you can get it with last_hidden_states[0,0,:].

Hope this helps!

1 Like