ValueError: operands could not be broadcast together with shapes (1,2048,51200) (20,2,1,16,2048,64)

Hi, I am trying to train a code completion model using salesforce/codegen-350M-mono as a base. However my code gives me a
ValueError: operands could not be broadcast together with shapes (1,2048,51200) (20,2,1,16,2048,64)

Here is the code:

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from tensorflow.keras.optimizers import Adam
import numpy as np
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
import evaluate 

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
ds = load_dataset("codeparrot/github-code", streaming=True, split="train")

dataset = []
samples = 8
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

for x in iter(ds):
  if samples > 0:
    samples -= 1
    dataset.append(x)
  else:
    break

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

n = len(dataset)
k = int(0.8 * n)
l = int(0.2 * n)

train_dataset = dataset[0:k]
eval_dataset = dataset[k:-1]

def tokenize_function(examples):
    return tokenizer(examples['code'], examples['code'], padding='max_length', truncation=True)

train = []
eval = []

for element in train_dataset:
  train.append(tokenize_function(element))

for element in eval_dataset:
  eval.append(tokenize_function(element))

training_args = TrainingArguments(output_dir="cooder-large", evaluation_strategy="epoch", per_device_train_batch_size=1, per_device_eval_batch_size=1)
metric = evaluate.load("bleu")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

trainer.train()

Thanks in advance.

Hello, I was wondering if you found the solution for this error.