[Solved] Error using dataset map function to get hidden states from Chapter 2

I’m trying to replicate the extract_hidden_states from Chapter 2 of the Transformers text and I’m getting the following error and I’m not sure what the issue is exactly:
TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool..

The code is exactly similar to the text and the only difference is I’m using local dataset usig csv’s rather than the a in-built dataset.

Here’s the code which causes the error and the full error-trace:

dataset_encoded.set_format("torch", columns=["input_ids", "attention_mask", "sentiment"])
dataset_hidden = dataset_encoded.map(extract_hidden_states, batched=True) 

Error trace:

--------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/harish/Desktop/SeaWord/action_extraction/hf/nbs/text_classification.ipynb Cell 29' in <cell line: 5>()
      1 # converting the input_ids and attention_mask columns to the "torch" format
      2 dataset_encoded.set_format("torch",
      3                             columns=["input_ids", "attention_mask", "sentiment"])
----> 5 dataset_hidden = dataset_encoded.map(extract_hidden_states, batched=True)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/dataset_dict.py:494, in DatasetDict.map(self, function, with_indices, input_columns, batched, batch_size, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)
    491 if cache_file_names is None:
    492     cache_file_names = {k: None for k in self}
    493 return DatasetDict(
--> 494     {
    495         k: dataset.map(
    496             function=function,
    497             with_indices=with_indices,
    498             input_columns=input_columns,
    499             batched=batched,
    500             batch_size=batch_size,
    501             remove_columns=remove_columns,
    502             keep_in_memory=keep_in_memory,
    503             load_from_cache_file=load_from_cache_file,
    504             cache_file_name=cache_file_names[k],
    505             writer_batch_size=writer_batch_size,
    506             features=features,
    507             disable_nullable=disable_nullable,
    508             fn_kwargs=fn_kwargs,
    509             num_proc=num_proc,
    510             desc=desc,
    511         )
    512         for k, dataset in self.items()
    513     }
    514 )

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/dataset_dict.py:495, in <dictcomp>(.0)
    491 if cache_file_names is None:
    492     cache_file_names = {k: None for k in self}
    493 return DatasetDict(
    494     {
--> 495         k: dataset.map(
    496             function=function,
    497             with_indices=with_indices,
    498             input_columns=input_columns,
    499             batched=batched,
    500             batch_size=batch_size,
    501             remove_columns=remove_columns,
    502             keep_in_memory=keep_in_memory,
    503             load_from_cache_file=load_from_cache_file,
    504             cache_file_name=cache_file_names[k],
    505             writer_batch_size=writer_batch_size,
    506             features=features,
    507             disable_nullable=disable_nullable,
    508             fn_kwargs=fn_kwargs,
    509             num_proc=num_proc,
    510             desc=desc,
    511         )
    512         for k, dataset in self.items()
    513     }
    514 )

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/arrow_dataset.py:2092, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   2089 disable_tqdm = bool(logging.get_verbosity() == logging.NOTSET) or not utils.is_progress_bar_enabled()
   2091 if num_proc is None or num_proc == 1:
-> 2092     return self._map_single(
   2093         function=function,
   2094         with_indices=with_indices,
   2095         with_rank=with_rank,
   2096         input_columns=input_columns,
   2097         batched=batched,
   2098         batch_size=batch_size,
   2099         drop_last_batch=drop_last_batch,
   2100         remove_columns=remove_columns,
   2101         keep_in_memory=keep_in_memory,
   2102         load_from_cache_file=load_from_cache_file,
   2103         cache_file_name=cache_file_name,
   2104         writer_batch_size=writer_batch_size,
   2105         features=features,
   2106         disable_nullable=disable_nullable,
   2107         fn_kwargs=fn_kwargs,
   2108         new_fingerprint=new_fingerprint,
   2109         disable_tqdm=disable_tqdm,
   2110         desc=desc,
   2111     )
   2112 else:
   2114     def format_cache_file_name(cache_file_name, rank):

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/arrow_dataset.py:518, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    516     self: "Dataset" = kwargs.pop("self")
    517 # apply actual function
--> 518 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    519 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    520 for dataset in datasets:
    521     # Remove task templates if a column mapping of the template is no longer valid

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/arrow_dataset.py:485, in transmit_format.<locals>.wrapper(*args, **kwargs)
    478 self_format = {
    479     "type": self._format_type,
    480     "format_kwargs": self._format_kwargs,
    481     "columns": self._format_columns,
    482     "output_all_columns": self._output_all_columns,
    483 }
    484 # apply actual function
--> 485 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    486 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    487 # re-apply format to the output

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/fingerprint.py:411, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
    405             kwargs[fingerprint_name] = update_fingerprint(
    406                 self._fingerprint, transform, kwargs_for_fingerprint
    407             )
    409 # Call actual function
--> 411 out = func(self, *args, **kwargs)
    413 # Update fingerprint of in-place transforms + update in-place history of transforms
    415 if inplace:  # update after calling func so that the fingerprint doesn't change if the function fails

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/arrow_dataset.py:2461, in Dataset._map_single(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only)
   2459 if drop_last_batch and i + batch_size > input_dataset.num_rows:
   2460     continue
-> 2461 batch = input_dataset._getitem(
   2462     slice(i, i + batch_size),
   2463     decoded=False,
   2464 )
   2465 indices = list(
   2466     range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
   2467 )  # Something simpler?
   2468 try:

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/arrow_dataset.py:1900, in Dataset._getitem(self, key, decoded, **kwargs)
   1898 formatter = get_formatter(format_type, features=self.features, decoded=decoded, **format_kwargs)
   1899 pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
-> 1900 formatted_output = format_table(
   1901     pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns
   1902 )
   1903 return formatted_output

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/formatting.py:539, in format_table(table, key, formatter, format_columns, output_all_columns)
    537 else:
    538     pa_table_to_format = pa_table.drop(col for col in pa_table.column_names if col not in format_columns)
--> 539     formatted_output = formatter(pa_table_to_format, query_type=query_type)
    540     if output_all_columns:
    541         if isinstance(formatted_output, MutableMapping):

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/formatting.py:284, in Formatter.__call__(self, pa_table, query_type)
    282     return self.format_column(pa_table)
    283 elif query_type == "batch":
--> 284     return self.format_batch(pa_table)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/torch_formatter.py:67, in TorchFormatter.format_batch(self, pa_table)
     65 def format_batch(self, pa_table: pa.Table) -> dict:
     66     batch = self.numpy_arrow_extractor().extract_batch(pa_table)
---> 67     return self.recursive_tensorize(batch)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/torch_formatter.py:55, in TorchFormatter.recursive_tensorize(self, data_struct)
     54 def recursive_tensorize(self, data_struct: dict):
---> 55     return map_nested(self._recursive_tensorize, data_struct, map_list=False)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/utils/py_utils.py:261, in map_nested(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, types, disable_tqdm)
    259     num_proc = 1
    260 if num_proc <= 1 or len(iterable) <= num_proc:
--> 261     mapped = [
    262         _single_map_nested((function, obj, types, None, True))
    263         for obj in utils.tqdm(iterable, disable=disable_tqdm)
    264     ]
    265 else:
    266     split_kwds = []  # We organize the splits ourselve (contiguous splits)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/utils/py_utils.py:262, in <listcomp>(.0)
    259     num_proc = 1
    260 if num_proc <= 1 or len(iterable) <= num_proc:
    261     mapped = [
--> 262         _single_map_nested((function, obj, types, None, True))
    263         for obj in utils.tqdm(iterable, disable=disable_tqdm)
    264     ]
    265 else:
    266     split_kwds = []  # We organize the splits ourselve (contiguous splits)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/utils/py_utils.py:197, in _single_map_nested(args)
    195 # Singleton first to spare some computation
    196 if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
--> 197     return function(data_struct)
    199 # Reduce logging to keep things readable in multiprocessing with tqdm
    200 if rank is not None and logging.get_verbosity() < logging.WARNING:

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/torch_formatter.py:52, in TorchFormatter._recursive_tensorize(self, data_struct)
     50     if data_struct.dtype == np.object:  # pytorch tensors cannot be instantied from an array of objects
     51         return [self.recursive_tensorize(substruct) for substruct in data_struct]
---> 52 return self._tensorize(data_struct)

File ~/miniconda3/envs/dl/lib/python3.9/site-packages/datasets/formatting/torch_formatter.py:44, in TorchFormatter._tensorize(self, value)
     41 elif np.issubdtype(value.dtype, np.floating):
     42     default_dtype = {"dtype": torch.float32}
---> 44 return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

Turns out I didn’t numericalize by labels! Simple error and the trace was quite evidently pointing at that and it didn’t strike me! -.-