I think you also need to specify which columns you’d like to keep when doing .set_format(type='torch')
. If you don’t do this, then the text columns are still part of the dataset, and converting strings to PyTorch tensors causes an error.
So I think you just need to update that line to:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])