Streaming Dataset of Sequence Length 2048

Hello,

I am working on an implementation of a streamed dataset that consists of input examples that are concatenated together and then split into sequences of exactly 2048 tokens so that there are no padding tokens. Examples can be split in the middle. I use drop_last=True in the DataLoader to remove the last input example which does not meet the required sequence length. I am using .map() to apply the processing function to the input examples which meet the criterion above.

Is this the correct method of doing so? When training the model while using streaming=True there seems to be some instability with spiking training loss. When using the same data loading process but loading the entire dataset into memory, that instability disappears and the training loss becomes smooth. Is anyone able to provide any additional advice? Information on improving the data loading and processing method?

This is the code I have for processing, tokenizing, and loading the data into the DataLoader:

tokenizer = GPT2Tokenizer(vocab_file='/token/vocab.json', 
                          merges_file='/token/merges.txt')

ds = load_dataset("lvwerra/github-code", 
                  streaming=True, 
                  split="train", 
                  languages=["Python"])

shuffled_dataset = ds.shuffle(seed=42, 
                               buffer_size=10_000)

def tokenize(examples):
    seq_length = 2048
    examples = tokenizer(examples["code"])
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= seq_length:
        total_length = (total_length // seq_length) * seq_length

    result = {
        k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
        for k, t in concatenated_examples.items()
    }

    return result

tokenized_dataset = shuffled_dataset.map(
    tokenize, 
    batched=True, 
    remove_columns= ['code', 'language', 'path', 'repo_name', 'size', 'license']
)

dataset = tokenized_dataset.with_format("torch")

dataloader = DataLoader(dataset,
                        drop_last=True,
                        collate_fn=default_data_collator,
                        batch_size=8,
                       )

Is this the correct method of doing so?

yes it is :slight_smile:

When training the model while using streaming=True there seems to be some instability with spiking training loss. When using the same data loading process but loading the entire dataset into memory, that instability disappears and the training loss becomes smooth.

When you do streaming=False or when you have a “map-style” dataset (i.e. when you can get any example of the dataset at any time, as you can do with a python list), then shuffling the dataset can be a perfect shuffling. On the contrary, streaming dataset can only be shuffled approximately (unless you set buffer_size=num_examples). You can try setting a bigger buffer_size so see if it helps.

Otherwise you can try to reshuffle the dataset yourself if you think this is a shuffling issue.

1 Like

Hi @lhoestq ,

Firstly, congratulations on the Series C!

Secondly, thank you for confirming that this is the correct way to stream, structure, and map the data loader to meet those specifications.

I increased the buffer_size to 100_000. I will further increase the size in later testing to what I can appropriately fit into RAM at any given time. I did two additional streaming tests using wikitext-103-v1 and wikitext-2-v1 along with the original python subset of lvwerra/github-code. Same batch size of 8 for all. No gradient accumulation steps.

Both tests with wikitext-103-v1 and wikitext-2-v1 had ‘smooth’ training loss and generated decent results when training from scratch on an autoregressive gpt-like decoder-only model. There still seems to be a lot of spiking in the training loss when using the python subset of lvwerra/github-code though.

Training for each test:

Wikitext-2-v1: 25k steps. Dataset size ~ 12.72 MB

Wikitext-103-v1: 120k steps. Dataset size ~ 522.66 MB

lvwerra/github-code - languages=[“Python”]: 13.5k steps. Dataset size ~ 52.03 GB

Other considerations:

I am using AdamW with a learning rate of 0.0002 and weight_decay of 0.099. Gradient clipping is 1.0. The same model has been used for all three tests. I am using the GPT2Tokenizer with vocab file and merges file. I will test with sequence length 512 and a larger model as well.

Further, I am going to test on a different smaller coding language subset such as Julia and log performance.

Any additional help/input would be great.

Thank you again.

cc @lvwerra wdyt of those spikes encountered using the python subset of lvwerra/github-code ?

That’s interesting, we are also currently investigating the code a bit we used to train lvwerra/codeparrot. There a few things we noticed:

  • Shuffling: just shuffling the input files might not be enough. Since some files can be very long (~several 10k tokens) it can happen that a single batch only consists of samples from one file. Thus shuffling the data loader could also be a good idea.

  • Data quality: the github-code dataset is quite raw and was minimally filtered. The reason is to give users the freedom to choose their own filtering logic. There is a nice thread from @loubnabnl about data quality: https://twitter.com/LoubnaBenAllal1/status/1514300881419878403. So in essence some of these spikes could be bad/noisy training samples. But that does not explain the difference between streaming/not streaming.

  • We are also investigating a slower than expected training convergence of our model with the codeparrot training script but we have not yet identified the root cause of that.

Do the spikes completely disappear when you turn off streaming?

2 Likes

Hi @lvwerra ,

I greatly appreciate the help.

I had trained 3 additional models with the same set of parameters for ~10k steps each over the last day. The same GPT2Tokenizer was used along with AdamW optimizer and hyperparameters.

To address the raw data quality concern I changed datasets from the Python subset of lvwerra/github-code to lvwerra/codeparrot-clean-train which was referenced in the Code Parrot article.

I tried 3 different test combinations of input file shuffling, data loader batch shuffling, and streaming of the datasets. From my results, I believe your assertion that shuffling input files alone is not enough, and that the batches need to be shuffled in the data loader as well, is correct.

Test 1:

  • Streaming
  • Input file shuffle with a buffer size of 10_000
  • No dataloader batch shuffling. Batch size of 8

Test 2:

  • No Streaming
  • Input file shuffle with a buffer size of 1_000
  • No dataloader batch shuffling. Batch size of 8

Test 3:

  • No streaming
  • Input file shuffle with a buffer size of 1_000
  • Dataloader batch shuffling. Batch size of 8

Large spiking in loss was seen in both tests 1 and 2 regardless of streaming. In neither of these tests were the batches shuffled in the data loader. I believe the discrepancy I was previously seeing in streaming vs non-streaming was due to the fact I was shuffling batches in test 3 while not shuffling batches in test 1. The training was more stable in test 3 with fewer exploding/vanishing gradients over the ~10k steps. Some spikes in training loss were still noticed.

The non-iterable Pytorch Dataloader allows for shuffle=True when dealing with non-streamed datasets. I am unable to set this parameter to true in the Iterable Dataset / Dataloader when streaming data. I have yet to find a proper solution for optimally shuffling batches when working with the Pytorch Iterable Dataset / Dataloader while streaming.

Do you have any advice/input on doing so? I will do a 4th test with streaming, input file shuffling as well as batch shuffling when I find a resolution.

A similar issue is brought up here in this post as well:

It was stated, “_ training loss curves decrease smoothly with streaming=False but currently with iterable datasets losses do not converge smoothly and even tend to diverge”.

In all tests, training loss ended approximately at 1.5 after ~10k steps, and convergence seemed to slow around ~4k steps.

After further review of the recent OPT-175 and PaLM LLM research papers, I had some additional considerations for my model training:

  • Implement a learning rate scheduler. Warming up from steps 0 to 2000/10000.
  • Add different momentum. For example: B1 = 0.9, B2 = 0.95.
  • Use weight initialization.
  • Use dropout of 0.1.
  • Try a different optimizer such as Adafactor with parameter scaling.
  • Increase the batch size to 256/512.
  • Add gradient accumulation steps.
  • Construct a tokenizer with a larger vocab size.
  • Possibly increase the size of the model.
  • Use fp16 or bf16.
  • “Loss Divergences - Loss divergences were also an issue in our training run. When the loss diverged, we found that lowering the learning rate and restarting from an earlier checkpoint allowed for the job to recover and continue training”, in OPT-175 Pg. 3.

Once, the IterableDataset / Dataloader issue with batch shuffling is resolved I plan to do a separate test with streaming the Pile Dataset as well.

Again, thank you for all of your help!

Sources:

  • When training a model on the lvwerra/codeparrot-clean-train dataset, we don’t get these spikes even without shuffling the sequences inside the batches but we’re using a different training setting that is supposed to be more stable and implements some of the ideas you mentioned, you can find the code here transformers/codeparrot_training.py at main · huggingface/transformers · GitHub . We also use file concatenation and sequence splitting without padding.
  • As for the shuffling of a torch IterableDataset, you can create a ShuffledDataset class to which you pass your IterableDataset like here How to shuffle an iterable dataset - #6 by sharvil - PyTorch Forums Or use combinatorics.ShufflerIterDataPipe(IterableDataset, buffer_size) from torch.utils.data.datapipes.iter which I think is supposed to do the same thing
1 Like

Hi @loubnabnl ,

I appreciate the additional information. I will thoroughly review the codeparrot_training repository.

Before posting yesterday I had reviewed the resources below about Shuffling Iterable Datasets. I ended up using the ShufflerIterDataPipe which you had also mentioned this morning. I could not find any concrete documentation on the class but it seems to properly shuffle the batch data. I may have missed the official PyTorch documentation if it is available. I set the buffer size of the ShufflerIterDataPipe equal to that of the Hugging Face Dataset shuffle function. I did not know if there was an efficient or officially documented way to do so with Hugging Face, so I thought it would be worth asking additionally.

https://discuss.pytorch.org/t/how-to-shuffle-an-iterable-dataset/64130/6
https://discuss.pytorch.org/t/how-to-shuffle-multi-worker-iterabledataset-loading/142203
https://www.ccoderun.ca/programming/doxygen/pytorch/classtorch_1_1utils_1_1data_1_1datapipes_1_1iter_1_1combinatorics_1_1ShufflerIterDataPipe.html

I completed the 4th test last night on the lvwerra/codeparrot-clean-train for 8k steps while shuffling the batch data in the streamed Iterable Dataset.

Test 4:

  • Streaming
  • Input file shuffle with a buffer size of 1_000
  • Dataloader batch shuffling with ShufflerIterDataPipe. Batch size of 8.

Shuffling the data in the batches while using streaming=True seemed to help resolve instances of large spiking. The slower convergence and spiking still occur in many instances of streaming data without shuffling as mentioned in the other linked post: https://discuss.huggingface.co/t/limitations-of-iterable-datasets/16794/3. Strangely this still does not happen when loading the data locally into memory and shuffling in the data loader. This may be due to the use of smaller batch sizes, models, etc. during my pre-training since my environment/configuration is different.

I will have to further investigate after reviewing the Code Parrot repository. I am currently implementing some of the considerations I listed as well and will provide an update to see if they resolve the unshuffled streaming data issues.

It might also be worth reviewing whether using .set_epoch() to additionally shuffle input files will have any benefit in my case.

I greatly appreciate your help.

Thank you.