Tensorflow Huggingface Datasets Equivalent to PyTorch

Hi!

You can use the to_tf_dataset() function which just got a nice rework. If your elements are all the same length, then the built-in collator will handle it (otherwise you’ll need a custom collator). You can just do:

tf_train = dataset.to_tf_dataset(columns=["input"], 
                                 label_cols=["labels"],
                                 batch_size=8,
                                 shuffle=True)

Check out the docs here for more details about :slight_smile:

1 Like