I am doing sentence pair classification. I am using BertForSequence classification.
My model is as follows:
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=5)
And my training loop looks like the below:
import numpy as np
EPOCHS = 5
criterion = nn.CrossEntropyLoss()
total_loss, total_accuracy = 0, 0
# empty list to save model predictions
total_preds=[]
for epoch in range(EPOCHS):
model.train()
total_train_loss = 0
total_train_acc = 0
for step,batch in enumerate(train_dataloader):
batch = [r.to(device) for r in batch]
input_id,attention_mask,token_type_id,y = batch
model.zero_grad()
prediction = model(input_id,attention_mask,token_type_id)
loss = criterion(prediction,y)
total_loss = total_loss + loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
preds=preds.detach().cpu().numpy()
total_preds.append(preds)
avg_loss = total_loss / len(train_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
print(avg_loss)
When I train the model, I get the following error :
TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not SequenceClassifierOutput
I am not able to figure out what is wrong here. Any suggestions?