Tokenizer Trainer Crashing

Trying to train a tokenizer on a corpus of protein sequences ~8 GB in size (32M examples).

However, the memory usage grows incrementally until the process is killed when training from iterator:

from tokenizers import SentencePieceBPETokenizer
from tokenizers.processors import BertProcessing

import datasets
from datasets import load_dataset

import itertools

from datetime import datetime
from absl import flags, app

flags.DEFINE_integer("num_batches", 1, help="Number of batches to sample")
flags.DEFINE_integer("batch_size", 100, help="Batch size")
FLAGS = flags.FLAGS

dataset_path = "agemagician/uniref30"

def batch_iterator(ds: datasets.Dataset, batch_size = 1000, num_batches = -1, split='train', feature='text'):
    
    iterator = iter(ds[split])
    
    def batch_generator(iterable):
        while (batch := [ex[feature] for ex in itertools.islice(iterable, batch_size)]):
            yield batch 
         
    return batch_generator(iterator) if num_batches == -1 else itertools.chain(itertools.islice(batch_generator(iterator), num_batches))
    
def train(batch_size, num_batches):
    uniref = load_dataset(dataset_path)
    tokenizer = SentencePieceBPETokenizer()
    
    it = batch_iterator(uniref, batch_size=batch_size, num_batches=num_batches)
    tokenizer.train_from_iterator(it, vocab_size=1000, min_frequency=2, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ])
    tokenizer.save('proteins-tmp')

def main(args):
    start = datetime.now()
    print(f"Starting tokenizer training on {FLAGS.num_batches} of batch size {FLAGS.batch_size}")
    train(FLAGS.batch_size, FLAGS.num_batches)
    end = datetime.now()
    dur = end - start
    print(f"Took {dur.total_seconds() / 60.:.1f} minutes")    

if __name__ == "__main__":
   app.run(main)

I included batch_size and num_batches commandline args to enable testing. However, when I try to train the tokenizer on the entire dataset (setting num_batches=-1), the memory usage during the “tokenizing” phase starts to grow until the entire process crashes.

Are there any suggested workarounds for this? Is it possible to train the tokenizer on chunks of data (i.e., call tokenizer.train_from_iterator multiple times), or does it need to see the entire dataset at once for the BPE algorithm to work correctly?

Thanks!