Which loss function in bertforsequenceclassification regression

This is the GitHub link

At line 1354, you have the condition to check the labels (if it is one or more)
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

5 Likes