How to catch Up with the GPT2 based model. at each iteration the size of the model increases

Error: Expected input batch_size (28) to match target batch_size (456), Changing batch size increase the target batch size with GPT2 model.

def tokenize_data(total_marks, coding_feeddback):
    inputs = tokenizer(total_marks, truncation=True, padding=True, 
                                                     return_tensors="pt")
    labels = tokenizer(coding_feeddback, truncation=True, padding=True, 
                                                    return_tensors="pt")['input_ids']
return inputs, labels

*# Prepare the training and validation datasets*

 train_inputs, train_labels = tokenize_data(train_df['Question'].tolist(), 
 train_df['ans'].tolist())
 val_inputs, val_labels = tokenize_data(val_df['Question'].tolist(), 
 val_df['ans'].tolist())

 train_dataset = TensorDataset(train_inputs['input_ids'], train_labels)
 val_dataset = TensorDataset(val_inputs['input_ids'], val_labels)

batch_size = 4
train_dataloader = DataLoader(train_dataset, 
                   batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

Training loop

model.train()
for epoch in range(num_epochs):
 for batch in train_dataloader:
    batch = [item.to(device) for item in batch]
    input_ids, labels = batch

    optimizer.zero_grad()
    
    print("indputIds:",len(input_ids))
    print("lebels:",len(labels))

    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs.loss
    logits = outputs.logits

    loss.backward()
    optimizer.step()

# Validation
with torch.no_grad():
    model.eval()
    val_loss = 0.0
    for val_batch in val_dataloader:
        val_batch = [item.to(device) for item in val_batch]
        val_input_ids, val_labels = val_batch

        val_outputs = model(input_ids=val_input_ids, labels=val_labels)
        val_loss += val_outputs.loss.item()

    average_val_loss = val_loss / len(val_dataloader)
    print(f"Epoch: {epoch+1}, Validation Loss: {average_val_loss:.4f}")

model.train()