.map() function extremely slow

I am preprocessing the Wikipedia dataset. It’s extremely slow, with 12it/s, which totals 140h to process the dataset. Have looked online and no trace of anyone having similar issues.

I am running the script on a Slurm cluster with 128 CPUs, no GPU.

#SBATCH --ntasks=1 --cpus-per-task=128 --mem=50000M
#SBATCH --time=200:00:00

Code - should be reproducible:

import datasets 
from random import randint 
from transformers import AutoTokenizer
import gc

dataset = datasets.load_dataset("wikipedia", "20220301.en", split="train")


def preprocess_examples(batch):
    documents = batch['text']
    data = {'example': [], 'summary': []}

    for document in documents:
        # Generate two random numbers for length of document and summary 
        doc_length = randint(100, 400) # Patches in the encoder
        sum_length = randint(20, 50) # Tokens in the decoder
        document = document.replace('\n','')
        document = document.split(' ')

        if doc_length + sum_length <= len(document):
            text = document[:doc_length]
            summary = document[doc_length:(doc_length+sum_length)]
            text = document[:int(len(document)*0.8)]
            summary = document[int(len(document)*0.8):]
        summary = ' '.join(summary)
        text = ' '.join(text)


    return data

if __name__ == "__main__":
    train_dataset = dataset.map(preprocess_examples, batched=True, batch_size=1000, remove_columns=["id", "url", "title", "text"])

I have tried to:

  • set num_proc to number of cpu cores os.cpu_count() - didn’t improve the speed
  • set batch size to smaller / bigger number - no effect
  • try to change .map() to Dataset.from_generator() - no improvement

Any ideas or hints would be greatly appreciated :pray:t2:

Why do you need gc.collect()? This call is super expensive. Without it, 1500-2000 examples/s is the processing speed I get in Colab.