Why does deleting the columns before giving it to interleave work but sometimes it does NOT work?

Why does deleting the columns before giving it to interleave work but sometimes it does NOT work?

Code snipt:

        # -Interleaving datasets
        print('- Interleaving datasets')
        datasets = [my_load_dataset(path, name, data_files).with_format("torch") for path, name, data_files in zip(path, name, data_files)]
        # datasets = [my_load_dataset(path, name).with_format("torch") for path, name in zip(path, name)]
        if path[0] == 'parquet':  # idk why I need to do this, I checked very carefully and deleted all columns so interleaved data set matched but when doing this with c4 & wikitext it fails but with the parquet it works
            dataset_descriptions = [dataset.description for dataset in datasets]  # print description if available
            print(f'{dataset_descriptions=}')
            # - make sure all datasets have the same columns to avoid interleave to complain
            all_columns = [col for dataset in datasets for col in dataset.column_names]
            print(f'{all_columns=}')
            columns_to_remove = [col for dataset in datasets for col in dataset.column_names if col != 'text']
            columns_to_remove = list(set(columns_to_remove))  # remove duplicates
            print(f'{columns_to_remove=}')
            datasets = [dataset.remove_columns(columns_to_remove) for dataset in datasets]
            # - interleave
            print(f'{probabilities=}')
            dataset_descriptions = [dataset.description for dataset in datasets]  # print description if available
            print(f'{dataset_descriptions=}')
        dataset = interleave_datasets(datasets, probabilities)
        # dataset = dataset.remove_columns(columns_to_remove)
        print(f'{dataset=}')
        print(f'{dataset.column_names=}')
    print(f'{dataset=}')
    print(f'{type(dataset)=}')
    # datasets.iterable_dataset.IterableDataset
    # datasets.arrow_dataset.Dataset
    # dataset = IterableDataset(dataset) if type(dataset) != IterableDataset else dataset  # to force dataset.take(batch_size) to work in non-streaming mode
    batch = dataset.take(batch_size)
    print(f'{batch=}')
    print(f'{next(iter(batch))=}')

code:

def experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights():
    """
    Get divs using pt ft, pt (rand, rand ft?) 
    - div c4 
    - div wt = wt-103
    Then with unioned datasets
    - div c4+wt, uniform [0.5, 0.5]
    - # div c4+wt, data set size proportions (using GBs)
    - div c4+wt, respect doremi
    - div c4+wt, respect the pile
    - div c4+wt, respect gpt3 weights
    then repeat all with pt (no ft)
    """
    import random
    from diversity.data_mixtures import get_uniform_data_mixture_for_c4_wt103, get_doremi_based_data_mixture_for_c4_wt103, get_llama_v1_based_data_mixture_for_c4_wt103
    probabilities = []
    data_mixture_name = None
    streaming = True
    data_files = []
    seed = 0
    # -- Setup wandb
    import wandb
    # - Dryrun
    mode = 'dryrun'; num_batches = 3
    # mode = 'dryrun'; num_batches = 3; seed = random.randint(0, 2**32 - 1)

    # - Online (real experiment)
    # mode='online'; num_batches = 600; seed = random.randint(0, 2**32 - 1)
    # path, name = 'c4', 'en'
    # path, name = "wikitext", 'wikitext-103-v1'
    path, name, data_files = ['c4', 'wikitext'], ['en', 'wikitext-103-v1'], [None, None]
    probabilities, data_mixture_name = get_uniform_data_mixture_for_c4_wt103()
    # probabilities, data_mixture_name = get_doremi_based_data_mixture_for_c4_wt103()
    # probabilities, data_mixture_name = get_llama_v1_based_data_mixture_for_c4_wt103()
    # probabilities, data_mixture_name = [0.75, 0.25], '[0.75, 0.25]' 
    # probabilities, data_mixture_name = [0.25, 0.75], '[0.25, 0.75]' 
    # path, name = 'EleutherAI/pile', 'all'
    # path, name = 'conceptofmind/pile_cc', 'sep_ds'
    # streaming = False
    # path, name = 'conceptofmind/pile_cc', 'sep_ds'
    # path, name = 'EleutherAI/pile', 'hacker_news' 
    # path, name = 'EleutherAI/pile', 'nih_exporter'  # https://github.com/huggingface/datasets/issues/6144
    # path, name = 'EleutherAI/pile', 'pubmed' 
    # path, name = 'EleutherAI/pile', 'uspto' 
    # - 5 subsets of pile using hf data set viewer (parquet)) 
    # from diversity.pile_subset_urls import urls_hacker_news, urls_nih_exporter, urls_pubmed, urls_uspto
    # path, name, data_files = 'parquet', 'hacker_news', urls_hacker_news
    # path, name, data_files = 'parquet', 'nih_exporter', urls_nih_exporter
    # path, name, data_files = 'parquet', 'pubmed', urls_pubmed
    # path, name, data_files = 'parquet', 'uspto', urls_uspto
    # - 5 subsets of the pile interleaved
    # from diversity.pile_subset_urls import urls_hacker_news, urls_nih_exporter, urls_pubmed, urls_uspto
    # from diversity.data_mixtures import get_doremi_data_mixture_5subsets_of_pile, get_llama_v1_data_mixtures_5subsets_of_pile
    # path, name, data_files = ['conceptofmind/pile_cc'] + ['parquet'] * 4, ['sep_ds'] + ['hacker_news', 'nih_exporter', 'pubmed', 'uspto'], [None] + [urls_hacker_news, urls_nih_exporter, urls_pubmed, urls_uspto]
    # probabilities, data_mixture_name = [1.0/len(path)] * len(path), f'{[1.0/len(path)] * len(path)=}'
    # probabilities, data_mixture_name = get_doremi_data_mixture_5subsets_of_pile(name)
    # probabilities, data_mixture_name = get_llama_v1_data_mixtures_5subsets_of_pile(name)
    # not changing
    batch_size = 512
    today = datetime.datetime.now().strftime('%Y-m%m-d%d-t%Hh_%Mm_%Ss')
    run_name = f'{path} div_coeff_{num_batches=} ({today=} ({name=}) {data_mixture_name=} {probabilities=})'
    print(f'\n---> {run_name=}\n')

    # - Init wandb
    debug: bool = mode == 'dryrun'
    run = wandb.init(mode=mode, project="beyond-scale", name=run_name, save_code=True)
    wandb.config.update({"num_batches": num_batches, "path": path, "name": name, "today": today, 'probabilities': probabilities, 'batch_size': batch_size, 'debug': debug, 'data_mixture_name': data_mixture_name, 'streaming': streaming, 'data_files': data_files, 'seed': seed})
    # run.notify_on_failure() # https://community.wandb.ai/t/how-do-i-set-the-wandb-alert-programatically-for-my-current-run/4891
    print(f'{debug=}')
    print(f'{wandb.config=}')

    # -- Get probe network
    from datasets import load_dataset 
    from datasets.iterable_dataset import IterableDataset
    import torch
    from transformers import GPT2Tokenizer, GPT2LMHeadModel

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    probe_network = GPT2LMHeadModel.from_pretrained("gpt2")
    device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
    probe_network = probe_network.to(device)

    # -- Get data set
    def my_load_dataset(path, name, data_files=data_files):
        print(f'{path=} {name=} {streaming=} {data_files=}')
        if path == 'json' or path == 'bin' or path == 'csv':
            print(f'{data_files_prefix+name=}')
            return load_dataset(path, data_files=data_files_prefix+name, streaming=streaming, split="train").with_format("torch")
        elif path == 'parquet':
            print(f'{data_files=}')
            return load_dataset(path, data_files=data_files, streaming=streaming, split="train").with_format("torch")
        else:
            return load_dataset(path, name, streaming=streaming, split="train").with_format("torch")
    # - get data set for real now
    if isinstance(path, str):
        dataset = my_load_dataset(path, name)
    else:
        # -Interleaving datasets
        print('- Interleaving datasets')
        datasets = [my_load_dataset(path, name, data_files).with_format("torch") for path, name, data_files in zip(path, name, data_files)]
        # datasets = [my_load_dataset(path, name).with_format("torch") for path, name in zip(path, name)]
        if path[0] == 'parquet':  # idk why I need to do this, I checked very carefully and deleted all columns so interleaved data set matched but when doing this with c4 & wikitext it fails but with the parquet it works
            dataset_descriptions = [dataset.description for dataset in datasets]  # print description if available
            print(f'{dataset_descriptions=}')
            # - make sure all datasets have the same columns to avoid interleave to complain
            all_columns = [col for dataset in datasets for col in dataset.column_names]
            print(f'{all_columns=}')
            columns_to_remove = [col for dataset in datasets for col in dataset.column_names if col != 'text']
            columns_to_remove = list(set(columns_to_remove))  # remove duplicates
            print(f'{columns_to_remove=}')
            datasets = [dataset.remove_columns(columns_to_remove) for dataset in datasets]
            # - interleave
            print(f'{probabilities=}')
            dataset_descriptions = [dataset.description for dataset in datasets]  # print description if available
            print(f'{dataset_descriptions=}')
        dataset = interleave_datasets(datasets, probabilities)
        # dataset = dataset.remove_columns(columns_to_remove)
        print(f'{dataset=}')
        print(f'{dataset.column_names=}')
    print(f'{dataset=}')
    print(f'{type(dataset)=}')
    # datasets.iterable_dataset.IterableDataset
    # datasets.arrow_dataset.Dataset
    # dataset = IterableDataset(dataset) if type(dataset) != IterableDataset else dataset  # to force dataset.take(batch_size) to work in non-streaming mode
    batch = dataset.take(batch_size)
    print(f'{batch=}')
    print(f'{next(iter(batch))=}')
    column_names = next(iter(batch)).keys()
    print(f'{column_names=}')

    # - Prepare functions to tokenize batch
    def preprocess(examples):
        return tokenizer(examples["text"], padding="max_length", max_length=128, truncation=True, return_tensors="pt")
    remove_columns = column_names  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader
    def map(batch):
        return batch.map(preprocess, batched=True, remove_columns=remove_columns)
    print(f'{batch=}')
    tokenized_batch = map(batch)
    print(f'{next(iter(tokenized_batch))=}')

    # -- Compute diversity coefficient
    print(f'-- Compute diversity coefficient')
    print(f'{seed=}, {streaming=}')
    # - Debug run
    # results: dict = get_diversity_coefficient(dataset, map, probe_network, num_batches=3, seed=seed, debug=True, shuffle=False)  # (quick debug) hardcoded for debugging
    # results: dict = get_diversity_coefficient(dataset, map, probe_network, num_batches=3, seed=seed, debug=True, shuffle=True)  # (slow debug) hardcoded for debugging
    # results: dict = get_diversity_coefficient(dataset, map, probe_network, num_batches=3, seed=seed, debug=False, shuffle=False)  # (real) hardcoded for debugging
    # - Real run
    # assert not debug, f'Err: {debug=} for real run'
    results: dict = get_diversity_coefficient(dataset, map, probe_network, num_batches=num_batches, seed=seed, debug=debug, shuffle=False)
    # results: dict = get_diversity_coefficient(dataset, map, probe_network, num_batches=num_batches, seed=seed, debug=debug, shuffle=True)
    # - Log results
    div_coeff, div_coeff_ci = results['div_coeff'], results['div_coeff_ci']
    print(f'{div_coeff=} {div_coeff_ci=}')
    wandb.log({'div_coeff': div_coeff, 'div_coeff_ci': div_coeff_ci})

    # -- Save results or not
    save_results = True
    if save_results:
        output_dir = Path(f'~/data/div_coeff/{today}').expanduser()
        output_dir.mkdir(parents=True, exist_ok=True)
        np.save(output_dir / f'distance_matrix{today}.npy', results['distance_matrix'])
        np.save(output_dir / f'results{today}.npy', results)
        # Save results as a pretty-printed JSON
        results = {key: str(value) for key, value in results.items()}
        with open(output_dir / f'results{today}.json', 'w') as f:
            json.dump(results, f, indent=4)
        # - wandb save
        base_path = str(output_dir.parent)
        wandb.save(str(output_dir / f'distance_matrix{today}.npy'), base_path=base_path)
        wandb.save(str(output_dir / f'results{today}.npy'), base_path=base_path)
        wandb.save(str(output_dir / f'results{today}.json'), base_path=base_path)
        wandb.save(__file__)
    
    # -- Finish wandb
    wandb.finish()

err

Exception has occurred: KeyError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'timestamp'
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 729, in _iter
    del transformed_example[c]
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 652, in __iter__
    yield from self._iter()
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 73, in __next__
    result = next(self.it)
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 400, in __iter__
    yield next(iterators[i])
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1013, in __iter__
    yield from islice(self.ex_iterable, self.n)
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1353, in __iter__
    for key, example in ex_iterable:
  File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 531, in experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights
    print(f'{next(iter(batch))=}')
  File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 592, in <module>
    experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights()
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
KeyError: 'timestamp'

so: huggingface datasets - Why does deleting the columns before giving it to interleave work but sometimes it does NOT work? - Stack Overflow