Multi-stage finetuning of transformer models

I’ve been doing some fine-tuning work with GPT-J.

I have three primary sources of data:

  • documents we have generated internally
  • publically-accessible documents we think are NOT included in GPT-J’s training data
  • a larger database, much of which may have been in GPT-J’s training data, albeit in the “wrong” format for our use. (Our documents are highly structured.)

The public databases are related but are a superset of our documents. (E.g., if the public database were “Wikipedia articles about animals,” our application might be the subset “Wikipedia articles about aquatic mammals weighing less than a ton.”

The public documents are much closer to our application than GPT-J’s overall training data, but few of them are an exact match. I’ve done my best to extract those. There’s also a time-consuming manual curation process.

I’ve recently written some NLTK stuff to automate getting documents into something close to our format. I’ve applied this to all of the uncurated data from all sources. Documents in C must be manually curated before moving to D, but this is a good start.

This leaves me with four categories of data currently, in order of decreasing size and increasing relevancy:

A. The “GPT-J training overlap” domain-general documents, but in the correct format (~4.2B tokens in ~410k documents)
B. The “new data” domain-general documents (about 44M tokens in 4,600 documents)
C. Uncurated domain-specific documents (~535k tokens in ~110 documents)
D. Manually-curated domain-specific documents (~360k tokens in ~120 documents)

To date, I have been fine-tuning GPT-J with the manually-curated documents. This gives decent results.

I’ve experimented with finetuning on a combo of the curated & uncurated domain-special documents, and that seems a bit better, but it’s hard to tell.

I want to try finetuning in stages.

Ideally, one would fine-tune with all the available data, either as:

GPT-J → A → B → C → D → result
GPT-J → A+B → C → D → result

This raises a couple of questions for me.

  1. Is it even worth trying to fine-tune on set A?

Fine-tuning on A would be quite expensive and time-consuming; that’s probably a month on 8 A100s. And I have no idea whether or how much it would help GPT-J to see that data again in something much closer to the “correct” format.

  1. For early fine-tuning stages, what test set should I use?

E.g., if I’m fine-tuning on B, then C, then D, should I test the finetuning of B against a sample drawn from B, C, or D? The standard approach for fine-tuning on dataset B is to test with data also from B. But it makes sense to test the fine-tuning against a sample from D to see if it’s helping my actual goal and stop fine-tuning when it no longer is. Is that reasonable?

  1. What, if anything, should I do to adjust the learning rate differently for multi-stage fine-tuning than I would for single-stage?

I admit that this whole project is basically an effort to see how tightly I can press myself up against the wall of diminishing returns. :slight_smile:

Thanks for any advice!