Task Guides - Image segmentation

I encountered an unexpected error while copying the image segmentation part of the Task Guides. I did it as it was in the Task Guides, but why does the error occur? How can I solve this problem?

Code

from datasets import load_dataset

dataset = load_dataset('scene_parse_150', trust_remote_code=True)

train_dataset = dataset['train']
test_dataset = dataset['test']

import json
from huggingface_hub import hf_hub_download

repo_id = 'huggingface/label-files'
file_name = 'ade20k-id2label.json'
id2label = json.load(open(hf_hub_download(repo_id, file_name, repo_type='dataset'), mode='r'))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
n_labels = len(id2label)

from transformers import AutoImageProcessor

checkpoint = 'nvidia/mit-b0'
img_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)

from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=.25, contrast=.24, saturation=.25, hue=.1)

def train_transforms(exmaple_batch):
    imgs = [jitter(x) for x in exmaple_batch['image']]
    labels = [x for x in exmaple_batch['annotation']]
    inputs = img_processor(imgs, labels)

    return inputs

def val_transforms(example_batch):
    imgs = [x for x in example_batch['image']]
    labels = [x for x in example_batch['annotation']]
    inputs = img_processor(imgs, labels)

    return inputs

train_dataset.set_transform(train_transforms)
test_dataset.set_transform(val_transforms)

import evaluate

metric = evaluate.load('mean_iou')

import numpy as np
import torch
import torch.nn as nn

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode='bilinear',
            align_corners=False,
        ).argmax(dim=1)
        
        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=n_labels,
            ignore_index=255,
            reduce_labels=False
        )
        
        for k, v in metrics.items():
            if isinstance(v, np.ndarray):
                metrics[k] = v.tolist()
        
        return metrics

from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

training_args = TrainingArguments(
    output_dir='segformer-b0-scene-parse-150',
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_total_limit=5,
    evaluation_strategy='steps',
    save_strategy='steps',
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

Error Message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[67], line 25
      1 training_args = TrainingArguments(
      2     output_dir='segformer-b0-scene-parse-150',
      3     learning_rate=6e-5,
   (...)
     14     remove_unused_columns=False,
     15 )
     17 trainer = Trainer(
     18     model=model,
     19     args=training_args,
   (...)
     22     compute_metrics=compute_metrics,
     23 )
---> 25 trainer.train()

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/trainer.py:1780, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1778         hf_hub_utils.enable_progress_bars()
   1779 else:
-> 1780     return inner_training_loop(
   1781         args=args,
   1782         resume_from_checkpoint=resume_from_checkpoint,
   1783         trial=trial,
   1784         ignore_keys_for_eval=ignore_keys_for_eval,
   1785     )

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/trainer.py:2193, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2190     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2191     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2193     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2194 else:
   2195     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/trainer.py:2577, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2575 metrics = None
   2576 if self.control.should_evaluate:
-> 2577     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2578     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2580     # Run delayed LR scheduler now that metrics are populated

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/trainer.py:3365, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3362 start_time = time.time()
   3364 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3365 output = eval_loop(
   3366     eval_dataloader,
   3367     description="Evaluation",
   3368     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3369     # self.args.prediction_loss_only
   3370     prediction_loss_only=True if self.compute_metrics is None else None,
   3371     ignore_keys=ignore_keys,
   3372     metric_key_prefix=metric_key_prefix,
   3373 )
   3375 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3376 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/trainer.py:3544, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3542 observed_num_examples = 0
   3543 # Main evaluation loop
-> 3544 for step, inputs in enumerate(dataloader):
   3545     # Update the observed num examples
   3546     observed_batch_size = find_batch_size(inputs)
   3547     if observed_batch_size is not None:

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/accelerate/data_loader.py:452, in DataLoaderShard.__iter__(self)
    450 # We iterate one batch ahead to check when we are at the end
    451 try:
--> 452     current_batch = next(dataloader_iter)
    453 except StopIteration:
    454     yield

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
    674     index = self._next_index()  # may raise StopIteration
--> 675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:
    677         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     47 if self.auto_collation:
     48     if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
     51         data = [self.dataset[idx] for idx in possibly_batched_index]

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/arrow_dataset.py:2814, in Dataset.__getitems__(self, keys)
   2812 def __getitems__(self, keys: List) -> List:
   2813     """Can be used to get a batch using a list of integers indices."""
-> 2814     batch = self.__getitem__(keys)
   2815     n_examples = len(batch[next(iter(batch))])
   2816     return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/arrow_dataset.py:2810, in Dataset.__getitem__(self, key)
   2808 def __getitem__(self, key):  # noqa: F811
   2809     """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2810     return self._getitem(key)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/arrow_dataset.py:2795, in Dataset._getitem(self, key, **kwargs)
   2793 formatter = get_formatter(format_type, features=self._info.features, **format_kwargs)
   2794 pa_subtable = query_table(self._data, key, indices=self._indices)
-> 2795 formatted_output = format_table(
   2796     pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
   2797 )
   2798 return formatted_output

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/formatting/formatting.py:629, in format_table(table, key, formatter, format_columns, output_all_columns)
    627 python_formatter = PythonFormatter(features=formatter.features)
    628 if format_columns is None:
--> 629     return formatter(pa_table, query_type=query_type)
    630 elif query_type == "column":
    631     if key in format_columns:

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/formatting/formatting.py:400, in Formatter.__call__(self, pa_table, query_type)
    398     return self.format_column(pa_table)
    399 elif query_type == "batch":
--> 400     return self.format_batch(pa_table)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/datasets/formatting/formatting.py:515, in CustomFormatter.format_batch(self, pa_table)
    513 batch = self.python_arrow_extractor().extract_batch(pa_table)
    514 batch = self.python_features_decoder.decode_batch(batch)
--> 515 return self.transform(batch)

Cell In[61], line 11
      9 imgs = [x for x in example_batch['image']]
     10 labels = [x for x in example_batch['annotation']]
---> 11 inputs = img_processor(imgs, labels)
     13 return inputs

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/models/segformer/image_processing_segformer.py:321, in SegformerImageProcessor.__call__(self, images, segmentation_maps, **kwargs)
    314 def __call__(self, images, segmentation_maps=None, **kwargs):
    315     """
    316     Preprocesses a batch of images and optionally segmentation maps.
    317 
    318     Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
    319     passed in as positional arguments.
    320     """
--> 321     return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/image_processing_utils.py:551, in BaseImageProcessor.__call__(self, images, **kwargs)
    549 def __call__(self, images, **kwargs) -> BatchFeature:
    550     """Preprocess an image or a batch of images."""
--> 551     return self.preprocess(images, **kwargs)

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/models/segformer/image_processing_segformer.py:404, in SegformerImageProcessor.preprocess(self, images, segmentation_maps, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, do_reduce_labels, return_tensors, data_format, input_data_format, **kwargs)
    401 validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
    403 if segmentation_maps is not None:
--> 404     segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
    406 if not valid_images(images):
    407     raise ValueError(
    408         "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
    409         "torch.Tensor, tf.Tensor or jax.ndarray."
    410     )

File ~/anaconda3/envs/torch/lib/python3.11/site-packages/transformers/image_utils.py:162, in make_list_of_images(images, expected_ndims)
    157         raise ValueError(
    158             f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
    159             f" {images.ndim} dimensions."
    160         )
    161     return images
--> 162 raise ValueError(
    163     "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
    164     f"jax.ndarray, but got {type(images)}."
    165 )

ValueError: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'list'>.