MyPy and DatasetDict. Error: Incompatible return value type (got "Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]", expected "DatasetDict")

Has anyone else had issues with getting DatasetDict objects being parsed correctly in mypy?

The following code snippet should return a DatasetDict but MyPy treats it as a something very different.

This code snippet returns

from datasets import DatasetDict, load_dataset

     def load_data_from_files(self, inputs_per_task: List[str]) -> DatasetDict:
        data_files = {}
        for data_file in inputs_per_task:
            if self.train_task_inputs and data_file in self.train_task_inputs:
                data_files["train"] = data_file
            elif self.test_task_inputs and data_file in self.test_task_inputs:
                data_files["test"] = data_file
            else:
                data_files["validation"] = data_file
        data = load_dataset("csv", data_files=data_files)
        print()
        print(f"[+]{data=}")
        print(f"[+]{type(data)=}")
        return data
>>>
[+]data=DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 20
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 20
    })
    validation: Dataset({
        features: ['text', 'labels'],
        num_rows: 20
    })
})
[+]type(data)=<class 'datasets.dataset_dict.DatasetDict'>

But the mypy compiles the following error.

data.py:251: error: Incompatible return value type (got "Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]", expected "DatasetDict")
Found 1 error in 1 file (checked 1 source file)

As we can see, the type of the output data is clearly <class ‘datasets.dataset_dict.DatasetDict’>. So why does mypy compile this specific dataset to Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] ?

Any advice or pointers would be much appreciated! :slight_smile:

The problem appears to come from this line in the docs.

Because the function has many different return datatypes depending on the input parameters, it makes it a poor candidate for static typing :frowning:

Is there a way to ensure that a datasetDict is returned given my input values?

    Returns:
        :class:`Dataset` or :class:`DatasetDict`:
        - if `split` is not None: the dataset requested,
        - if `split` is None, a ``datasets.DatasetDict`` with each split.

        or :class:`IterableDataset` or :class:`IterableDatasetDict`: if streaming=True

        - if `split` is not None: the dataset requested,
        - if `split` is None, a ``datasets.streaming.IterableDatasetDict`` with each split.

Hi ! I think we can fix this by documenting the type hints per input types using @overload for mypy :slight_smile: See an example in the mypy docs: More types — Mypy 0.942 documentation

If you are interested in contributing to datasets, feel free to open an issue and/or a pull request here: GitHub - huggingface/datasets: 🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools !