Sentence Pair Classification

I am doing a sentence pair classification where based on two sentences I have to classify the label of the sentence.

After I created my train and test data I converted both the sentences to a list and applied BERT tokenizer as

train_encode = tokenizer(train1, train2,padding="max_length",truncation=True)
test_encode = tokenizer(test1, test2,padding="max_length",truncation=True)

After that, I extracted the only required things as

train_seq = torch.tensor(train_encode['input_ids'])
train_mask = torch.tensor(train_encode['attention_mask'])
train_token = torch.tensor(train_encode['token_type_ids'])

I did the same for the test data and created a data loader as

train_data = TensorDataset(train_seq, train_mask, train_token, train_y)
train_dataloader = DataLoader(train_data, batch_size=32,shuffle=True)

Created model as

model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=5)

and the training loop is

train_losses = []
num_mb_train = len(train_dataloader)

import torch.nn as nn
import numpy as np


EPOCHS = 5
criterion = nn.CrossEntropyLoss()


# empty list to save model predictions
total_preds=[]

for epoch in range(EPOCHS):
  train_loss = 0


  
  for step,batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    model.train()

    batch = [r.to(device) for r in batch]
    input_id,attention_mask,token_type_id,y = batch
    

    
    prediction = model(input_id,attention_mask= attention_mask,token_type_ids=token_type_id, labels=y)

    loss = prediction[0]
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
        
    train_loss += loss.data / num_mb_train

    print("\nTrain loss after itaration %i: %f" % (epoch+1, train_loss))
    train_losses.append(train_loss.cpu())

Once I run my training loop it says

RuntimeError: CUDA out of memory.

I am not sure where I am missing and what should I do? Any help would be much appreciated.

@sgugger

Have you tried reducing your batch_size and/or max_length? It looks like you’re running out of memory.