want to train LLM on TPUv4-32 using JAX/Flax. The dataset is stored in a mounted google storage bucket. The dataset (Red-Pajama-v2) consists of 5000 shards, which are stored in .json.gz files: ~/folder-for-bucket/red_pajama/****/en_head.json.gz. Each file consists of JSON lines with examples, and text of an example is under the key “raw_content”.
I use LLamaTokenizerFast from HuggingFace. The context size of the model is 1024 tokens, and the batch size is 512. My question is, what would be optimal pipeline of loading dataset, tokenization and batch iteration, at least at high level.
I didn’t find any conventional way to do it on the internet. I asked ChatGPT, it suggested to make a token stream. However, I the current formulation it loads batches very slowly, so the script is input bound:
# ---------- tokenizer ----------
tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
# ---------- streaming dataset ----------
pattern = os.path.join(args.data_dir, "*", "en_head.json.gz")
raw_stream = load_dataset("json", data_files=pattern, split="train", streaming=True)
raw_stream = raw_stream.shard(jax.process_count(), jax.process_index())
# ---------- fast batched tokenizer ----------
def token_stream():
buf = []
for ex in raw_stream:
buf.append(ex["raw_content"])
if len(buf) >= DOCS_PER_CHUNK:
for ids in tokenizer(buf, add_special_tokens=False)["input_ids"]:
yield from ids + [tokenizer.eos_token_id]
buf.clear()
# flush remaining docs
if buf:
for ids in tokenizer(buf, add_special_tokens=False)["input_ids"]:
yield from ids + [tokenizer.eos_token_id]
# ---------- token → batch iterator ----------
def batch_iter(global_bsz: int):
ts, buf = token_stream(), []
while True:
buf.extend(itertools.islice(ts, seq_len + 1 - len(buf)))
if len(buf) < seq_len + 1:
continue
seq = np.asarray(buf[:seq_len], dtype=np.int32)
buf = buf[seq_len:]
yield {"input_ids": np.tile(seq[None, :], (global_bsz, 1))}