Use dataset.map for ngrams and Word2Vec style data pipeline

I am using JAX to reproduce the training pipeline to train Word2Vec style embeddings of dimension D from scratch on wikitext, using datasets and tokenizers to handle the data pipeline. I will store the embeddings in two separate jax.numpy.ndarray objects.

I am using a WordLevel tokenizer rather than a PreTrainedTokenizer at the SubWord or WordPiece as in the Transformer examples.

I would like to learn the most efficient way of doing this using Huggingface for the data pipeline only.

In a given iteration, the code needs to:

  1. Sample a batch of normalized training data
  2. Tokenize
  3. Process each article of length M into non-overlapping ngrams of length N
  4. Use tokens in ngram data to access a word embedding vector of dimension ND (each embedding concatenated in sequence order), so that a the data sample of original length M becomes a jax.numpy.ndarray of size (1, M // N, ND), where // indicates integer division

Using the example code:

from datasets import load_dataset 
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Lowercase, NFKC, Sequence, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
dataset = load_dataset('wikitext', "wikitext-103-v1", split="train+test+validation")
# from tokenizers.decoders import WordLevel as WordLevelDecoder

vocab_size = 400000

tokenizer = Tokenizer(WordLevel())

tokenizer.normalizer = Sequence([
    NFKC(),
    Lowercase(),
    StripAccents()
])
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(vocab_size=vocab_size, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

This loads, normalizes, and tokenizes data successfully.

My question from here is: what’s the most efficient way of handling (3.) and (4.)? Should I combine all this as a map over the dataset?

A more specific question that came up was how do I find the maximum sequence length from a trained tokenizer object? EDIT: in order to JIT as efficiently as possible, my input data should all be the same shape, so knowing the max sequence length after the tokenzier has iterated over the raw data would be really useful here.

Thanks!