Understanding set_transform

I’ve been working on a side project that uses phonetic English language models for text generation. Since I’m not aware of any existing phonetic English datasets, I’ve been preprocessing existing English text datasets with my phonemization script to give myself enough training data. Mainly OSCAR for pretraining the model, and then my own small datasets for fine-tuning on specific tasks.

My workflow has been:

  1. downloading the 2.5TB oscar_en shuffled text file
  2. processing it (in chunks) to its phonetic representation and saving those text files to disk
  3. batch tokenizing those files and saving them to a local HuggingFace dataset, because it takes hours (or days) to tokenize the whole thing at the beginning of a training

Even with only 3% of the original OSCAR corpus phonemized, my dataset is up to over a 1.2TB on disk. Which I was okay with – I’m running out of local storage, but I was never going to be able to use the whole OSCAR corpus anyway on my rinky-dink home setup.

But this month has brought two things to HuggingFace – OSCAR in the datasets library, and on-the-fly transforms.

Am I right in understanding that I could load the oscar_en corpus from the HF Dataset, and then pass to set_transform a function that would phonemize and tokenize the samples, and the only hit to my disk would be the arrow cache of the original OSCAR dataset? And that I would be able to quickly resume training from my most recent checkpoint, since it’d just be loading from that cache?

I imagine it’ll slow the overall training down and I might not be able to feed my GPUs as quickly as I’d like, but the simplified workflow might be worth the performance hit (especially if I find another bug in my phonemizer script that makes me want to redo everything)

Building off of that, if one wanted to do a BART-style pretraining, would it be possible start with a single-column dataset, and pass to set_transform a function that returns the tokenized original dataset as the targets, and a randomly-masked version of the original tokens as the inputs, all on the fly? [Forgive me if this last question is dumb or nonsensical, I have a very limited understanding of seq2seq training / how BART works]

set_transform does not cache the resulting data. Depending on the data/storage you have available you may want to opt for map. Both have a low memory footprint.

Great, thanks for confirming that it doesn’t cache to disk. That’s exactly what I was hoping.

I guess now I’ll have to update to the latest masters and start testing how much on-the-fly tokenization and other data transforms slow my training down.

Indeed if you use set_transform then the resulting phonemized data are created on-the-fly and not stored/cached. Only the original OSCAR data are stored on your disk as an arrow file.

And you’re right your second point about BART-style pretraining: you can pass a function to set_transform that returns two fields, one that is the original text and one that is randomly masked, even if you have only one column in your dataset.

Thanks, @lhoestq! That’s an incredibly cool and useful feature. Can’t wait to play with it.

Hi @lhoestq! I’m finally getting around to testing some set_transform workflows and I have a question.

I’ve passed a fairly CPU-heavy preprocessing function to set_transform. After about an hour of training, I forced the training to stop and then tried to resume from the last checkpoint.

It’s been over 15 minutes of heavy CPU activity since I resumed, and the training progress indicator is still on step zero. [UPDATE: training finally resumed after 46 minutes] Is it possible that my transform function being called on every sample as the trainer advances to the last checkpoint step?

If that is what’s happening, I’m not sure if it’s due to the dataset or trainer. Is there any way to avoid it?

You’re right, your transform is being called for every sample until the last checkpoint step.
This is done in the Trainer code here

Maybe at that point in the code you can remove the transform and then re-add it ?
To remove the formatting transform you can use dataset.reset_format(), and then you can set the transform again.

Thanks for that pointer.

If I’m reading the code right, I could save the current transform function by storing the dict returned by dataset.format(), remove the transform and let the trainer advance to the current checkpoint, and then use dataset.set_format() with the stored format values for args to reload the original transform.

Would that get me back to the checkpoint faster and still give me the same values from the dataset?

Yes exactly.
You can get the format with dataset.format, then you can remove the formatting transform with dataset.reset_format. At this point you can run the for loop that iterates over the dataloader to make it reach the requested checkpoint. Finally after that you can set the transform back with dataset.set_format.

Hope that helps !

1 Like

Do you think this might be a generally useful update to the trainer? Or are there use cases where you’d want to maintain a dataset’s formatting function while iterating up to the last checkpoint?

1 Like

This may be a nice addition indeed ! Feel free to open an issue on the transformers repo on github.
You can tag me @lhoestq and also @sgugger who’s in charge of the Trainer related stuff :slight_smile:

1 Like