Hello, I am trying to fine tune a pretrained segformer model following this blog post, I am using my own dataset: Fine-Tune a Semantic Segmentation Model with a Custom Dataset
I have a directory containing my original images and another containing the binary masks for the segmentation training.
Here’s how to loading the data:
def to_hf_dataset(img_dir, mask_dir):
file_list = os.listdir(img_dir)
mask_file_list = os.listdir(mask_dir)
pixel_values = [os.path.join(img_dir, fname) for fname in file_list]
pixel_values.sort()
labels = [os.path.join(mask_dir, fname) for fname in mask_file_list]
labels.sort()
ds = Dataset.from_dict({"pixel_values": pixel_values, "label": labels}).cast_column("pixel_values", Image())
ds = ds.cast_column("label", Image())
return ds
train_ds = to_hf_dataset(train_raw_image_dir, train_binary_output_dir)
test_ds = to_hf_dataset(validation_raw_image_dir, validation_binary_output_dir)
This is what the dataset looks like on printing train_ds[0]
{'pixel_values': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=960x1280>,
'label': <PIL.PngImagePlugin.PngImageFile image mode=L size=960x1280>}
Next I generate the labels and id mapping:
id2label = {0: 'unlabeled', 1: 'ring'}
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
print("Id2label:", id2label)
Next I apply transformation with the following code, I believe this is where the problem start:
from torchvision.transforms import ColorJitter
from transformers import (
SegformerFeatureExtractor,
)
feature_extractor = SegformerFeatureExtractor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
def train_transforms(example_batch):
images = [x for x in example_batch['pixel_values']]
labels = [x for x in example_batch['label']]
inputs = feature_extractor(images=images, masks=labels)
return inputs
def val_transforms(example_batch):
images = [x for x in example_batch['pixel_values']]
labels = [x for x in example_batch['label']]
inputs = feature_extractor(images=images, masks=labels)
return inputs
# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)
Upon printing train_ds[0]
this is what I see:
{'pixel_values': array([[[ 1.7180408 , 1.7180408 , 1.7180408 , ..., -0.9362959 ,
-1.1246684 , -1.1931673 ],
[ 1.7180408 , 1.7180408 , 1.7180408 , ..., -1.2102921 ,
-1.2102921 , -1.1931673 ],
[ 1.7180408 , 1.7180408 , 1.7180408 , ..., -1.073294 ,
-1.0904187 , -1.1589178 ],
...,
[-0.43967807, -0.45680285, -0.4739276 , ..., 1.9406626 ,
1.9577874 , 1.9577874 ],
[-0.49105233, -0.5081771 , -0.5253019 , ..., 1.9406626 ,
1.9577874 , 1.9577874 ],
[-0.57667613, -0.55955136, -0.55955136, ..., 1.9406626 ,
1.9577874 , 1.9577874 ]],
[[ 1.7107843 , 1.7107843 , 1.7107843 , ..., -0.810224 ,
-1.0028011 , -1.0728291 ],
[ 1.7107843 , 1.7107843 , 1.7107843 , ..., -1.0903361 ,
-1.0903361 , -1.0728291 ],
[ 1.7107843 , 1.7107843 , 1.7107843 , ..., -0.9502801 ,
-0.9677871 , -1.0378151 ],
...,
[-0.91526604, -0.9502801 , -0.9677871 , ..., 1.9558823 ,
1.9733893 , 1.9733893 ],
[-0.9677871 , -0.9852941 , -1.0028011 , ..., 1.9558823 ,
1.9733893 , 1.9733893 ],
[-1.055322 , -1.0378151 , -1.0378151 , ..., 1.9558823 ,
1.9733893 , 1.9733893 ]],
[[ 1.8905448 , 1.8905448 , 1.8905448 , ..., -0.82840955,
-1.0201306 , -1.0898474 ],
[ 1.8905448 , 1.8905448 , 1.8905448 , ..., -1.1072767 ,
-1.1072767 , -1.0898474 ],
[ 1.8905448 , 1.8905448 , 1.8905448 , ..., -0.9678431 ,
-0.9852723 , -1.0549891 ],
...,
[-1.1247058 , -1.1595641 , -1.1769934 , ..., 2.0125492 ,
2.0299783 , 2.0299783 ],
[-1.1769934 , -1.1944225 , -1.2118517 , ..., 2.0125492 ,
2.0299783 , 2.0299783 ],
[-1.2641394 , -1.2467101 , -1.2467101 , ..., 2.0125492 ,
2.0299783 , 2.0299783 ]]], dtype=float32)}
Here I notice that the label object is no longer present after the transformation
The rest of the code is same as shown in the blog post that I mentioned above.
Next when I execute trainer.train()
Here’s the full trace back of the error I see:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-67-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()
20 frames
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1660 self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
1661 )
-> 1662 return inner_training_loop(
1663 args=args,
1664 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1897
1898 step = -1
-> 1899 for step, inputs in enumerate(epoch_iterator):
1900 total_batched_samples += 1
1901 if rng_to_sync:
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self)
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
47 if self.auto_collation:
48 if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
51 data = [self.dataset[idx] for idx in possibly_batched_index]
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitems__(self, keys)
2805 def __getitems__(self, keys: List) -> List:
2806 """Can be used to get a batch using a list of integers indices."""
-> 2807 batch = self.__getitem__(keys)
2808 n_examples = len(batch[next(iter(batch))])
2809 return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitem__(self, key)
2801 def __getitem__(self, key): # noqa: F811
2802 """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2803 return self._getitem(key)
2804
2805 def __getitems__(self, keys: List) -> List:
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in _getitem(self, key, **kwargs)
2786 formatter = get_formatter(format_type, features=self._info.features, **format_kwargs)
2787 pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
-> 2788 formatted_output = format_table(
2789 pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
2790 )
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in format_table(table, key, formatter, format_columns, output_all_columns)
627 python_formatter = PythonFormatter(features=None)
628 if format_columns is None:
--> 629 return formatter(pa_table, query_type=query_type)
630 elif query_type == "column":
631 if key in format_columns:
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in __call__(self, pa_table, query_type)
398 return self.format_column(pa_table)
399 elif query_type == "batch":
--> 400 return self.format_batch(pa_table)
401
402 def format_row(self, pa_table: pa.Table) -> RowFormat:
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in format_batch(self, pa_table)
513 batch = self.python_arrow_extractor().extract_batch(pa_table)
514 batch = self.python_features_decoder.decode_batch(batch)
--> 515 return self.transform(batch)
516
517
<ipython-input-61-559ad64a3301> in train_transforms(example_batch)
10 images = [x for x in example_batch['pixel_values']]
11 labels = [x for x in example_batch['label']]
---> 12 inputs = feature_extractor(images=images, masks=labels)
13 # labels = [np.array(x) for x in example_batch['label']]
14 # inputs['label'] = labels
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in __call__(self, images, segmentation_maps, **kwargs)
313 passed in as positional arguments.
314 """
--> 315 return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
316
317 def preprocess(
/usr/local/lib/python3.10/dist-packages/transformers/image_processing_utils.py in __call__(self, images, **kwargs)
456 def __call__(self, images, **kwargs) -> BatchFeature:
457 """Preprocess an image or a batch of images."""
--> 458 return self.preprocess(images, **kwargs)
459
460 def preprocess(self, images, **kwargs) -> BatchFeature:
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in preprocess(self, images, segmentation_maps, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, do_reduce_labels, return_tensors, data_format, **kwargs)
408 raise ValueError("Image mean and std must be specified if do_normalize is True.")
409
--> 410 images = [
411 self._preprocess_image(
412 image=img,
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in <listcomp>(.0)
409
410 images = [
--> 411 self._preprocess_image(
412 image=img,
413 do_resize=do_resize,
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in _preprocess_image(self, image, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, data_format)
260 # All transformations expect numpy arrays.
261 image = to_numpy_array(image)
--> 262 image = self._preprocess(
263 image=image,
264 do_reduce_labels=False,
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in _preprocess(self, image, do_reduce_labels, do_resize, do_rescale, do_normalize, size, resample, rescale_factor, image_mean, image_std)
234
235 if do_resize:
--> 236 image = self.resize(image=image, size=size, resample=resample)
237
238 if do_rescale:
/usr/local/lib/python3.10/dist-packages/transformers/models/segformer/image_processing_segformer.py in resize(self, image, size, resample, data_format, **kwargs)
162 if "height" not in size or "width" not in size:
163 raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
--> 164 return resize(
165 image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
166 )
/usr/local/lib/python3.10/dist-packages/transformers/image_transforms.py in resize(image, size, resample, reducing_gap, data_format, return_numpy)
298 # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
299 # The resized image from PIL will always have channels last, so find the input format first.
--> 300 data_format = infer_channel_dimension_format(image) if data_format is None else data_format
301
302 # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
/usr/local/lib/python3.10/dist-packages/transformers/image_utils.py in infer_channel_dimension_format(image)
163 elif image.shape[last_dim] in (1, 3):
164 return ChannelDimension.LAST
--> 165 raise ValueError("Unable to infer channel dimension format")
166
167
ValueError: Unable to infer channel dimension format
I believe the issue I mentioned before this where the feature extractor is not returning the label object might be the problem.
I have been at this for quite some time, could you please help me figure out what I might be missing. Thanks.