Is it possible to cache datasets that have been processed with a “random” function?
For example, this tutorial Process uses random.randint
to pick words to mask. I believe this will break the hashing that Datasets relies upon to perform caching so I thought to use a generator like so.
def sub_tok(batch, g):
input_ids: List[List[int]] = batch["input_ids"]
for i, example in enumerate(input_ids):
token = example[
torch.randint(0, len(example), (1,), generator=g).item()
]
input_ids[i] = [MASK if tok == token else tok for tok in example]
batch["input_ids"] = input_ids
g = torch.Generator(4444)
lm_datasets["train"] = lm_datasets["train"].map(
sub_tok,
fn_kwargs={"generator": g},
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Enciphering XXX tokens per example",
)
But I receive the error TypeError: cannot pickle 'torch._C.Generator' object
. Removing the generator allows for the dataset to be processed but breaks caching as expected. Any advice?