sentence-transformers
were recently added in the model hub. I wanted to know the following if :
- all these models were trained on NLI (which dataset ? MNLI)
- for how many epochs ?
- If I want to use it for inference on NLI task, do I need to train an additional linear layer ? So freeze the
AutoModel
, train a nn.Linear
for 1 epoch on the mean
output and then perform inference. Can you please confirm this usage ?
tfmr = AutoModel.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")
class Model(nn.Module):
def __init__(self, model):
super(model, self).__init__()
self.encoder = model
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier = nn.Linear(768, 3)
self.criterion = nn.CrossEntropy()
def forward(self, inputs): # x will come from dataloader using default_data_collator
labels = inputs.pop('labels')
model_output = self.encoder(**inputs)
sentence_embeddings = mean_pooling(model_output, inputs['attention_mask'])
logits = self.classifier(sentence_embeddings)
loss = self.criterion(logits, labels)
return loss, logits
model = Model(tfmr)
1 Like
I trained it on MNLI with the model frozen, it gets around 46% accuracy after epoch 1. Seems like it has not been trained on MNLI. Can anyone confirm ?
from the paper, section 3.1
We train SBERT on the combination of the SNLI
(Bowman et al., 2015) and the Multi-Genre NLI dataset.
So yes, it’s trained on MNLI .
My guess is as it’s trained using Triplet loss to produce embeddings that are semantically meaningful
and can be compared with cosine-similarity, it might nor perform well on just classification .
Thanks for the information. When I read the paper, I initially had the impression that it was trained on STS task. Given that it has been trained on both SNLI and MNLI (not sure about the percentage), I’m getting 46% accuracy when I train a MLP for one epoch. I know the objective is a bit different (not classification) but I’d expect it to do well (in frozen condition) given that it has seen both NLI datasets. Things don’t improve much even after 3 epochs.
On what objective is this model trained on ? Section 3 of the paper mentions three objective functions classification
, regression
and triplet
. As per Fig 2. they seem to be using regression for computing similarity scores at inference. As per Fig 1. it says that they used classification
objective for fine-tuning. Based on the performance I’m getting, I doubt that these weights are coming from model trained on classification
. Can someone please clarify ?
I have the same confusion, it’s not clear what was the final objective from those 3
@joeddav If you’ve some idea, can you please confirm this ?
I’m afraid I don’t have any extra insight here. I might head over to the sentence-transformers repo and see if you can find an answer or open an issue there (and then loop us back in over here once you have an answer )
I’ve opened an issue. You can follow it if interested.