Extremely slow operation on dataset.map

Hi,

I’d like to know why is group_texts soooo slow.
Initially it’s super fast but when it reaches around 92%, I see only 2 threads (that’s 16/8) working from top command.
The dataset has over 80 million examples so it talks almost forever.

I’m wondering why dataset.map() uses only 1/8 threads after 92%.
Is there any better solution for this preprocess task?

You can copy&&paste the code piece below and just run it. I copied it from HF official tutorial.

Cheers! :beers:
Aiden

from datasets import concatenate_datasets, load_dataset
from transformers import BertTokenizerFast


def preprocess_function(examples):
    tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-tiny')
    return tokenizer([" ".join(x) for x in examples["text"]])


def group_texts(examples):
    block_size = 384
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    return result

# cpu cores: 16
num_proc_1 = 16
num_proc_2 = 16
# num_proc_2 = 16 * 8
bookcorpus = load_dataset("bookcorpus", split="train")
wiki = load_dataset("wikipedia", "20220301.en", split="train")

wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"])  # only keep the 'text' column

assert bookcorpus.features.type == wiki.features.type
raw_datasets_bert = concatenate_datasets([bookcorpus, wiki])

tokenized_datasets_bert_1 = raw_datasets_bert.map(
    preprocess_function,
    batched=True,
    batch_size=20000,
    writer_batch_size=20000,
    remove_columns=["text"],
    num_proc=num_proc_1
)

print("start tokenized_datasets_bert_2")
tokenized_datasets_bert_2 = tokenized_datasets_bert_1.map(
    group_texts,
    batched=True,
    batch_size=2_000,
    writer_batch_size=2_000,
    num_proc=num_proc_2
)