In the upcoming book “Natural Language Processing with Transformers”, it’s teaching us how to do a classification task on sentences by using Transformers as Feature Extractors. We process the sentences through a transformer to get the hidden state.
To train a classifier, we take the token embedding for just the first token, namely the “[CLS]” and ignore the rest of the sentence. The book says that it’s common practice to do that.
It doesn’t make much sense to me to ignore the rest of the embeddings. Shouldn’t they be averaged or something?
The only reasoning I can think of is that the attention layers of the encoder make the CLS token absorb the meaningful context?
Thank you! The book is awesome by the way, highly recommended!
Hi @carlosaguayo thanks for your question and I’m glad you’re enjoying the book
In general, we need a way to represent the sequence of embeddings as as single vector and there are several “pooling” techniques that people use in the literature:
- [CLS] pooling: just take the embedding of the [CLS] token as the representation for the whole sequence
- mean pooling: take the average of token embeddings
- max pooling: take the token embedding with the largest values
A related question is whether pooling should be applied to the last hidden states, or some earlier layers (or concatenation thereof).
Now, which pooling method + layer(s) provides the best feature representation tends to depend on the task at hand, the domain of the data, length of the texts and so on. We picked [CLS] pooling in this early chapter because it’s simple and tends to be “good enough” for text classification tasks. You can find a nice ablation study that examined some of these issues here.
As to why this even works, you’re insight that it’s due to self-attention is spot on! Each token embedding in the sequence is contextualised through the attention mechanism, so the [CLS] token does contain information about subsequent tokens in the sequence (we explain this in more detail in Chapter 3).
Hope that helps!
It does help! Thank you for the explanation and the link!