Hi! map
ignores tensor formatting while writing a cache file, so to get PyTorch tensors under the input_ids
column, you need to explicitly call set_format("pt", columns=["input_ids"], output_all_columns=True)
on the dataset object (after map
).
6 Likes