GeneratorBasedBuilder gets stuck & consumes all RAM

Following the tutorial, I implemented a simple dataset which ought to generate the LibriSpeech dataset.

However, after 39999 examples I can see that the entire RAM (~64GB) has been consumed by the process and that the iterator gets stuck:

Downloading and preparing dataset new_dataset/default to /data/speech/corpora/hf/new_dataset/default/0.0.1/c928cea00cd327f5feccecf2cf274119a8859dbfe3298aca6333539088e3a3a4...
39999 examples [4:01:01,  5.91 examples/s]

Why is this the case? Each samples is supposed to get written to the .arrow file, why is it kept in memory?

I am using load_dataset() for this like so:

cache_dir = "/data/speech/corpora/hf"
path = my_dataset.__file__
train_dataset = load_dataset(path, split="train", cache_dir=cache_dir, keep_in_memory=False)

and this here would be the code for the dataset itself. Nothing too special there:

class NewDataset(datasets.GeneratorBasedBuilder):

    VERSION: datasets.Version = datasets.Version("0.0.1")

    BUILDER_CONFIGS = [
        datasets.BuilderConfig(version=VERSION, description="This part of my dataset covers a first domain"),
    ]

    def _info(self):

        features = datasets.Features(
            {
                "inputs": datasets.features.Sequence(datasets.Value("int16")),
                "targets": datasets.Value("string"),
                "length": datasets.Value("int64"),
            }
        )

        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=features,
            homepage=_HOMEPAGE,
            license=_LICENSE,
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                gen_kwargs={
                    "filepath": "/mariana/asr/corpora/converted/en/librispeech_train",
                    "split": "train",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                gen_kwargs={"filepath": "/mariana/asr/corpora/converted/en/librispeech_test", "split": "test"},
            ),
            datasets.SplitGenerator(
                name=datasets.Split.VALIDATION,
                gen_kwargs={
                    "filepath": "/mariana/asr/corpora/converted/en/librispeech_dev",
                    "split": "dev",
                },
            ),
        ]

    def _generate_examples(self, filepath, split):
        corpus = ConvertedCorpus(filepath)
        for key, record in enumerate(corpus.sample_generator()):
            yield key, dict(inputs=record.wav, targets=record.transcript, length=len(record.wav))

The sample-generator of the ConvertedCorpus you see here is not the issue here, in case you wonder. It is reading samples from 500MB .tar files one by one and yields it to the consumer without keeping any references:

# ConvertedCorpus.sample_generator()
def sample_generator(self):

    df = self.df_index
    chunk_ids = sorted(df.chunk_id.unique())

    for chunk_id in chunk_ids:
        tar_fp = self.tar_fps[chunk_id]
        df_chunk = df[df.chunk_id == chunk_id]
        with tf.io.gfile.GFile(tar_fp, mode="rb") as f:
            with tarfile.open(fileobj=f, mode="r:*") as tar:
                for _, record in df_chunk.iterrows():
                    raw_bytes = tar.extractfile(record.wav_fp).read()
                    wav, rate = soundfile.read(io.BytesIO(raw_bytes), dtype=np.int16)
                    assert rate == 16000, f"Expected {16000} kHz but got {rate}"
                    if not len(wav):
                        logging.warning(f"Ignoring empty sample {record.sample_id} of {self.info.name}")
                        continue
                    record["wav"] = wav
                    yield record

Hi ! Can you try adding the class attribute DEFAULT_WRITER_BATCH_SIZE = 256 ? The default value might be too big for your dataset.

This parameter says how many examples can stay in RAM while writing the dataset to an Arrow file

Indeed, this was the problem. After setting the field to a smaller value I was able to generate the dataset:

class AsrDataset(datasets.GeneratorBasedBuilder):
    # Arrow's default batch size (10_000) is too large for speech samples
    DEFAULT_WRITER_BATCH_SIZE = 500
1 Like