Caching a dataset processed with randomness

Is it possible to cache datasets that have been processed with a “random” function?

For example, this tutorial Process uses random.randint to pick words to mask. I believe this will break the hashing that Datasets relies upon to perform caching so I thought to use a generator like so.

        def sub_tok(batch, g):
            input_ids: List[List[int]] = batch["input_ids"]
            for i, example in enumerate(input_ids):
                token = example[
                    torch.randint(0, len(example), (1,), generator=g).item()
                ]
                input_ids[i] = [MASK if tok == token else tok for tok in example]
            batch["input_ids"] = input_ids

        g = torch.Generator(4444)
        lm_datasets["train"] = lm_datasets["train"].map(
            sub_tok,
            fn_kwargs={"generator": g},
            batched=True,
            num_proc=args.preprocessing_num_workers,
            load_from_cache_file=not args.overwrite_cache,
            desc=f"Enciphering XXX tokens per example",
        )

But I receive the error TypeError: cannot pickle 'torch._C.Generator' object. Removing the generator allows for the dataset to be processed but breaks caching as expected. Any advice?

Hi! I opened a PR that implements a serializer for torch.Generator to avoid the pickle error.

PS: torch.Generator expects a device as the first argument in the constructor, not a seed.

1 Like