Vision Transformer Fine Tuning Issues

Hi,
I’m finetuning the Google’s Vision Transformer over my custom dataset. I am formatting my dataset to a HuggingFace’s Dataset format in the following manner:

This is the initial csv:
image

Each example from the dataset should have 3 features:

  • image: A PIL Image
  • image_file_path: The str path to the image file that was loaded as image
  • labels: A datasets.ClassLabel feature, which is an integer representation of the label.

I am casting it to an Image Dataset as instructed in the official HuggingFace Docs

paths_dict = {'image': df['Image_Path'].tolist()}

with output:

{'image': ['/kaggle/working/PendantImages/Pendant_UG00198-1Y0000_1_lar.jpg',
  '/kaggle/working/EarringsImages/Earring_BISJ0029S17_YAA18DIG6XXXXXXXX_ABCD00-PICS-00004-1024-5542.png',
  '/kaggle/working/BangleImages/Bangle_513220VIZ2A00_1.jpg',
  '/kaggle/working/EarringsImages/Earring_BISP0419S13_YAA18DIG6XXXXXXXX_ABCD00-PICS-00004-1024-37602.png',
  '/kaggle/working/NecklaceImages/Necklace_BIDK0846N11_YAA18DIG6XXXXXXXX_ABCD00-PICS-00003-1024-65037.png',
  '/kaggle/working/EarringsImages/Earring_BICM0379S07_YAA18DIG6XXXXXXXX_ABCD00-PICS-00004-1024-53001.png', ...}

Here’s the cast to image format operation:

from datasets import Dataset, Image
hub_dataset = Dataset.from_dict(paths_dict, split = "train").cast_column("image", Image())

with output:

Dataset({
    features: ['image'],
    num_rows: 6278
})

And here’s the rest columns cast operations with final results:

hub_dataset = hub_dataset.add_column('image_file_path', df['Image_Path'])
hub_dataset = hub_dataset.add_column('labels', df['Class'])
hub_dataset = hub_dataset.cast_column('labels', datasets.features.features.ClassLabel(names=[1,2,3,4]))
hub_dataset.features

with output:

{'image': Image(decode=True, id=None),
 'image_file_path': Value(dtype='string', id=None),
 'labels': ClassLabel(num_classes=4, names=[1, 2, 3, 4], id=None)}

Here’s the processor function and the collate function for preprocessing it for ViT

from transformers import ViTImageProcessor

model_name = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name)


def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

prepared_df = hub_dataset.with_transform(transform)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }


Here’s the model that I created:

labels = hub_dataset.features['labels'].names #[1,2,3,4]

model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

The training args:

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir = "/kaggle/working/vit-base-jwlry",
#   per_device_train_batch_size = 16,
  evaluation_strategy = "steps",
#   num_train_epochs = 4,
  fp16 = True, 
#   use_cpu = True,
#   save_steps = 100,
#   eval_steps = 100,
#   logging_steps = 10,
  learning_rate = 2e-4,
  save_total_limit = 2,
  remove_unused_columns = False,
  push_to_hub = False,
  report_to = 'tensorboard',
  load_best_model_at_end = True
)


trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = collate_fn,
    compute_metrics = compute_metrics,
    train_dataset = prepared_df[0:100],
    eval_dataset = prepared_df[0:100],
    tokenizer = processor
)

I am encountering the error in the train() step:

train_results = trainer. Train()

This is the exact error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[38], line 1
----> 1 train_results = trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1537, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1535         hf_hub_utils.enable_progress_bars()
   1536 else:
-> 1537     return inner_training_loop(
   1538         args=args,
   1539         resume_from_checkpoint=resume_from_checkpoint,
   1540         trial=trial,
   1541         ignore_keys_for_eval=ignore_keys_for_eval,
   1542     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1821, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1818     rng_to_sync = True
   1820 step = -1
-> 1821 for step, inputs in enumerate(epoch_iterator):
   1822     total_batched_samples += 1
   1824     if self.args.include_num_input_tokens_seen:

File /opt/conda/lib/python3.10/site-packages/accelerate/data_loader.py:448, in DataLoaderShard.__iter__(self)
    446 # We iterate one batch ahead to check when we are at the end
    447 try:
--> 448     current_batch = next(dataloader_iter)
    449 except StopIteration:
    450     yield

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:634, in _BaseDataLoaderIter.__next__(self)
    631 if self._sampler_iter is None:
    632     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    633     self._reset()  # type: ignore[call-arg]
--> 634 data = self._next_data()
    635 self._num_yielded += 1
    636 if self._dataset_kind == _DatasetKind.Iterable and \
    637         self._IterableDataset_len_called is not None and \
    638         self._num_yielded > self._IterableDataset_len_called:

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:678, in _SingleProcessDataLoaderIter._next_data(self)
    676 def _next_data(self):
    677     index = self._next_index()  # may raise StopIteration
--> 678     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    679     if self._pin_memory:
    680         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File /opt/conda/lib/python3.10/site-packages/transformers/feature_extraction_utils.py:88, in BatchFeature.__getitem__(self, item)
     86     return self.data[item]
     87 else:
---> 88     raise KeyError("Indexing with integers is not available when using Python based feature extractors")

KeyError: 'Indexing with integers is not available when using Python based feature extractors'

Can anyone help me identify the issue, and suggest some workarounds for it? I am running it on a kaggle notebook with following specs:

Disk: 73 GB
CPU and RAM: 29GB
GPU: T4 x2 of 15GB each

Here’s the guide i am following for finetuning ViT for my dataset: Fine-Tune ViT for Image Classification with :hugs: Transformers (huggingface.co)

Hey ,were you able to solve this issue ? @raunak45

Hi,

The issue may come from the fact that the image processor adds a batch dimension by default, and the collate_fn seems to add another dimension. Also, it seems you’re providing a dataframe to the Trainer, which doesn’t seem right.

In order to debug, the best thing would be to create a PyTorch DataLoader and see what comes out of the first batch:

from torch.utils.data import DataLoader

dataloader = DataLoader(your_pytorch_dataset, batch_size=2, shuffle=True)

batch = next(iter(dataloader))
for k,v in batch.items():
    print(k,v.shape)

Here you would need to see 2 things: pixel_values of shape (batch_size, num_channels, height, width) and labels of shape (batch_size,)