"too many values to unpack (expected 4)" but pixel_values dimension is correct

I’m running into this error and unsure about how to resolve it.

from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel, DefaultDataCollator, TrainingArguments, Trainer, EarlyStoppingCallback
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

model_name = 'google/vit-base-patch16-224' 
processor = ViTImageProcessor()

configuration = ViTConfig()
configuration.num_labels = 1
configuration.output_hidden_states = True
model = ViTForImageClassification(configuration)#.to(device)

scores = pd.read_csv('/Users/paigemin/Desktop/thesis/resources/scores.csv', index_col=[0])
X_VIT = scores['filename']
y_VIT = scores['score']

X_train_VIT, X_test_VIT, y_train_VIT, y_test_VIT = train_test_split(X_VIT, y_VIT, test_size=0.2)
X_train_VIT = X_train_VIT.reset_index(drop=True)
y_train_VIT = y_train_VIT.reset_index(drop=True).astype('float32')
X_test_VIT = X_test_VIT.reset_index(drop=True)
y_test_VIT = y_test_VIT.reset_index(drop=True).astype('float32')

# Methods to define torch UIDataset to convert Pandas dataframe to torch Dataset 

from PIL import Image

class UIDataset(Dataset):
    def __init__(self, X, y, processor):
        print("inited")
        self.X = X 
        self.y = y
        self.processor = processor

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index): 
        # print('Input index', index)
        path = '/Users/paigemin/Desktop/thesis/resources/music_fin/'+ str(self.X[index]) +'.jpg'
        # print('Getting image from path', path, '............')
        image = Image.open(path)

        image = processor(image, return_tensors="pt")  # this is a dictionary with 'pixel_values' as the tensor
        # print('Image is type', type(image), 'and the shape of pt is', str(image['pixel_values'].shape))

        image["labels"] = torch.tensor(self.y[index])  # here we simply add another 'labels' value to the dictionary aka the label of the image

        return image

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

trainVIT_FT = UIDataset(X_train_VIT, y_train_VIT, processor)
valVIT_FT = UIDataset(X_test_VIT, y_test_VIT, processor)

train_loader = DataLoader(trainVIT_FT, batch_size=16, shuffle=True)
val_loader = DataLoader(valVIT_FT, batch_size=16, shuffle=False)

# Define Trainer parameters
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)#.to(device)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)#.to(device)
    recall = recall_score(y_true=labels, y_pred=pred)#.to(device)
    precision = precision_score(y_true=labels, y_pred=pred)#.to(device)
    f1 = f1_score(y_true=labels, y_pred=pred)#.to(device)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

# Using basic TrainingArguments to start atleast 
args = TrainingArguments('/Users/paigemin/Desktop/thesis/resources/ft_outputs_VIT', label_names=['labels'],)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=trainVIT_FT,
    eval_dataset=valVIT_FT,
    compute_metrics=compute_metrics,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=collate_fn
)

# Train pre-trained model
trainer.train()

Below are images showing X_train_VIT and y_train_VIT:

Below is my output for transformers-cli env:

  • transformers version: 4.31.0
  • Platform: macOS-14.3-arm64-i386-64bit
  • Python version: 3.10.0
  • Huggingface_hub version: 0.18.0
  • Safetensors version: 0.4.0
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Here is an image of the error

Here is what I see as the shape of pixel_values: