Streaming Dataset of Sequence Length 2048

Hi @lvwerra ,

I greatly appreciate the help.

I had trained 3 additional models with the same set of parameters for ~10k steps each over the last day. The same GPT2Tokenizer was used along with AdamW optimizer and hyperparameters.

To address the raw data quality concern I changed datasets from the Python subset of lvwerra/github-code to lvwerra/codeparrot-clean-train which was referenced in the Code Parrot article.

I tried 3 different test combinations of input file shuffling, data loader batch shuffling, and streaming of the datasets. From my results, I believe your assertion that shuffling input files alone is not enough, and that the batches need to be shuffled in the data loader as well, is correct.

Test 1:

  • Streaming
  • Input file shuffle with a buffer size of 10_000
  • No dataloader batch shuffling. Batch size of 8

Test 2:

  • No Streaming
  • Input file shuffle with a buffer size of 1_000
  • No dataloader batch shuffling. Batch size of 8

Test 3:

  • No streaming
  • Input file shuffle with a buffer size of 1_000
  • Dataloader batch shuffling. Batch size of 8

Large spiking in loss was seen in both tests 1 and 2 regardless of streaming. In neither of these tests were the batches shuffled in the data loader. I believe the discrepancy I was previously seeing in streaming vs non-streaming was due to the fact I was shuffling batches in test 3 while not shuffling batches in test 1. The training was more stable in test 3 with fewer exploding/vanishing gradients over the ~10k steps. Some spikes in training loss were still noticed.

The non-iterable Pytorch Dataloader allows for shuffle=True when dealing with non-streamed datasets. I am unable to set this parameter to true in the Iterable Dataset / Dataloader when streaming data. I have yet to find a proper solution for optimally shuffling batches when working with the Pytorch Iterable Dataset / Dataloader while streaming.

Do you have any advice/input on doing so? I will do a 4th test with streaming, input file shuffling as well as batch shuffling when I find a resolution.

A similar issue is brought up here in this post as well:

It was stated, “_ training loss curves decrease smoothly with streaming=False but currently with iterable datasets losses do not converge smoothly and even tend to diverge”.

In all tests, training loss ended approximately at 1.5 after ~10k steps, and convergence seemed to slow around ~4k steps.

After further review of the recent OPT-175 and PaLM LLM research papers, I had some additional considerations for my model training:

  • Implement a learning rate scheduler. Warming up from steps 0 to 2000/10000.
  • Add different momentum. For example: B1 = 0.9, B2 = 0.95.
  • Use weight initialization.
  • Use dropout of 0.1.
  • Try a different optimizer such as Adafactor with parameter scaling.
  • Increase the batch size to 256/512.
  • Add gradient accumulation steps.
  • Construct a tokenizer with a larger vocab size.
  • Possibly increase the size of the model.
  • Use fp16 or bf16.
  • “Loss Divergences - Loss divergences were also an issue in our training run. When the loss diverged, we found that lowering the learning rate and restarting from an earlier checkpoint allowed for the job to recover and continue training”, in OPT-175 Pg. 3.

Once, the IterableDataset / Dataloader issue with batch shuffling is resolved I plan to do a separate test with streaming the Pile Dataset as well.

Again, thank you for all of your help!

Sources: