Fine-tuning T5 with custom datasets

Hi folks,

I am a newbie to T5 and transformers in general so apologies in advance for any stupidity or incorrect assumptions on my part!

I am trying to put together an example of fine-tuning the T5 model to use a custom dataset for a custom task. I have the “How to fine-tune a model on summarization” example notebook working but that example uses a pre-configured HF dataset via “load_dataset()” not a custom dataset that I load from disk. So I was wanting to combine that example with the guidance given at “Fine-tuning with custom datasets” but with T5 and not DistilBert as in the fine-tuning example shown.

I think my main problem is knowing how to construct a dataset object that the pre-configured T5 model can consume. So here is my use of the tokenizer and my attempt at formating the tokenized sequencies into datasets:

But I get the following error back when I call trainer.train():


I have seen the post “Defining a custom dataset for fine-tuning translation” but the solution offered there seems to be write your own custom Dataset loading class rather than directly providing a solution to the problem - I can try to learn/do this but it would be great to get this working equivalent to “Fine-tuning with custom datasets” but for the T5 model I want to use.

I also found “Fine Tuning Transformer for Summary Generation” which is where I got the idea to change the getitem method of my ToxicDataset class to return “input_ids” “input_mask” “output_ids” “output_mask” but I am guessing really, I can’t find any documentation of what is needed (sorry!).

Any help or pointers to find what I need would be very much appreciated!

I think I may have found a way around this issue (or at least the trainer starts and completes!). The subclassing of a object for the distilbert example in “Fine-tuning with custom datasets” needs changing as follows. I guess because the distilbert model provides just a list of integers whereas the T5 model has output texts and I assume the DataCollatorForSeq2Seq() takes care of preprocessing the labels (the output encodings) into the features needed by forward function of T5 model (I am guessing, but this is what I am assuming from what I have read). Code changes below:

Response via github from sgugger:

This tutorial is out of date and will be rewritten soon. You should have a look at the maintained examples or the example notebooks instead.

1 Like

So what example in the links above is appropriate for fine-tuning T5?

Clone the transformers repo and I used transformers/examples/pytorch/summarization/ It is not limited to summarization with T5, just configure the input/output text to what you need.