Bert Model: IndexError: too many indices for tensor of dimension 2

Finetuning BERT on a single new added classifier layer with correct input and output dimensions but still failing to train

Error:

IndexError                                Traceback (most recent call last)
last_layer_finetune\llf_sst2.ipynb Cell 16 line 1
----> 1 train(model, train_dataloader, optimizer, 1, device)

e:\Internships\Applications & Cover Letters\CISPA Helmhotz\CISPA_FB_PPLLMs\experiments\last_layer_finetune\llf_sst2.ipynb Cell 16 line 1
     15 batch = {k: v.to(device) for k, v in batch.items()}
     16 optimizer.zero_grad()
---> 18 outputs = model(**batch)
     19 loss = outputs.loss
     20 loss.backward()

File c:\Users\KXIF\anaconda3\envs\workenv\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File c:\Users\KXIF\anaconda3\envs\workenv\lib\site-packages\torch\nn\modules\module.py:1568, in Module._call_impl(self, *args, **kwargs)
   1565     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1566     args = bw_hook.setup_input_hook(args)
-> 1568 result = forward_call(*args, **kwargs)
   1569 if _global_forward_hooks or self._forward_hooks:
   1570     for hook_id, hook in (
   1571         *_global_forward_hooks.items(),
   1572         *self._forward_hooks.items(),
   1573     ):
   1574         # mark that always called hook is run

File c:\Users\KXIF\anaconda3\envs\workenv\lib\site-packages\opacus\grad_sample\grad_sample_module.py:148, in GradSampleModule.forward(self, *args, **kwargs)
    147 def forward(self, *args, **kwargs):
--> 148     return self._module(*args, **kwargs)

File c:\Users\KXIF\anaconda3\envs\workenv\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File c:\Users\KXIF\anaconda3\envs\workenv\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

last_layer_finetune\llf_sst2.ipynb Cell 16 line 2
     23 last_hidden_state = outputs[0]
     25 sequence_outputs = self.dropout(last_hidden_state)
---> 26 logits = self.classifier(sequence_outputs[:, 0, :].view(-1, 768))
     28 loss = None
     30 if labels is not None:

IndexError: too many indices for tensor of dimension 2

Code:

from datasets import load_dataset

import torch
import torch.nn as nn
import numpy as np

from tqdm.notebook import tqdm
from torch.optim import SGD
from torch.utils.data import DataLoader

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, DataCollatorWithPadding
from transformers.modeling_outputs import TokenClassifierOutput

from sklearn.metrics import accuracy_score

model_name = "bert-base-uncased"
EPOCHS = 22
BATCH_SIZE = 256
LR = 0.0001

# Prepare data
dataset = load_dataset("glue", "sst2")
num_labels = dataset["train"].features["label"].num_classes

tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenized_dataset = dataset.map(
    lambda example: tokenizer(example["sentence"], max_length=128, padding='max_length', truncation=True),
    batched=True
)

tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_dataset = tokenized_dataset.remove_columns(['idx'])

train_dataloader = DataLoader(
    tokenized_dataset['train'],
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
    shuffle=True
)

test_dataloader = DataLoader(
    tokenized_dataset['validation'],
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
)

class ClassifierHeadLayer (nn.Module):

    def __init__(self, model_name, num_labels):
        super(ClassifierHeadLayer, self).__init__()
        self.num_labels = num_labels

        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config = AutoConfig.from_pretrained(model_name, num_labels=num_labels, 
                                                                                                       output_attention = True, 
                                                                                                       output_hidden_state = True))
        
        # Freeze all original layers

        for param in self.model.parameters():
            param.requires_grad = False

        # New Layer
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None):

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]

        sequence_outputs = self.dropout(last_hidden_state)
        logits = self.classifier(sequence_outputs[:, 0, :].view(-1, 768))

        loss = None

        if labels is not None:
            loss_func = nn.CrossEntropyLoss()
            loss = loss_func(logits.view(-1, self.num_labels), labels.view(-1))

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ClassifierHeadLayer(model_name=model_name, num_labels=num_labels)

model.to(device)

optimizer = SGD(params=model.parameters(), lr=LR)

def train(model, train_dataloader, optimizer, epoch, device):
    model.train()

    losses = []
    epsilon = []

    for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Training Epoch: {epoch}"):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        losses.append(loss.item())

        if i % 8000 == 0:
            epsilon = privacy_engine.get_epsilon(DELTA)

            print(f"Training Epoch: {epoch} | Loss: {np.mean(losses):.6f}")

train(model, train_dataloader, optimizer, 1, device)