I am following this page to load div2k from super-image, and I want to convert it so that I can train my model built in PyTorch.
To do that I am following Using a Dataset with PyTorch/Tensorflow — datasets 1.11.0 documentation. However, I got an error
TypeError: new(): invalid data type 'numpy.str_'
My code is
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop
import torch
augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
.map(augment_five_crop, batched=True, desc="Augmenting Dataset") # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset) # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation')) # prepare the eval dataset for the PyTorch DataLoader
augmented_dataset.set_format(type='torch', columns= ["lr", "hr"])
augmented_dataset.__getitem__(0)
and the full error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-1-5b6b28c36bc5> in <module>
12
13 augmented_dataset.set_format(type='torch', columns= ["lr", "hr"])
---> 14 augmented_dataset.__getitem__(0)
~\anaconda3\lib\site-packages\datasets\arrow_dataset.py in __getitem__(self, key)
1515 def __getitem__(self, key: Union[int, slice, str]) -> Union[Dict, List]:
1516 """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 1517 return self._getitem(
1518 key,
1519 format_type=self._format_type,
~\anaconda3\lib\site-packages\datasets\arrow_dataset.py in _getitem(self, key, format_type, format_columns, output_all_columns, format_kwargs)
1508 formatter = get_formatter(format_type, **format_kwargs)
1509 pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
-> 1510 formatted_output = format_table(
1511 pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
1512 )
~\anaconda3\lib\site-packages\datasets\formatting\formatting.py in format_table(table, key, formatter, format_columns, output_all_columns)
420 else:
421 pa_table_to_format = pa_table.drop(col for col in pa_table.column_names if col not in format_columns)
--> 422 formatted_output = formatter(pa_table_to_format, query_type=query_type)
423 if output_all_columns:
424 if isinstance(formatted_output, MutableMapping):
~\anaconda3\lib\site-packages\datasets\formatting\formatting.py in __call__(self, pa_table, query_type)
192 def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]:
193 if query_type == "row":
--> 194 return self.format_row(pa_table)
195 elif query_type == "column":
196 return self.format_column(pa_table)
~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in format_row(self, pa_table)
57 def format_row(self, pa_table: pa.Table) -> dict:
58 row = self.numpy_arrow_extractor().extract_row(pa_table)
---> 59 return self.recursive_tensorize(row)
60
61 def format_column(self, pa_table: pa.Table) -> "torch.Tensor":
~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in recursive_tensorize(self, data_struct)
53
54 def recursive_tensorize(self, data_struct: dict):
---> 55 return map_nested(self._recursive_tensorize, data_struct, map_list=False)
56
57 def format_row(self, pa_table: pa.Table) -> dict:
~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in map_nested(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, types)
202 num_proc = 1
203 if num_proc <= 1 or len(iterable) <= num_proc:
--> 204 mapped = [
205 _single_map_nested((function, obj, types, None, True))
206 for obj in utils.tqdm(iterable, disable=disable_tqdm)
~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in <listcomp>(.0)
203 if num_proc <= 1 or len(iterable) <= num_proc:
204 mapped = [
--> 205 _single_map_nested((function, obj, types, None, True))
206 for obj in utils.tqdm(iterable, disable=disable_tqdm)
207 ]
~\anaconda3\lib\site-packages\datasets\utils\py_utils.py in _single_map_nested(args)
141 # Singleton first to spare some computation
142 if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
--> 143 return function(data_struct)
144
145 # Reduce logging to keep things readable in multiprocessing with tqdm
~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in _recursive_tensorize(self, data_struct)
50 if data_struct.dtype == np.object: # pytorch tensors cannot be instantied from an array of objects
51 return [self.recursive_tensorize(substruct) for substruct in data_struct]
---> 52 return self._tensorize(data_struct)
53
54 def recursive_tensorize(self, data_struct: dict):
~\anaconda3\lib\site-packages\datasets\formatting\torch_formatter.py in _tensorize(self, value)
42 default_dtype = {"dtype": torch.float32}
43
---> 44 return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
45
46 def _recursive_tensorize(self, data_struct: dict):
TypeError: new(): invalid data type 'numpy.str_'
The same error is thrown if I use PyTorch DataLoader
my_ldr = torch.utils.data.DataLoader(augmented_dataset, batch_size=1, shuffle=False)
next(iter(my_ldr))