An example collate function may look like this in your case:
def collate_tokenize(data):
text_batch = [element["text"] for element in data]
tokenized = tokenizer(text_batch, padding='longest', truncation=True, return_tensors='pt')
return tokenized
Then you pass to the dataloader:
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_tokenize
)
Also, here’s a somewhat outdated article that has an example of collate function.