Sentence Embeddings From Fine-Tuned BERTForSequenceClassification

Hey everyone,

I have a binary classification task for a set of documents, and I’d like to visualize these documents from their embeddings. I’ve previously used the sentence-transformers library to do this, but I wanted to see if it was possible to improve these embeddings by fine-tuning my own BERT model to the particular task rather than just using a pre-trained model. I read through some guides and discussions online, and it seems like I should be able to use the embedding for the CLS token from the last hidden state layer as a sentence embedding. However, when I pull those values from the hidden_states of the fine-tuned BERTForSequenceClassification model, every embedding is the same.

This is the code I’m using to fine-tune the pre-trained model:

model = BertForSequenceClassification.from_pretrained("bert-base-cased",
optim = AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):
    for batch in dataloader_train:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]


And this is the code I’m using to pull the embeddings:

def embeddings(model, dataloader_val):
    embeddings = np.zeros((0, 768))
    for batch in dataloader_val:
        batch = tuple( for b in batch)
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],

        with torch.no_grad():        
            outputs = model(**inputs)
        embeddings = np.concatenate((embeddings, outputs[1][0][:,0,:].cpu().numpy()), axis=0)
    return embeddings

Any ideas or thoughts on why the embeddings for all of the CLS tokens would be the same?

One way around this problem that I was thinking of was to train via the sequence classification task and then load the trained model as a normal BertModel and used the normal pooler_output. Has anyone tried something like this?