Hi!
You can use the to_tf_dataset()
function which just got a nice rework. If your elements are all the same length, then the built-in collator will handle it (otherwise you’ll need a custom collator). You can just do:
tf_train = dataset.to_tf_dataset(columns=["input"],
label_cols=["labels"],
batch_size=8,
shuffle=True)
Check out the docs here for more details about