Understanding how token batches and fine-tuning interact

I’m fine-tuning on GPT-J, which has a 2048-token input size limit. Similarly to this question, I’m looking at group_texts() trying to understand how it all fits together. (But unlike that question, my fine-tuning texts are quite long, often spanning multiple blocks.)

For simplicity, suppose the token limit was 4 instead of 2048. So if I have the document ABCDEFGHIJ, then it will end up going into fine-tuning as two blocks:

ABCD

EFGH

And group_texts() as written will drop IJ. (Long term, I plan to pad that out. My fine-tuning input is quite hard to come by.)

What happens then?

Will the fine-tuning process break this down to:

A → B

AB → C

ABC → D

E ->F

EF → G

EFG → H?

If so, will it learn anything about BCD → E or CDE → F?

If not, what does it do instead?

And how should I structure my input if I want to make sure that all the causal LM possibilities are covered:

A → B

AB → C

ABC → D

BCD → E

CDE → F

DEF → G

EFG → H

FGH → I

GHI → J

My first instinct was to feed in every 4-token sequence in the document: ABCD, BCDE, CDEF, etc., as a separate input block. But while that might work for 4 tokens, that’s an awful lot of repetition and potential for overfitting at 2048. (Not to mention the massive increase in fine-tuning time!)

What’s the right way to do this?

Thanks for any advice!

1 Like