Has anyone else had issues with getting DatasetDict objects being parsed correctly in mypy?
The following code snippet should return a DatasetDict but MyPy treats it as a something very different.
This code snippet returns
from datasets import DatasetDict, load_dataset
def load_data_from_files(self, inputs_per_task: List[str]) -> DatasetDict:
data_files = {}
for data_file in inputs_per_task:
if self.train_task_inputs and data_file in self.train_task_inputs:
data_files["train"] = data_file
elif self.test_task_inputs and data_file in self.test_task_inputs:
data_files["test"] = data_file
else:
data_files["validation"] = data_file
data = load_dataset("csv", data_files=data_files)
print()
print(f"[+]{data=}")
print(f"[+]{type(data)=}")
return data
>>>
[+]data=DatasetDict({
train: Dataset({
features: ['text', 'labels'],
num_rows: 20
})
test: Dataset({
features: ['text', 'labels'],
num_rows: 20
})
validation: Dataset({
features: ['text', 'labels'],
num_rows: 20
})
})
[+]type(data)=<class 'datasets.dataset_dict.DatasetDict'>
But the mypy compiles the following error.
data.py:251: error: Incompatible return value type (got "Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]", expected "DatasetDict")
Found 1 error in 1 file (checked 1 source file)
As we can see, the type of the output data
is clearly <class ‘datasets.dataset_dict.DatasetDict’>. So why does mypy compile this specific dataset to Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]
?
Any advice or pointers would be much appreciated!