AttributeError: 'InMemoryTable' object has no attribute '_batches'

Not sure if this is a bug or something else. Opened an issue on the Datasets github repo as well.

I’m running an MLOps flow.
The error appears when I run the following function:

data_tokenized = data.map(partial(funcs.tokenize_function, tokenizer,
                                      seq_length),
                              batched=True,
                              batch_size=batch_size,
                              remove_columns=['col1', 'col2'])
def tokenize_function(tok, seq_length, example)
    # Pad so that each batch has the same sequence length
    inp = tok(example['col1'], padding=True, truncation=True)
    outp = tok(example['col2'], padding="max_length", max_length=seq_length)

    res = {
        'input_ids': inp['input_ids'],
        'attention_mask': inp['attention_mask'],
        'decoder_input_ids': outp['input_ids'],
        'labels': outp['input_ids'],
        'decoder_attention_mask': outp['attention_mask']
    }
    return res

Here is the error:

Traceback (most recent call last):
  File "finetune.py", line 103, in <module>
    main(args)
  File "finetune.py", line 45, in main
    data_tokenized = data.map(partial(funcs.tokenize_function, tokenizer,
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/dataset_dict.py", line 868, in map
    {
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/dataset_dict.py", line 869, in <dictcomp>
    k: dataset.map(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3093, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3432, in _map_single
    arrow_formatted_shard = shard.with_format("arrow")
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 2667, in with_format
    dataset = copy.deepcopy(self)
  File "/opt/conda/envs/ptca/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/opt/conda/envs/ptca/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/opt/conda/envs/ptca/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/opt/conda/envs/ptca/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/opt/conda/envs/ptca/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/datasets/table.py", line 176, in __deepcopy__
    memo[id(self._batches)] = list(self._batches)
AttributeError: 'InMemoryTable' object has no attribute '_batches'

This ran a couple of weeks ago fine. Although I recreated the environment since then, I see that the datasets library was last updated at the end of December '23, so before I ran this the last time.

I’m running the latest version of datasets (2.16.1) and Transformers v. 4.35.2. This isn’t the latest version of Transformers because last time I used the latest version there was a conflict with Azure MLFlow.