Trying to train a tokenizer on a corpus of protein sequences ~8 GB in size (32M examples).
However, the memory usage grows incrementally until the process is killed when training from iterator:
from tokenizers import SentencePieceBPETokenizer
from tokenizers.processors import BertProcessing
import datasets
from datasets import load_dataset
import itertools
from datetime import datetime
from absl import flags, app
flags.DEFINE_integer("num_batches", 1, help="Number of batches to sample")
flags.DEFINE_integer("batch_size", 100, help="Batch size")
FLAGS = flags.FLAGS
dataset_path = "agemagician/uniref30"
def batch_iterator(ds: datasets.Dataset, batch_size = 1000, num_batches = -1, split='train', feature='text'):
iterator = iter(ds[split])
def batch_generator(iterable):
while (batch := [ex[feature] for ex in itertools.islice(iterable, batch_size)]):
yield batch
return batch_generator(iterator) if num_batches == -1 else itertools.chain(itertools.islice(batch_generator(iterator), num_batches))
def train(batch_size, num_batches):
uniref = load_dataset(dataset_path)
tokenizer = SentencePieceBPETokenizer()
it = batch_iterator(uniref, batch_size=batch_size, num_batches=num_batches)
tokenizer.train_from_iterator(it, vocab_size=1000, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
tokenizer.save('proteins-tmp')
def main(args):
start = datetime.now()
print(f"Starting tokenizer training on {FLAGS.num_batches} of batch size {FLAGS.batch_size}")
train(FLAGS.batch_size, FLAGS.num_batches)
end = datetime.now()
dur = end - start
print(f"Took {dur.total_seconds() / 60.:.1f} minutes")
if __name__ == "__main__":
app.run(main)
I included batch_size
and num_batches
commandline args to enable testing. However, when I try to train the tokenizer on the entire dataset (setting num_batches=-1
), the memory usage during the “tokenizing” phase starts to grow until the entire process crashes.
Are there any suggested workarounds for this? Is it possible to train the tokenizer on chunks of data (i.e., call tokenizer.train_from_iterator multiple times), or does it need to see the entire dataset at once for the BPE algorithm to work correctly?
Thanks!