Hey there everyone, I haven’t used this amazing API for a while. The code below is able to pool the [CLS] output embedding from DistilBERT?
distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased', return_dict=True)
# more code ...
# input_ids & attention_mask from a batch
outputs = distilbert(input_ids, attention_mask=attention_mask)
seq_embeddings = outputs[0] # DistilBERT output embeddings
cls_embedding = seq_embeddings[:,0,:] # <--- [CLS] output embedding
I’ve taken this idea from the following DPR code:
toyl
2
did you get a solution for your question, please