How to Resolve batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) TypeError: not a sequence

I am training a simple NER model on custom data-set. I get error while using data-loader.

Here is my complete code.

import os
import warnings
from collections import Counter
import tqdm
import random
warnings.filterwarnings('ignore')
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"]= "true"
from torchcrf import CRF
from transformers import BertTokenizerFast as BertTokenizer
from transformers import BertForTokenClassification
import torch.nn as nn
import torch.nn.functional as F
log_soft = F.log_softmax
from transformers import (Trainer,TrainingArguments)
from torch.utils.data import TensorDataset
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.empty_cache()
train_data=[
{'text': "My name is Jon. I live in Germany.",
'spans': [{'start': 12, 'end': 14, 'label': 'name', 'ngram': 'Jon'},
          {'start': 27, 'end': 33, 'label': 'country', 'ngram': 'Germany'}
          ]
 },

{'text': "My name is Jony. I live in Russia.",
'spans': [{'start': 12, 'end': 15, 'label': 'name', 'ngram': 'Jony'},
          {'start': 28, 'end': 33, 'label': 'country', 'ngram': 'Russia'}
          ]
 },
{'text': "My name is Tony. I live in Poland.",
'spans': [{'start': 12, 'end': 15, 'label': 'name', 'ngram': 'Tony'},
          {'start': 28, 'end': 33, 'label': 'country', 'ngram': 'Poland'}
          ]
 },
{'text': "My name is Yun. I live in Holland.",
'spans': [{'start': 12, 'end': 14, 'label': 'name', 'ngram': 'Yun'},
          {'start': 27, 'end': 33, 'label': 'country', 'ngram': 'Holland'}
          ]
 }
]

model_checkpoint = "SpanBERT/spanbert-base-cased"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)

def isin(a, b):
    return a[1] > b[0] and a[0] < b[1]
def tokenize_and_align_labels(examples, label2id, max_length=10):
    tokenized_inputs = tokenizer(examples["texts"], truncation=True,padding='max_length', max_length=max_length,
                                 return_offsets_mapping=True,return_tensors="pt")

    labels = []
    for i, label_idx_for_single_input in enumerate(tqdm.tqdm(examples["tag_names"])):

        labels_for_single_input = ['O' for _ in range(max_length)]      
        text_offsets = tokenized_inputs['offset_mapping'][i]
     
        for entity in label_idx_for_single_input:
           
            tag = entity['tag']
            tag_offset = [entity['start'], entity['end']]
            affected_token_ids = [j for j in range(max_length) if isin(tag_offset, text_offsets[j])]

            if len(affected_token_ids) < 1:     
                continue
            if any(labels_for_single_input[j] != 'O' for j in affected_token_ids):
                continue

            for j in affected_token_ids:
                labels_for_single_input[j] = 'I-' + tag

            labels_for_single_input[affected_token_ids[-1]] = 'I-' + tag
           labels_for_single_input[affected_token_ids[0]] = 'B-' + tag

        label_ids = [label2id[x] for x in labels_for_single_input]
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    print(tokenized_inputs.keys())
    return tokenized_inputs
train_set = [
    [
        x['text'],
        [{'start': y["start"], 'end': y["end"], 'tag': y["label"], 'text': y["ngram"]} for y in x['spans']]
    ] for x in train_data
]
ori_label_list = []
for line in train_set:
    ori_label_list += [entity['tag'] for entity in line[1]]

ori_label_list = sorted(list(set(ori_label_list)))

label_list = []
for prefix in 'BI':
    label_list += [prefix + '-' + x for x in ori_label_list]
label_list += ['O']
label_list = sorted(list(set(label_list)))
label2id = {n:i for i,n in enumerate(label_list)}
id2label= {i:n for i,n in enumerate(label_list)}

train_examples ={'texts':[x[0] for x in train_set],'tag_names':[x[1] for x in train_set]}
train_data = tokenize_and_align_labels(train_examples,label2id)
_=train_data.pop('offset_mapping')
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.encodings = examples
        print(self.encodings)
        print()
        self.labels = examples['labels']

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        print(item)
        item["labels"] = torch.tensor([self.labels[idx]])
        return item

    def __len__(self):

        return len(self.labels)
train_data2=MyDataset(train_data)
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer,return_tensors="pt")
bert_model = BertForTokenClassification.from_pretrained(
                        model_checkpoint,
                        id2label=id2label,
                        label2id=label2id
)
bert_model.config.output_hidden_states=True
class BERT_CRF(nn.Module):
    

    def __init__(self, bert_model, num_labels):
       
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.config = self.bert.config
        self.dropout = nn.Dropout(0.25)     
        self.classifier = nn.Linear(768, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
       

        sequence_output = torch.stack((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4])).mean(dim=0)
        
        sequence_output = self.dropout(sequence_output)
      
        emission = self.classifier(sequence_output)  # [32,256,17]

        labels = labels.reshape(attention_mask.size()[0], attention_mask.size()[1])
      

        if labels is not None:
        
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
       
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
          
            print([loss, prediction])
            return [loss, prediction]

        else:
      
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction
model = BERT_CRF(bert_model, num_labels=len(label2id))
model.to(device)
args = TrainingArguments(
    "spanbert_",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=2,
    weight_decay=0.01,
    per_device_train_batch_size=2,

)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data2,
    data_collator=data_collator,
    tokenizer=tokenizer)

trainer.train()

Error

trainer.train()
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1527, in train
    return inner_training_loop(
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1749, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/transformers/trainer_utils.py", line 696, in __call__
{'input_ids': tensor([ 101, 1139, 1271, 1110,  194, 3488,  119,  178, 1686,  102]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([4, 4, 4, 4, 4, 1, 4, 4, 4, 4])}
{'input_ids': tensor([ 101, 1139, 1271, 1110,  179, 1320,  119,  178, 1686,  102]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([4, 4, 4, 4, 4, 1, 4, 4, 4, 4])}
    return self.data_collator(features)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/transformers/data/data_collator.py", line 45, in __call__
    return self.torch_call(features)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/transformers/data/data_collator.py", line 339, in torch_call
    batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
TypeError: not a sequence
  0%|          | 0/4 [00:00<?, ?it/s]