Dataset map return only list instead torch tensors

when i use the map on Dataset object the 'input_ids return as list instead of tensors

def tokenize(batch):
  return tokenizer(batch['text'],padding=True, return_tensors='pt', truncation=True).to(DEVICE), batched=False, batch_size=None) → return list

tokenizer(data['text'],padding=True, return_tensors='pt', truncation=True).to(DEVICE) → return tensor


Hi! map ignores tensor formatting while writing a cache file, so to get PyTorch tensors under the input_ids column, you need to explicitly call set_format("pt", columns=["input_ids"], output_all_columns=True) on the dataset object (after map).


I got burned by this as well. Maybe the documentation (examples) for map could be updated to explain this?


I got burned by this too! Definitely seems like an issue in the library that .map doesn’t return the objects that the mapped function returns.

And I’d just finished the course which has a whole section about bugs caused by not setting return_tensors, but the examples never mention this issue with .map.

Hi! When it comes to tensors, PyArrow (the storage format we use) only understands 1D arrays, so we would have to store (potentially) a significant amount of metadata to be able to restore the types after map fully. Also, a map transform can return different value types for the same column (e.g. PyTorch tensors or Python lists), which would make this process ambiguous (e.g. should we return mixed types or only one type when indexing the dataset afterwards). So for the sake of simplicity, we return Python objects by default, and for other formats, one can use with_format/with_transform.


I ended up using a data collator to take the lists and turn them back into tensors

Yep, I also got burned by this. What is the point here to return list of lists when it’s a torch tensor? Inside the map function, it’s a tensor, then after it’s a list?

Also burned by this coercion, would be helpful to note the PyArrow limitation (although Tensor may be helpful here?) and workarounds in the docs…

Edit: looks like there’s some discussion on how to use PyArrow’s Tensor type at Use pyarrow Tensor dtype · Issue #5272 · huggingface/datasets · GitHub