Tokenizer Trainer Crashing

Trying to train a tokenizer on a corpus of protein sequences ~8 GB in size (32M examples).

However, the memory usage grows incrementally until the process is killed when training from iterator:

from tokenizers import SentencePieceBPETokenizer
from tokenizers.processors import BertProcessing

import datasets
from datasets import load_dataset

import itertools

from datetime import datetime
from absl import flags, app

flags.DEFINE_integer("num_batches", 1, help="Number of batches to sample")
flags.DEFINE_integer("batch_size", 100, help="Batch size")
FLAGS = flags.FLAGS

dataset_path = "agemagician/uniref30"

def batch_iterator(ds: datasets.Dataset, batch_size = 1000, num_batches = -1, split='train', feature='text'):
    
    iterator = iter(ds[split])
    
    def batch_generator(iterable):
        while (batch := [ex[feature] for ex in itertools.islice(iterable, batch_size)]):
            yield batch 
         
    return batch_generator(iterator) if num_batches == -1 else itertools.chain(itertools.islice(batch_generator(iterator), num_batches))
    
def train(batch_size, num_batches):
    uniref = load_dataset(dataset_path)
    tokenizer = SentencePieceBPETokenizer()
    
    it = batch_iterator(uniref, batch_size=batch_size, num_batches=num_batches)
    tokenizer.train_from_iterator(it, vocab_size=1000, min_frequency=2, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ])
    tokenizer.save('proteins-tmp')

def main(args):
    start = datetime.now()
    print(f"Starting tokenizer training on {FLAGS.num_batches} of batch size {FLAGS.batch_size}")
    train(FLAGS.batch_size, FLAGS.num_batches)
    end = datetime.now()
    dur = end - start
    print(f"Took {dur.total_seconds() / 60.:.1f} minutes")    

if __name__ == "__main__":
   app.run(main)

I included batch_size and num_batches commandline args to enable testing. However, when I try to train the tokenizer on the entire dataset (setting num_batches=-1), the memory usage during the ā€œtokenizingā€ phase starts to grow until the entire process crashes.

Are there any suggested workarounds for this? Is it possible to train the tokenizer on chunks of data (i.e., call tokenizer.train_from_iterator multiple times), or does it need to see the entire dataset at once for the BPE algorithm to work correctly?

Thanks!