Finetuning Vision Encoder Decoder Models with huggingface causes ValueError: expected sequence of length 11 at dim 2 (got 12)

Input code that causes code failing:

from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, ViTFeatureExtractor, AutoTokenizer
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel, default_data_collator
from datasets import load_dataset, DatasetDict
encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "bert-base-uncased"


model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_checkpoint, decoder_checkpoint
)
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 512
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
model.decoder.resize_token_embeddings(len(tokenizer))

feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)

Preparing dataset

dataset = load_dataset("svjack/pokemon-blip-captions-en-zh").remove_columns("zh_text")
dataset = dataset.map(lambda example: {'pixel_values': feature_extractor(example['image'], return_tensors='pt').pixel_values})
dataset = dataset.remove_columns("image")
dataset = dataset.map(lambda example: {'labels': tokenizer(example['en_text'], return_tensors='pt').input_ids })
dataset = dataset.remove_columns("en_text")
"""
dataset = DatasetDict({
train: Dataset({
    features: ['pixel_values', 'labels'],
    num_rows: 833
})
"""
train_testvalid = dataset["train"].train_test_split(0.1)
test_valid = train_testvalid['test'].train_test_split(0.5)
train_test_valid_dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

Setting parameters:

for param in model.encoder.parameters():
    param.requires_grad = False

output_dir = "./checkpoints"
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    overwrite_output_dir=True,
    fp16=True,
    run_name="first_run",
    load_best_model_at_end=True,
    output_dir=output_dir,
    logging_steps=2000,
    save_steps=2000,
    eval_steps=2000,
)

Trying to finetune models:

trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_test_valid_dataset['train'],
        eval_dataset=train_test_valid_dataset['valid'],
        data_collator=default_data_collator,
        
    )
trainer.train()

Output error:

/usr/local/lib/python3.9/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1541             self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1542         )
-> 1543         return inner_training_loop(
   1544             args=args,
   1545             resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.9/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1763 
   1764             step = -1
-> 1765             for step, inputs in enumerate(epoch_iterator):
   1766 
   1767                 # Skip past any already trained steps if resuming training

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    669     def _next_data(self):
    670         index = self._next_index()  # may raise StopIteration
--> 671         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672         if self._pin_memory:
    673             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     59         else:
     60             data = self.dataset[possibly_batched_index]
---> 61         return self.collate_fn(data)

/usr/local/lib/python3.9/dist-packages/transformers/data/data_collator.py in default_data_collator(features, return_tensors)
     68 
     69     if return_tensors == "pt":
---> 70         return torch_default_data_collator(features)
     71     elif return_tensors == "tf":
     72         return tf_default_data_collator(features)

/usr/local/lib/python3.9/dist-packages/transformers/data/data_collator.py in torch_default_data_collator(features)
    134                 batch[k] = torch.tensor(np.stack([f[k] for f in features]))
    135             else:
--> 136                 batch[k] = torch.tensor([f[k] for f in features])
    137 
    138     return batch

ValueError: expected sequence of length 11 at dim 2 (got 12)

How to fix the code?