SegformerFeatureExtractor not working as expected - Feature extractor not returning the label object

Hello, I am trying to fine tune a pretrained segformer model following this blog post, I am using my own dataset:

I have a directory containing my original images and another containing the binary masks for the segmentation training.

Here’s how to loading the data:

def to_hf_dataset(img_dir, mask_dir):
  file_list = os.listdir(img_dir)
  mask_file_list = os.listdir(mask_dir)
  pixel_values = [os.path.join(img_dir, fname) for fname in file_list]
  pixel_values.sort()
  labels = [os.path.join(mask_dir, fname) for fname in mask_file_list]
  labels.sort()
  ds = Dataset.from_dict({"pixel_values": pixel_values, "label": labels}).cast_column("pixel_values", Image())
  ds = ds.cast_column("label", Image())

  return ds

train_ds = to_hf_dataset(train_raw_image_dir, train_binary_output_dir)
test_ds = to_hf_dataset(validation_raw_image_dir, validation_binary_output_dir)

This is what the dataset looks like on printing train_ds[0]

{'pixel_values': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=960x1280>,
 'label': <PIL.PngImagePlugin.PngImageFile image mode=L size=960x1280>}

Next I generate the labels and id mapping:

id2label = {0: 'unlabeled', 1: 'ring'}
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

print("Id2label:", id2label)

Next I apply transformation with the following code, I believe this is where the problem start:

from torchvision.transforms import ColorJitter
from transformers import (
    SegformerFeatureExtractor,
)

feature_extractor = SegformerFeatureExtractor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = feature_extractor(images=images, masks=labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = feature_extractor(images=images, masks=labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

Upon printing train_ds[0] this is what I see:

{'pixel_values': array([[[ 1.7180408 ,  1.7180408 ,  1.7180408 , ..., -0.9362959 ,
          -1.1246684 , -1.1931673 ],
         [ 1.7180408 ,  1.7180408 ,  1.7180408 , ..., -1.2102921 ,
          -1.2102921 , -1.1931673 ],
         [ 1.7180408 ,  1.7180408 ,  1.7180408 , ..., -1.073294  ,
          -1.0904187 , -1.1589178 ],
         ...,
         [-0.43967807, -0.45680285, -0.4739276 , ...,  1.9406626 ,
           1.9577874 ,  1.9577874 ],
         [-0.49105233, -0.5081771 , -0.5253019 , ...,  1.9406626 ,
           1.9577874 ,  1.9577874 ],
         [-0.57667613, -0.55955136, -0.55955136, ...,  1.9406626 ,
           1.9577874 ,  1.9577874 ]],
 
        [[ 1.7107843 ,  1.7107843 ,  1.7107843 , ..., -0.810224  ,
          -1.0028011 , -1.0728291 ],
         [ 1.7107843 ,  1.7107843 ,  1.7107843 , ..., -1.0903361 ,
          -1.0903361 , -1.0728291 ],
         [ 1.7107843 ,  1.7107843 ,  1.7107843 , ..., -0.9502801 ,
          -0.9677871 , -1.0378151 ],
         ...,
         [-0.91526604, -0.9502801 , -0.9677871 , ...,  1.9558823 ,
           1.9733893 ,  1.9733893 ],
         [-0.9677871 , -0.9852941 , -1.0028011 , ...,  1.9558823 ,
           1.9733893 ,  1.9733893 ],
         [-1.055322  , -1.0378151 , -1.0378151 , ...,  1.9558823 ,
           1.9733893 ,  1.9733893 ]],
 
        [[ 1.8905448 ,  1.8905448 ,  1.8905448 , ..., -0.82840955,
          -1.0201306 , -1.0898474 ],
         [ 1.8905448 ,  1.8905448 ,  1.8905448 , ..., -1.1072767 ,
          -1.1072767 , -1.0898474 ],
         [ 1.8905448 ,  1.8905448 ,  1.8905448 , ..., -0.9678431 ,
          -0.9852723 , -1.0549891 ],
         ...,
         [-1.1247058 , -1.1595641 , -1.1769934 , ...,  2.0125492 ,
           2.0299783 ,  2.0299783 ],
         [-1.1769934 , -1.1944225 , -1.2118517 , ...,  2.0125492 ,
           2.0299783 ,  2.0299783 ],
         [-1.2641394 , -1.2467101 , -1.2467101 , ...,  2.0125492 ,
           2.0299783 ,  2.0299783 ]]], dtype=float32)}

Here I notice that the label object is no longer present after the transformation

The rest of the code is same as shown in the blog post that I mentioned above.
Next when I execute trainer.train()
Here’s the full trace back of the error I see:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-67-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

20 frames
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1660             self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1661         )
-> 1662         return inner_training_loop(
   1663             args=args,
   1664             resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1897 
   1898             step = -1
-> 1899             for step, inputs in enumerate(epoch_iterator):
   1900                 total_batched_samples += 1
   1901                 if rng_to_sync:

/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    631                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632                 self._reset()  # type: ignore[call-arg]
--> 633             data = self._next_data()
    634             self._num_yielded += 1
    635             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    675     def _next_data(self):
    676         index = self._next_index()  # may raise StopIteration
--> 677         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678         if self._pin_memory:
    679             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     47         if self.auto_collation:
     48             if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
     51                 data = [self.dataset[idx] for idx in possibly_batched_index]

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitems__(self, keys)
   2805     def __getitems__(self, keys: List) -> List:
   2806         """Can be used to get a batch using a list of integers indices."""
-> 2807         batch = self.__getitem__(keys)
   2808         n_examples = len(batch[next(iter(batch))])
   2809         return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitem__(self, key)
   2801     def __getitem__(self, key):  # noqa: F811
   2802         """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2803         return self._getitem(key)
   2804 
   2805     def __getitems__(self, keys: List) -> List:

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in _getitem(self, key, **kwargs)
   2786         formatter = get_formatter(format_type, features=self._info.features, **format_kwargs)
   2787         pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
-> 2788         formatted_output = format_table(
   2789             pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
   2790         )

/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in format_table(table, key, formatter, format_columns, output_all_columns)
    627     python_formatter = PythonFormatter(features=None)
    628     if format_columns is None:
--> 629         return formatter(pa_table, query_type=query_type)
    630     elif query_type == "column":
    631         if key in format_columns:

/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in __call__(self, pa_table, query_type)
    398             return self.format_column(pa_table)
    399         elif query_type == "batch":
--> 400             return self.format_batch(pa_table)
    401 
    402     def format_row(self, pa_table: pa.Table) -> RowFormat:

/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in format_batch(self, pa_table)
    513         batch = self.python_arrow_extractor().extract_batch(pa_table)
    514         batch = self.python_features_decoder.decode_batch(batch)
--> 515         return self.transform(batch)
    516 
    517 

<ipython-input-61-559ad64a3301> in train_transforms(example_batch)
     10     images = [x for x in example_batch['pixel_values']]
     11     labels = [x for x in example_batch['label']]
---> 12     inputs = feature_extractor(images=images, masks=labels)
     13     # labels = [np.array(x) for x in example_batch['label']]
     14     # inputs['label'] = labels

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in __call__(self, images, segmentation_maps, **kwargs)
    313         passed in as positional arguments.
    314         """
--> 315         return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
    316 
    317     def preprocess(

/usr/local/lib/python3.10/dist-packages/transformers/image_processing_utils.py in __call__(self, images, **kwargs)
    456     def __call__(self, images, **kwargs) -> BatchFeature:
    457         """Preprocess an image or a batch of images."""
--> 458         return self.preprocess(images, **kwargs)
    459 
    460     def preprocess(self, images, **kwargs) -> BatchFeature:

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in preprocess(self, images, segmentation_maps, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, do_reduce_labels, return_tensors, data_format, **kwargs)
    408             raise ValueError("Image mean and std must be specified if do_normalize is True.")
    409 
--> 410         images = [
    411             self._preprocess_image(
    412                 image=img,

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in <listcomp>(.0)
    409 
    410         images = [
--> 411             self._preprocess_image(
    412                 image=img,
    413                 do_resize=do_resize,

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in _preprocess_image(self, image, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, data_format)
    260         # All transformations expect numpy arrays.
    261         image = to_numpy_array(image)
--> 262         image = self._preprocess(
    263             image=image,
    264             do_reduce_labels=False,

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in _preprocess(self, image, do_reduce_labels, do_resize, do_rescale, do_normalize, size, resample, rescale_factor, image_mean, image_std)
    234 
    235         if do_resize:
--> 236             image = self.resize(image=image, size=size, resample=resample)
    237 
    238         if do_rescale:

/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in resize(self, image, size, resample, data_format, **kwargs)
    162         if "height" not in size or "width" not in size:
    163             raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
--> 164         return resize(
    165             image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
    166         )

/usr/local/lib/python3.10/dist-packages/transformers/image_transforms.py in resize(image, size, resample, reducing_gap, data_format, return_numpy)
    298     # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
    299     # The resized image from PIL will always have channels last, so find the input format first.
--> 300     data_format = infer_channel_dimension_format(image) if data_format is None else data_format
    301 
    302     # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use

/usr/local/lib/python3.10/dist-packages/transformers/image_utils.py in infer_channel_dimension_format(image)
    163     elif image.shape[last_dim] in (1, 3):
    164         return ChannelDimension.LAST
--> 165     raise ValueError("Unable to infer channel dimension format")
    166 
    167 

ValueError: Unable to infer channel dimension format

I believe the issue I mentioned before this where the feature extractor is not returning the label object might be the problem.

I have been at this for quite some time, could you please help me figure out what I might be missing. Thanks.