In Transformers 3.0.2 and prior, there used to be a small caveat accompanying the description of 3.0.2 BertModel.forward():
Returns
pooler_output (torch.FloatTensor: of shape (batch_size, hidden_size)):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training.
This output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.
However, the caveat was removed in 3.1.0 (current master). The description of 3.1.0 BertModel.forward() now just says:
Returns
pooler_output (torch.FloatTensor of shape (batch_size, hidden_size)) - Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.
Is there any deeply meaningful reason why that line was removed? Is there a new important secret to that [CLS] token that you’re not telling us?
That caveat was removed because this output is actually what is used for classification in the Sequence Classification model and experiments showed it was given the same kinds of results as the mean/average.
It’s probably because of the original paper. In the BERT paper, they discuss how NSP is beneficial for QA and NLI but that “[t]he vector C is not a meaningful sentence representation
without fine-tuning, since it was trained with NSP”, where C is “the final hidden
vector of the special [CLS] token” (p. 4).
In other words, CLS token without finetuning might not be a semantic representation of a given sentence because it was pretrained only on the NSP task. To use it in your own downstream task, you should finetune it. If you only use BERT as a feature extractor (e.g. for sentence embeddings), then you are probably better off by using mean/avg pooling. See this Twitter thread and the paper reference in there: https://twitter.com/JohnMGiorgi/status/1295472684353105920
Well, intuitively you would think that taking the average over all the last hidden states of all tokens captures a bit better the essence of the sentence than just taking the CLS token. The meaning of tokens after pretraining are usable (just take average and you have a relatively decent semantic representation), but CLS is specific to the NSP task.