Dataset map return only list instead torch tensors

Hi! map ignores tensor formatting while writing a cache file, so to get PyTorch tensors under the input_ids column, you need to explicitly call set_format("pt", columns=["input_ids"], output_all_columns=True) on the dataset object (after map).

6 Likes