Deepcopy error when copying Dataset in Training

Hi everyone,

I want to create a class to be used as interface for running trainings using Segformer model. (instead of everytime write each steps manually).

However when I run a test script using this class I got an error due to a problem with a deepcopy() instruction.

Here I will report the code I developed to create and test this class, and after that I will report the error I got in the terminal.

## Standard modules
import logging
import numpy as np
import json
import torch
from torch import nn
import cv2
import copy
from typing import Dict, Any

## Custom modules
import py_utils.log

## Huggingface modules
from huggingface_hub import login, hf_hub_download
from datasets import load_dataset, load_dataset_builder
from transformers import (SegformerImageProcessor,
                          SegformerForSemanticSegmentation,
                          TrainingArguments,
                          Trainer,
                          BatchFeature)
import evaluate
from transformers.trainer import EvalPrediction

## Albumentatios modules
import albumentations as Albu

class HfSegformerTrainer:
    """ Utility class providing a training interface for an Huggingface based Segformer model. """

    def __init__(self,
                 hf_token_         : str,
                 dataset_name_     : str,
                 pretrained_model_ : str,
                 train_augm_       : Albu.Compose,
                 valid_augm_       : Albu.Compose,
                 segformer_proc_   : SegformerImageProcessor,
                 training_args_    : TrainingArguments,
                 log_level_        : int = logging.INFO) -> None:
        """ Initialize a Segformer model and everything necessary to run the training. """

        ## Logger setup
        self.logger = py_utils.log.getCustomLogger(logger_name_=__name__,
                                                   node_name_="HfTrainer",
                                                   log_handler_=logging.StreamHandler(),
                                                   logging_level_=log_level_)

        self.logger.debug("__init__() begin!")

        # Login to Huggingface
        login(hf_token_)

        ## Retrieve dataset informations (optional)
        ds_builder = load_dataset_builder(dataset_name_)

        self.logger.debug(f"ds_builder description: {ds_builder.info.description}")
        self.logger.debug(f"ds_builder features: {ds_builder.info.features}")

        ## Load dataset splits
        self.train_ds = load_dataset(dataset_name_, split="train")
        self.valid_ds = load_dataset(dataset_name_, split="valid")

        self.logger.debug(f"Train dataset: {self.train_ds}")
        self.logger.debug(f"Valid dataset: {self.valid_ds}")

        self.logger.debug(f"Image: {self.train_ds[0]}")

        # Segformer Image processor
        self.segformer_processor = segformer_proc_
        
        ## Set transformation pipelines for each dataset split
        self.train_ds.set_transform(self.trainTransform)
        self.valid_ds.set_transform(self.validTransform)

        self.logger.debug(f"Train dataset format: {self.train_ds.format}")

        ## Augmentation pipelines (based on Albumentations)
        self.train_augm = train_augm_
        self.valid_augm = valid_augm_

        ## Retrieve id and labels of the dataset (assuming there is a id2label.json file)
        self.id2label = json.load(open(hf_hub_download(repo_id=dataset_name_,
                                                       filename="id2label.json",
                                                       repo_type="dataset"), "r"))
        self.id2label = {int(k): v for k, v in self.id2label.items()}
        label2id = {v: k for k, v in self.id2label.items()}

        self.logger.debug(f"self.id2label: {self.id2label}")

        # Index to be ignored by evaluation metrics (not suggested to change)
        self.IGNORE_IDX = 255

        ## Segformer model instantiation
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_,
            id2label=self.id2label,
            label2id=label2id
        )

        ## Initialize metrics (mIoU and Confusion Matrix)
        self.mean_iou_metric = evaluate.load("mean_iou")

        ## Initialize the Trainer
        self.trainer = Trainer(model=self.model,
                               args=training_args_,
                               train_dataset=self.train_ds,
                               eval_dataset=self.valid_ds,
                               compute_metrics=self.computeMetrics)

        self.logger.debug("__init__() completed!")

    def train(self) -> None:
        """ Run training pipeline. """
        self.trainer.train()

    def trainTransform(self,
                       batch_ : dict) -> BatchFeature:
        """ On the fly preparation of a batch of train data as expected by Segformer model. """

        return self.batchTransform(batch_=batch_,
                                   augm_pipeline_=self.train_augm)
        
    def validTransform(self,
                       batch_ : dict) -> BatchFeature:
        """ On the fly preparation of a batch of validation data as expected by Segformer model. """

        return self.batchTransform(batch_=batch_,
                                   augm_pipeline_=self.valid_augm)
    
    def batchTransform(self,
                       batch_ : dict,
                       augm_pipeline_ : Albu.Compose) -> BatchFeature:
        """ Apply augmentations and Segformer processor transformations to the given batch. """

        self.logger.debug(f"Batch: {batch_}")

        assert(len(batch_["pixel_values"]) == len(batch_["label"]))

        images = []
        labels = []

        ## Parse and augments both images and labels
        for (img_pil, label_pil) in zip(batch_["pixel_values"], batch_["label"]):
            augmented = augm_pipeline_(image=np.array(img_pil.convert("RGB")),
                                       mask=np.array(label_pil))
            images.append(augmented["image"])
            labels.append(augmented["mask"])

        assert(len(images) == len(labels))

        # Complete preprocessing to provide data as expected by Segformer
        segformer_inputs = self.segformer_processor(images, labels)

        return segformer_inputs
    
    def computeMetrics(self, eval_pred_ : EvalPrediction) -> Dict[str, Any]:
        """ Compute evaluation metrics for given predictions data. """
        
        with torch.no_grad():
            logits, labels = eval_pred_
            logits_tensor = torch.from_numpy(logits)
            # Upscale the logits to the size of the label
            logits_tensor = nn.functional.interpolate(input=logits_tensor,
                                                      size=labels.shape[-2:],
                                                      mode="bilinear",
                                                      align_corners=False).argmax(dim=1)
            pred_labels = logits_tensor.detach().cpu().numpy()

            mean_iou_results = self.mean_iou_metric._compute(predictions=pred_labels,
                                                             references=labels,
                                                             num_labels=len(self.id2label),
                                                             ignore_index=self.IGNORE_IDX,
                                                             reduce_labels=False) # we've already reduced the labels ourselves
            
        # add per category metrics as individual key-value pairs
        per_category_accuracy = mean_iou_results.pop("per_category_accuracy").tolist()
        per_category_iou = mean_iou_results.pop("per_category_iou").tolist()

        self.logger.debug(f"per_category_accuracy: {per_category_accuracy}")

        mean_iou_results.update(
            {f"accuracy_{self.id2label[i]}": v for i, v in enumerate(per_category_accuracy)}
        )
        mean_iou_results.update(
            {f"iou_{self.id2label[i]}": v for i, v in enumerate(per_category_iou)}
        )

        return mean_iou_results

##### ----- Test Script ----- #####
def test():
    HF_TOKEN         = "" # put your token here
    HF_DATASET       = "" # put dataset name here
    PRETRAINED_MODEL = "nvidia/mit-b0"
    AUGM             = Albu.Compose([ Albu.Resize(128,
                                                  256,
                                                  interpolation=cv2.INTER_AREA,
                                                  p=1.0) ])
    SEG_PROCESSOR    = SegformerImageProcessor(do_resize=False,
                                               do_rescale=True,
                                               do_normalize=True,
                                               do_reduce_labels=True)
    TRAINING_ARGS    = TrainingArguments(
        output_dir="241031_TEST",
        learning_rate=0.00006,
        num_train_epochs=10,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        save_total_limit=3,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        # save_steps=20,
        # eval_steps=20,
        logging_steps=1,
        eval_accumulation_steps=5,
        load_best_model_at_end=True,
        push_to_hub=True,
        hub_model_id="241028_TEST",
        hub_strategy="every_save",
        hub_private_repo=True
    )

    logger = py_utils.log.getCustomLogger(logger_name_=__name__ + "_test",
                                          node_name_="hf_trainer_test",
                                          log_handler_=logging.StreamHandler(),
                                          logging_level_=logging.INFO)
    
    logger.info("Starting test script...")

    hf_trainer = HfSegformerTrainer(hf_token_=HF_TOKEN,
                                    dataset_name_=HF_DATASET,
                                    pretrained_model_=PRETRAINED_MODEL,
                                    train_augm_=AUGM,
                                    valid_augm_=AUGM,
                                    segformer_proc_=SEG_PROCESSOR,
                                    training_args_=TRAINING_ARGS,
                                    log_level_=logging.DEBUG)
    
    logger.info("HfSegformerTrainer correctly initialized! Running training now...")

    hf_trainer.train()

    logger.info("Test script completed!")

if __name__ == "__main__":
    test()

Error:

Traceback (most recent call last):
  File "hf_segformer_trainer.py", line 312, in <module>
    test()
  File "hf_segformer_trainer.py", line 307, in test
    hf_trainer.train()
  File "hf_segformer_trainer.py", line 175, in train
    self.trainer.train()
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 1778, in train
    return inner_training_loop(
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2220, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2584, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 3365, in evaluate
    eval_dataloader = self.get_eval_dataloader(eval_dataset)
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 911, in get_eval_dataloader
    eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
  File "/home/andrea/.local/lib/python3.8/site-packages/transformers/trainer.py", line 787, in _remove_unused_columns
    return dataset.remove_columns(ignored_columns)
  File "/home/andrea/.local/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/andrea/.local/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/andrea/.local/lib/python3.8/site-packages/datasets/fingerprint.py", line 482, in wrapper
    out = func(dataset, *args, **kwargs)
  File "/home/andrea/.local/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 2206, in remove_columns
    dataset = copy.deepcopy(self)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 237, in _deepcopy_method
    return type(x)(x.__func__, deepcopy(x.__self__, memo))
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle '_thread.lock' object

However, I’ve noticed that if I create functions for train transform and valid transform outside of the class (as standard functions) this problem doesn’t occur. However I’d like to avoid this because it would force me to hard code the augmentations pipeline instead of passing them to the class directly.

I’ve also noticed that if I overwrite deepcopy() for my class as follows:

def __deepcopy__(self, memodict={}):
        cpyobj = type(self) 
        cpyobj = self
        return cpyobj

the problem doesn’t occur. However I’m not totally sure that this is a safe operation. If I’m not wrong, I am doing a simple shallow copy of the dataset instead of a deep copy. I’d like to know what datasets experts think about it.

Best solution would be to solve the deepcopy issue without any potentially unsafe operation like my deepcopy() ovveride.

Thanks to anyone will help me!

1 Like

It looks like a rather recent resolved issue of transformers, but I’m not sure if this is the actual issue.
If this is the case, maybe just an update will fix it?
Also, it seems that if lambda is mixed in with the model, it will not be able to pickle, but that is not likely to happen with the official model class.

I’ve update transformers package, as well as datasets package but still I have the same error.
By looking at the error report it seems that the error is in arrow_dataset.py module, in remove_columns() function. While the fix you mentioned was in transformers package so I was expecting that it wouldn’t have solved the issue.

I’m not very expert but by reading the error sentence:

TypeError: cannot pickle '_thread.lock' object

It seems that it is trying to pickle a thread object.
But I can not understand from where this thread object come from and why it “belongs” to the dataset object it is trying to deepcopy.

This reddit thread is not directly related to my scenario but it seems to explain the root cause of the problem (freezing to disk a thread object).

Do you have any idea what could be the source of this issue?

1 Like

I’ve read the source and followed the function flow, and I think there is a possible workaround. If it can be worked around, it’s a bug in the library. If you can’t work around it, then there is a real problem with the processing of the dataset.

pip install datasets==1.4.0

I’ve not tried the solution you proposed but maybe I’ve found a workaround to the problem.
Basically I’ve moved batcTransform() method outside to the class and passed it to training and validation dataset as lambda function.

It isn’t even anymore necessary to overwrite deepcopy function, which I think is a good point since it may be dangerous.

I’ve tested a little bit and it seems to be working. I will share the code with you for reference to any other person interested:

## MOVED OUTSIDE FROM CLASS!!!
def batchTransform(batch               : dict,
                   augm_pipeline       : Albu.Compose,
                   segformer_processor : SegformerImageProcessor) -> BatchFeature:
    """
    Transformation function applying augmentations and preprocessing to a given batch.

    Args:
        batch (dict): A batch containing pixel values and labels.
        augm_pipeline (Albu.Compose): Augmentation pipeline to apply.
        segformer_processor (SegformerImageProcessor): Processor object for preprocessing
            the image data.

    Returns:
        BatchFeature: A dictionary containing the transformed images and labels.
    """
    
    assert(len(batch["pixel_values"]) == len(batch["label"]))

    images = []
    labels = []

    ## Parse and augments both images and labels
    for (img_pil, label_pil) in zip(batch["pixel_values"], batch["label"]):
        augmented = augm_pipeline(image=np.array(img_pil.convert("RGB")),
                                  mask=np.array(label_pil))
        images.append(augmented["image"])
        labels.append(augmented["mask"])

    assert(len(images) == len(labels))

    # Complete preprocessing to provide data as expected by Segformer
    segformer_inputs = segformer_processor(images, labels)

    return segformer_inputs

class HfSegformerTrainer:
    """
    Utility class providing a training interface for a Hugging Face-based Segformer model.

    This class sets up and manages the training pipeline for a semantic segmentation model
    using the Segformer architecture, including model instantiation, data augmentation, 
    metric computation, and result uploads to the Hugging Face Hub.
    """

    def __init__(self,
                 hf_token_         : str,
                 dataset_name_     : str,
                 pretrained_model_ : str,
                 train_augm_       : Albu.Compose,
                 valid_augm_       : Albu.Compose,
                 segformer_proc_   : SegformerImageProcessor,
                 training_args_    : TrainingArguments,
                 log_level_        : int = logging.INFO) -> None:
        """
        Initializes a Segformer model and related configurations for training.

        Args:
            hf_token_ (str): Hugging Face authentication token.
            dataset_name_ (str): Name of the dataset to use for training.
            pretrained_model_ (str): Name or path of the pretrained Segformer model.
            train_augm_ (Albu.Compose): Augmentation pipeline for training dataset.
            valid_augm_ (Albu.Compose): Augmentation pipeline for validation dataset.
            segformer_proc_ (SegformerImageProcessor): Processor for input data transformations.
            training_args_ (TrainingArguments): Training arguments and configuration.
            log_level_ (int, optional): Logging level, default is `logging.INFO`.

        Raises:
            FileNotFoundError: If the Hugging Face authentication token is invalid.
        """

        ## Logger setup
        self.logger = py_utils.log.getCustomLogger(logger_name_=__name__,
                                                   node_name_="HfTrainer",
                                                   log_handler_=logging.StreamHandler(),
                                                   logging_level_=log_level_)

        self.logger.debug("__init__() begin!")

        # Login to Huggingface
        login(hf_token_)

        ## Save data
        self.dataset_name     = dataset_name_
        self.pretrained_model = pretrained_model_
        self.out_model_name   = training_args_.hub_model_id

        ## Retrieve dataset informations (optional)
        ds_builder = load_dataset_builder(self.dataset_name)

        self.logger.debug(f"ds_builder description: {ds_builder.info.description}")
        self.logger.debug(f"ds_builder features: {ds_builder.info.features}")

        ## Load dataset splits
        self.train_ds = load_dataset(self.dataset_name, split="train")
        self.valid_ds = load_dataset(self.dataset_name, split="valid")

        self.logger.info(f"Train dataset: {self.train_ds}")
        self.logger.info(f"Valid dataset: {self.valid_ds}")

        self.logger.debug(f"Image: {self.train_ds[0]}")

        ## Augmentation pipelines (based on Albumentations)
        self.train_augm = train_augm_
        self.valid_augm = valid_augm_

        # Segformer Image processor
        self.segformer_processor = segformer_proc_
        
        ## Set transformation pipelines for each dataset split
        ## PASSING TRANSFORM FUNCTION AS LAMBDAS!!!
        self.train_ds.set_transform(
            lambda batch: batchTransform(batch, self.train_augm, self.segformer_processor)
        )
        self.valid_ds.set_transform(
            lambda batch: batchTransform(batch, self.valid_augm, self.segformer_processor)
        )

        self.logger.debug(f"Train dataset format: {self.train_ds.format}")

        ## Retrieve id and labels of the dataset (assuming there is a id2label.json file)
        self.id2label = json.load(open(hf_hub_download(repo_id=self.dataset_name,
                                                       filename="id2label.json",
                                                       repo_type="dataset"), "r"))
        self.id2label = {int(k): v for k, v in self.id2label.items()}
        label2id = {v: k for k, v in self.id2label.items()}

        self.logger.info(f"self.id2label: {self.id2label}")

        # Index to be ignored by evaluation metrics (not suggested to change)
        self.IGNORE_IDX = 255

        ## Segformer model instantiation
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            pretrained_model_name_or_path=self.pretrained_model,
            id2label=self.id2label,
            label2id=label2id
        )

        ## Initialize metrics (mIoU and Confusion Matrix)
        self.mean_iou_metric = evaluate.load("mean_iou")

        ## Initialize the Trainer
        self.trainer = Trainer(model=self.model,
                               args=training_args_,
                               train_dataset=self.train_ds,
                               eval_dataset=self.valid_ds,
                               compute_metrics=self.computeMetrics)

        self.logger.debug("__init__() completed!")

    def train(self) -> None:
        """
        Runs the training pipeline for the Segformer model.
        """

        self.trainer.train()

    def uploadResultstoHfHub(self) -> None:
        """
        Uploads the training results and model to the Hugging Face Hub.

        Notes:
            The Segformer processor and trainer will both push their outputs to the Hub.
        """

        self.logger.debug("Publishing results to Huggingface Hub...")

        kwargs = {
            "tags": ["vision", "image-segmentation"],
            "finetuned_from": self.pretrained_model,
            "dataset": self.dataset_name,
        }

        self.segformer_processor.push_to_hub(self.out_model_name, private=True)
        self.trainer.push_to_hub(**kwargs)

        self.logger.info("Training results saved to Huggingface Hub!")
    
    def computeMetrics(self, eval_pred_ : EvalPrediction) -> Dict[str, Any]:
        """
        Computes evaluation metrics for the predictions.

        Args:
            eval_pred_ (EvalPrediction): Object containing model predictions and labels.

        Returns:
            Dict[str, Any]: A dictionary with calculated metrics, including per-category accuracy
            and IoU values.
        """
        
        with torch.no_grad():
            logits, labels = eval_pred_
            logits_tensor = torch.from_numpy(logits)
            # Upscale the logits to the size of the label
            logits_tensor = nn.functional.interpolate(input=logits_tensor,
                                                      size=labels.shape[-2:],
                                                      mode="bilinear",
                                                      align_corners=False).argmax(dim=1)
            pred_labels = logits_tensor.detach().cpu().numpy()

            mean_iou_results = self.mean_iou_metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(self.id2label),
                ignore_index=self.IGNORE_IDX,
                reduce_labels=False # we've already reduced the labels ourselves
            )
            
        # add per category metrics as individual key-value pairs
        per_category_accuracy = mean_iou_results.pop("per_category_accuracy").tolist()
        per_category_iou = mean_iou_results.pop("per_category_iou").tolist()

        self.logger.debug(f"per_category_accuracy: {per_category_accuracy}")

        mean_iou_results.update(
            {f"accuracy_{self.id2label[i]}": v for i, v in enumerate(per_category_accuracy)}
        )
        mean_iou_results.update(
            {f"iou_{self.id2label[i]}": v for i, v in enumerate(per_category_iou)}
        )

        return mean_iou_results
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.