Custom Data Collator Gives Error

I have a vision dataset that includes mammography images.

This is the model:

vit_config = ViTConfig(
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act='gelu',
    hidden_dropout_prob=0.0,
    attention_probs_dropout_prob=0.0,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    image_size=cfg["image_size"],
    patch_size=16,
    num_channels=1,
    qkv_bias=True,
    encoder_stride=16,
)


model = ViTForImageClassification(config=vit_config)

model.classifier = torch.nn.Linear(768, 4)

feature_extractor = ViTFeatureExtractor(
    do_resize=True,
    size=cfg["image_size"],
    image_mean=[0.5],
    image_std=[0.5],
)

Here I preprocess the dataset:

def process_example(batch):
    view = "image_" + cfg["current_view"]
    inputs = feature_extractor([x for x in batch[view]], return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(1)
    inputs["labels"] = [label2id[x] for x in batch[cfg["label"]]]
    return inputs

prepared_ds = dataset.with_transform(process_example)

example = prepared_ds["train"][0]
example["pixel_values"].shape, example["labels"]

Above cell outputs this:

(torch.Size([1, 512, 512]), 2)

here is the data collator i tried:

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])

#     pixel_values = pixel_values.flatten(0, 1)
#     labels = labels.flatten(0, 1)   
    return {"pixel_values": pixel_values, "labels": labels}

here is trainer

training_args = TrainingArguments(
    output_dir="./results",          # output directory
    num_train_epochs=10,              # total # of training epochs
    per_device_train_batch_size=cfg["batch_size"],  # batch size per device during training
    per_device_eval_batch_size=cfg["batch_size"],   # batch size for evaluation
    logging_dir="./logs",            # directory for storing logs
    evaluation_strategy="epoch",   
    save_strategy="epoch",
    remove_unused_columns=False, # this is especially important
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=feature_extractor,
)

trainer.train()

But this gives me this weird error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_23/3707666235.py in <module>
     27 )
     28 
---> 29 trainer.train()

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1411             resume_from_checkpoint=resume_from_checkpoint,
   1412             trial=trial,
-> 1413             ignore_keys_for_eval=ignore_keys_for_eval,
   1414         )
   1415 

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1649                         tr_loss_step = self.training_step(model, inputs)
   1650                 else:
-> 1651                     tr_loss_step = self.training_step(model, inputs)
   1652 
   1653                 if (

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in training_step(self, model, inputs)
   2343 
   2344         with self.compute_loss_context_manager():
-> 2345             loss = self.compute_loss(model, inputs)
   2346 
   2347         if self.args.n_gpu > 1:

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
   2375         else:
   2376             labels = None
-> 2377         outputs = model(**inputs)
   2378         # Save past state if it exists
   2379         # TODO: this needs to be fixed and made cleaner later.

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/transformers/models/vit/modeling_vit.py in forward(self, pixel_values, head_mask, labels, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
    798             elif self.config.problem_type == "single_label_classification":
    799                 loss_fct = CrossEntropyLoss()
--> 800                 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    801             elif self.config.problem_type == "multi_label_classification":
    802                 loss_fct = BCEWithLogitsLoss()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
   1164         return F.cross_entropy(input, target, weight=self.weight,
   1165                                ignore_index=self.ignore_index, reduction=self.reduction,
-> 1166                                label_smoothing=self.label_smoothing)
   1167 
   1168 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3012     if size_average is not None or reduce is not None:
   3013         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3014     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3015 
   3016 

ValueError: Expected input batch_size (8) to match target batch_size (4).

But i cant get the data collator work with the Huggingface Trainer. I tried many different versions of the data collator but i dont understand why i am getting this weird error. Expected input batch size is the 2x of the intended batch size.

Thank you for reading.

Ok i found the solution, it was because

i tried to add number of classes into model just by changing last layer.
Even though this is not in the documentation, i found out that adding num_classes instead of changing classifier layer was the solution.

vit_config = ViTConfig(
    ...
    num_classes=4,
    ...
)