Loading div2k from super-image into Pytorch

I am following this page to load div2k from super-image, and I want to convert it so that I can train my model built in PyTorch.

To do that I am following Using a Dataset with PyTorch/Tensorflow — datasets 1.11.0 documentation. However, I got an error

TypeError: new(): invalid data type 'numpy.str_'

My code is

from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop
import torch

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
    .map(augment_five_crop, batched=True, desc="Augmenting Dataset")                                # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset)                                                     # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation'))      # prepare the eval dataset for the PyTorch DataLoader

augmented_dataset.set_format(type='torch', columns= ["lr", "hr"])
augmented_dataset.__getitem__(0)

and the full error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-5b6b28c36bc5> in <module>
     12 
     13 augmented_dataset.set_format(type='torch', columns= ["lr", "hr"])
---> 14 augmented_dataset.__getitem__(0)

~\anaconda3\lib\site-packages\datasets\arrow_dataset.py in __getitem__(self, key)
   1515     def __getitem__(self, key: Union[int, slice, str]) -> Union[Dict, List]:
   1516         """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 1517         return self._getitem(
   1518             key,
   1519             format_type=self._format_type,

~\anaconda3\lib\site-packages\datasets\arrow_dataset.py in _getitem(self, key, format_type, format_columns, output_all_columns, format_kwargs)
   1508         formatter = get_formatter(format_type, **format_kwargs)
   1509         pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
-> 1510         formatted_output = format_table(
   1511             pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
   1512         )

~\anaconda3\lib\site-packages\datasets\formatting\formatting.py in format_table(table, key, formatter, format_columns, output_all_columns)
    420     else:
    421         pa_table_to_format = pa_table.drop(col for col in pa_table.column_names if col not in format_columns)
--> 422         formatted_output = formatter(pa_table_to_format, query_type=query_type)
    423         if output_all_columns:
    424             if isinstance(formatted_output, MutableMapping):

~\anaconda3\lib\site-packages\datasets\formatting\formatting.py in __call__(self, pa_table, query_type)
    192     def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]:
    193         if query_type == "row":
--> 194             return self.format_row(pa_table)
    195         elif query_type == "column":
    196             return self.format_column(pa_table)

~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in format_row(self, pa_table)
     57     def format_row(self, pa_table: pa.Table) -> dict:
     58         row = self.numpy_arrow_extractor().extract_row(pa_table)
---> 59         return self.recursive_tensorize(row)
     60 
     61     def format_column(self, pa_table: pa.Table) -> "torch.Tensor":

~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in recursive_tensorize(self, data_struct)
     53 
     54     def recursive_tensorize(self, data_struct: dict):
---> 55         return map_nested(self._recursive_tensorize, data_struct, map_list=False)
     56 
     57     def format_row(self, pa_table: pa.Table) -> dict:

~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in map_nested(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, types)
    202         num_proc = 1
    203     if num_proc <= 1 or len(iterable) <= num_proc:
--> 204         mapped = [
    205             _single_map_nested((function, obj, types, None, True))
    206             for obj in utils.tqdm(iterable, disable=disable_tqdm)

~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in <listcomp>(.0)
    203     if num_proc <= 1 or len(iterable) <= num_proc:
    204         mapped = [
--> 205             _single_map_nested((function, obj, types, None, True))
    206             for obj in utils.tqdm(iterable, disable=disable_tqdm)
    207         ]

~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in _single_map_nested(args)
    141     # Singleton first to spare some computation
    142     if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
--> 143         return function(data_struct)
    144 
    145     # Reduce logging to keep things readable in multiprocessing with tqdm

~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in _recursive_tensorize(self, data_struct)
     50             if data_struct.dtype == np.object:  # pytorch tensors cannot be instantied from an array of objects
     51                 return [self.recursive_tensorize(substruct) for substruct in data_struct]
---> 52         return self._tensorize(data_struct)
     53 
     54     def recursive_tensorize(self, data_struct: dict):

~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in _tensorize(self, value)
     42             default_dtype = {"dtype": torch.float32}
     43 
---> 44         return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
     45 
     46     def _recursive_tensorize(self, data_struct: dict):

TypeError: new(): invalid data type 'numpy.str_'

The same error is thrown if I use PyTorch DataLoader

my_ldr = torch.utils.data.DataLoader(augmented_dataset, batch_size=1, shuffle=False)
next(iter(my_ldr))

Hi ! What are the feature types of your dataset ?

print(augmented_dataset.features)

The ‘torch’ formats only work for columns of numeric data, since strings can’t be converted to pytorch tensors. Could it be that one of “lr” and “hr” has text data ?

Hi @lhoestq,

Yes, the features are “lr” and “hr”. augmented_dataset comes in as the dataset.Dataset class once I load them. This is what augmented_dataset looks like

Dataset({
    features: ['hr', 'lr'],
    num_rows: 4000
})

If I try using numbers instead,

augmented_dataset.set_format(type='torch', columns= [0,1])

I get this errorr

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-2a3051e19cd6> in <module>
----> 1 augmented_dataset.set_format(type='torch', columns= [0,1])

~\anaconda3\lib\site-packages\datasets\fingerprint.py in wrapper(*args, **kwargs)
    395             # Call actual function
    396 
--> 397             out = func(self, *args, **kwargs)
    398 
    399             # Update fingerprint of in-place transforms + update in-place history of transforms

~\anaconda3\lib\site-packages\datasets\arrow_dataset.py in set_format(self, type, columns, output_all_columns, **format_kwargs)
   1349             columns = [columns]
   1350         if columns is not None and any(col not in self._data.column_names for col in columns):
-> 1351             raise ValueError(
   1352                 "Columns {} not in the dataset. Current columns in the dataset: {}".format(
   1353                     list(filter(lambda col: col not in self._data.column_names, columns)), self._data.column_names

ValueError: Columns [0, 1] not in the dataset. Current columns in the dataset: ['hr', 'lr']

I’ve reached the point where I should set columns=['hr', 'lr'] but also I can’t use string type to represent them. This seems a bit contradicting.

As you said you have to pass the column names as strings in set_format: columns=['hr', 'lr'].

However this is the feature type of the data itself of each column that may actually cause the issue.
When I was talking about string type, I was referring to the type of the data inside each column (for example the type of augmented_dataset[0]["hr"])

You can check if the feature type of each column with

print(augmented_dataset.features)

or you can directly check the first row augmented_dataset[0]