Dataset.transform() hangs indefinitely while finetuning the stable diffusion XL

I am following the finetuning script given at https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md

With multi-GPU training, it never prints HERE4 statement.

with accelerator.main_process_first():
    print(accelerator.is_main_process)
    print("===========Here3.1===========")
    if args.max_train_samples is not None:
        dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
    print("===========Here3.2===========")
    # Set the training transforms
    train_dataset = dataset["train"].with_transform(preprocess_train)
print("==========HERE4=============")

corresponding output

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
10/25/2023 21:18:04 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 3
Process index: 1
Local process index: 1
Device: cuda:1

Mixed precision type: fp16

10/25/2023 21:18:04 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 3
Process index: 2
Local process index: 2
Device: cuda:2

Mixed precision type: fp16

10/25/2023 21:18:04 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 3
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{‘variance_type’, ‘clip_sample_range’, ‘thresholding’, ‘dynamic_thresholding_ratio’} was not found in config. Values will be initialized to default values.
{‘attention_type’, ‘reverse_transformer_layers_per_block’, ‘dropout’} was not found in config. Values will be initialized to default values.
==========HERE1=============
==========HERE1=============
==========HERE1=============
==========HERE2=============
==========HERE2=============
==========HERE2=============
==========HERE3=============
True
===========Here3.1===========
===========Here3.2===========
==========HERE3=============
==========HERE3=============

It works just fine with a single GPU setting but, sadly errors out since my NVIDIA A5000 24GB is not enough.

I am using cuda 11.7, Pytorch 1.13.1, diffusers 0.22.0.dev0.

1 Like

Any inputs, please?

I get:

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.

but my cuda seems fine:

Model is currently on: next(iter(model.parameters())).device=device(type='cuda', index=0)
Model is currently on: next(iter(model.parameters())).dtype=torch.bfloat16
vocab_size: len(tokenizer)=32000 
Expected random loss: math.log(len(tokenizer))=10.373491181781864
CUDA version: torch.version.cuda='12.1

and nvidia:

(base) brando9@ampere8~ $ python -c "import torch; print(torch.randn(2, 4).to('cuda') @ torch.randn(4, 1).to('cuda'));"
tensor([[-2.4671],
        [-0.6787]], device='cuda:0')
(base) brando9@ampere8~ $ nvidia-smi
Fri Jan 26 16:58:00 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   64C    P0             385W / 400W |  10822MiB / 81920MiB |     99%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:0A:00.0 Off |                    0 |
| N/A   53C    P0             364W / 400W |  10399MiB / 81920MiB |     99%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          On  | 00000000:44:00.0 Off |                    0 |
| N/A   51C    P0             274W / 400W |  53285MiB / 81920MiB |     60%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM4-80GB          On  | 00000000:4A:00.0 Off |                    0 |
| N/A   55C    P0             310W / 400W |  53283MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM4-80GB          On  | 00000000:84:00.0 Off |                    0 |
| N/A   48C    P0             103W / 400W |  52889MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM4-80GB          On  | 00000000:8A:00.0 Off |                    0 |
| N/A   29C    P0              64W / 400W |    873MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM4-80GB          On  | 00000000:C0:00.0 Off |                    0 |
| N/A   29C    P0              60W / 400W |      7MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM4-80GB          On  | 00000000:C3:00.0 Off |                    0 |
| N/A   64C    P0             372W / 400W |  61557MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1773004      C   python                                    10386MiB |
|    0   N/A  N/A   3303288      C   ...ihan/miniconda3/envs/rdb/bin/python      416MiB |
|    1   N/A  N/A   1773467      C   python                                    10386MiB |
|    2   N/A  N/A   1718375      C   python                                    53270MiB |
|    3   N/A  N/A   1718924      C   python                                    53268MiB |
|    4   N/A  N/A   1677419      C   python                                    52874MiB |
|    5   N/A  N/A   1768821      C   ...iconda/envs/beyond_scale/bin/python      748MiB |
|    7   N/A  N/A    459941      C   python                                    61542MiB |
+---------------------------------------------------------------------------------------+
n(base) brando9@ampere8~ $ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_18:49:52_PDT_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0

Seem fine.

what is going on? How do I fix this? Is there anything to fix?

Code:

"""
Goal: making HF training script for model (e.g., llama v2) using raw text of informal and formal mathematics (unpaired data).

Inspiration:
- ref: SO accelerate + trainer: https://stackoverflow.com/questions/76675018/how-does-one-use-accelerate-with-the-hugging-face-hf-trainer
- ref: The unreasonable effectiveness of few-shot learning for machine translation https://arxiv.org/abs/2302.01398
- ref: colab: https://colab.research.google.com/drive/1io951Ex17-6OUaogCo7OiR-eXga_oUOH?usp=sharing
- ref: SO on collate: https://stackoverflow.com/questions/76879872/how-to-use-huggingface-hf-trainer-train-with-custom-collate-function/76929999#76929999

Looks very useful especially for peft:
- peft https://github.com/huggingface/trl/blob/main/examples/scripts/sft_trainer.py

python trl/examples/scripts/sft_trainer.py \
    --model_name meta-llama/Llama-2-7b-hf \
    --dataset_name timdettmers/openassistant-guanaco \
    --load_in_4bit \
    --use_peft \
    --batch_size 4 \
    --gradient_accumulation_steps 2

- qlora https://github.com/artidoro/qlora/blob/main/scripts/finetune_llama2_guanaco_7b.sh, 
- https://github.com/artidoro/qlora/blob/main/qlora.py

export CUDA_VISIBLE_DEVICES=6
"""
from pathlib import Path
from typing import Callable
import datasets
from datasets import load_dataset, interleave_datasets
import torch
from transformers import GPT2LMHeadModel, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments, AutoConfig
import math

import sys
from training.reinit_and_smaller_llama2 import get_deafult_smallest_baby_llama2_v1_36m_0p036b, get_weight_norms, reinitialize_weights_gpt_neox_20B_inspired_4_llama2
sys.path = [''] + sys.path
from training.utils import eval_hf, eval_hf_with_subsample, get_column_names, get_data_from_hf_dataset, group_texts, raw_dataset_2_lm_data

# -- Experiments 

def train():
    """
    I decided to make the string data close to context length of llama2 7B 4096 tokens.
    So if any string is shorter, the tokenize will padd it according to Claude.
    
    """
    # feel free to move the import statements if you want, sometimes I like everything in one place so I can easily copy-paste it into a script
    import datetime
    from pathlib import Path
    import datasets
    from datasets import load_dataset, interleave_datasets
    import torch
    import transformers
    from transformers import PreTrainedTokenizer
    from transformers import GPT2LMHeadModel, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments, AutoConfig
    import random
    import math
    import os
    torch.cuda.empty_cache()
    # buffer_size = 500_000  # can't remember what this was for and doesn't seem to be anywhere
    probabilities = []
    data_mixture_name = None
    streaming = True
    data_files = [None]
    seed = 0
    split = 'train'
    max_length = 1024  # gpt2 context length
    shuffle = False
    report_to = 'none'  # safest default
    # CHUNK_SIZE = 16_896  # approximately trying to fill the llama2 context length of 4096
    batch_size = 2
    gradient_accumulation_steps = 2
    num_epochs = 1
    num_tokens_trained = None
    num_batches=1
    optim='paged_adamw_32bit'
    learning_rate=1e-5
    warmup_ratio=0.01
    weight_decay=0.01
    lr_scheduler_type='constant_with_warmup'
    # lr_scheduler_kwargs={}

    # -- Setup wandb
    import wandb
    # - Dryrun
    mode = 'dryrun'; seed = 0; report_to = 'none'

    # - Online (real experiment)
    # mode = 'online'; seed = 0; report_to = 'wandb'

    # - train data sets
    # path, name, data_files, split = ['c4'], ['en'], [None], ['train']
    path, name, data_files, split = ['UDACA/PileSubsets'], ['uspto'], [None], ['train']
    # path, name, data_files, split = ['UDACA/PileSubsets'], ['pubmed'], [None], ['train']
    # path, name, data_files, split = ['UDACA/PileSubsets', 'UDACA/PileSubsets'], ['uspto', 'pubmed'], [None, None], ['train', 'train']
    # - models
    # pretrained_model_name_or_path = 'gpt2'  # this is the smallest model gpt2, 124M params https://huggingface.co/gpt2 
    # pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-hf'
    # pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-chat-hf'
    # pretrained_model_name_or_path = 'meta-llama/Llama-2-13b-hf'
    # pretrained_model_name_or_path = 'meta-llama/Llama-2-70b-hf'
    # pretrained_model_name_or_path = 'mistralai/Mistral-7B-v0.1'
    pretrained_model_name_or_path = 'baby_llama2_v1'
    # - important training details or it wont run, mem issues maybe
    max_steps = 1
    # max_steps = 300
    # max_steps = 19_073 # <- CHANGE THIS  11 days with baby llama2 v1 36m 1, 32
    # max_steps = 866 # <- CHANGE THIS 12hs with with baby llama2 v1 36m 1, 32
    # max_steps = 1_761 # <- CHANGE THIS 12hs with with baby llama2 v1 36m 5, 6 0.2168M tokens
    # max_steps = 306_000 # <- CHANGE THIS 12hs with with baby llama2 v1 36m 1, 32 35.1 tokens
    # max_length = 4096
    max_length = 1024
    num_batches=1
    # single gpu
    # batch_size, gradient_accumulation_steps = 1, 32  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    # batch_size, gradient_accumulation_steps = 6, 5  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    # batch_size, gradient_accumulation_steps = 5, 6  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    # batch_size, gradient_accumulation_steps = 4, 6  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    # batch_size, gradient_accumulation_steps = 4, 8  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    learning_rate=1e-4
    # learning_rate=1e-5
    optim='paged_adamw_32bit'
    # optim = 'adafactor'
    weight_decay=0.1
    warmup_ratio=0.01
    lr_scheduler_type='cosine'
    # lr_scheduler_type='constant_with_warmup'
    # lr_scheduler_type='cosine_with_warmup'
    # lr_scheduler_kwargs={},  # ref: https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/optimizer_schedules#transformers.SchedulerType 
    # -- multiple gpus 3 4096 context len
    # batch_size, gradient_accumulation_steps = 4, 8  # e.g., choosing large number mabe for stability of training? 4 (per_device_train_batch_size) * 8 (gradient_accumulation_steps), based on alpaca https://github.com/tatsu-lab/stanford_alpaca 
    # gradient_checkpointing = False
    gradient_checkpointing = True
    print(f'{batch_size=} {gradient_accumulation_steps=} {gradient_checkpointing=} {num_epochs=}')
    # -- wandb
    num_tokens_trained = max_steps * batch_size * max_length * num_batches 
    today = datetime.datetime.now().strftime('%Y-m%m-d%d-t%Hh_%Mm_%Ss')
    run_name = f'beyond scale: {path} ({today=} ({name=}) {data_mixture_name=} {probabilities=} {pretrained_model_name_or_path=} {data_files=} {max_steps=} {batch_size=} {num_tokens_trained=} {gradient_accumulation_steps=} {optim=} {learning_rate=} {max_length=} {weight_decay=} {warmup_ratio=})'
    print(f'\n---> {run_name=}\n')

    # - Init wandb
    debug: bool = mode == 'dryrun'  # BOOL, debug?
    run = wandb.init(mode=mode, project="beyond-scale", name=run_name, save_code=True)
    wandb.config.update({"path": path, "name": name, "today": today, 'probabilities': probabilities, 'batch_size': batch_size, 'debug': debug, 'data_mixture_name': data_mixture_name, 'streaming': streaming, 'data_files': data_files, 'seed': seed, 'pretrained_model_name_or_path': pretrained_model_name_or_path, 'num_epochs': num_epochs, 'gradient_accumulation_steps': gradient_accumulation_steps})
    # run.notify_on_failure() # https://community.wandb.ai/t/how-do-i-set-the-wandb-alert-programatically-for-my-current-run/4891
    output_dir = Path(f'~/data/results_{today}/').expanduser() if not debug else Path(f'~/data/results/').expanduser()
    print(f'{output_dir=}')
    print(f'{debug=}')
    print(f'{wandb.config=}')

    # -- Load model and tokenizer  
    print(f'{pretrained_model_name_or_path=}')
    if pretrained_model_name_or_path == 'gpt2':
        from transformers import GPT2Tokenizer, GPT2LMHeadModel
        tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path)
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f'{tokenizer.pad_token=}')
        print(f'{tokenizer.eos_token=}')
        print(f'{tokenizer.eos_token_id=}')
        model = GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path)
        device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        block_size: int = tokenizer.model_max_length
        print(f'{block_size=}')
        print()
    elif 'Llama-2' in pretrained_model_name_or_path or 'Mistral' in pretrained_model_name_or_path:
        # - llama2
        from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
        # bf16 or fp32
        torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8 else torch.float32 # if >= 8 ==> brain float 16 available or set to True if you always want fp32
        # get model
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            # quantization_config=quantization_config,
            # device_map=device_map,  # device_map = None  https://github.com/huggingface/trl/blob/01c4a35928f41ba25b1d0032a085519b8065c843/examples/scripts/sft_trainer.py#L82
            trust_remote_code=True,
            torch_dtype=torch_dtype,
            use_auth_token=True,
        )
        # https://github.com/artidoro/qlora/blob/7f4e95a68dc076bea9b3a413d2b512eca6d004e5/qlora.py#L347C13-L347C13
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            # cache_dir=args.cache_dir,
            padding_side="right",
            use_fast=False, # Fast tokenizer giving issues.
            # tokenizer_type='llama' if 'llama' in args.model_name_or_path else None, # Needed for HF name change
            # tokenizer_type='llama',
            trust_remote_code=True,
            use_auth_token=True,
            # token=token,  # load from cat keys/brandos_hf_token.txt if you want to load it in python and not run huggingface-cli login
        )
        # - Ensure padding token is set TODO: how does this not screw up the fine-tuning? e.g., now model doesn't learn to predict eos since it's padded our by mask, ref: https://discuss.huggingface.co/t/why-does-the-falcon-qlora-tutorial-code-use-eos-token-as-pad-token/45954
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f'{tokenizer.pad_token=}')
        print(f'{tokenizer.eos_token=}')
        print(f'{ tokenizer.eos_token_id=}')
        # get context length for setting max length for training
        if hasattr(model.config, "context_length"):
            print("Context length:", model.config.context_length)
            max_length = model.config.context_length
        else:
            # CHUNK_SIZE = 16_896  # approximately trying to fill the llama2 context length of 4096
            max_length = 4096
        block_size: int = 4096
        print(f'{block_size=}')
    elif 'baby_llama2_v1' in pretrained_model_name_or_path:
        model = get_deafult_smallest_baby_llama2_v1_36m_0p036b()
        reinitialize_weights_gpt_neox_20B_inspired_4_llama2(model, L=max_length)
        # export HF_TOKEN=$(cat ~/keys/brandos_hf_token.txt)
        tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', padding_side="right", use_fast=False, trust_remote_code=True, use_auth_token=True)
        device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8 else torch.float32 # if >= 8 ==> brain float 16 available or set to True if you always want fp32
        model = model.to(torch_dtype)
        block_size: int = max_length
        print(f'{block_size=}')
    print("Number of parameters:", sum(p.numel() for p in model.parameters()))
    print(f"Total weight norm: {get_weight_norms(model)=}")
    print(f'{torch.cuda.device_count()=} (makes sure GPUs are visible and accesible to Pytorch.)')
    print(f'Model is currently on: {next(iter(model.parameters())).device=}')
    print(f'Model is currently on: {next(iter(model.parameters())).dtype=}')
    # Sanity check -- is loss random? lnV = -ln(1/V) = -ln(1/50257) = 10.82 since CE = avg_i v_i * ln(1/p_i) but only one token is right so vi = 1 for some i so CE = ln(1/p_i)
    print(f'vocab_size: {len(tokenizer)=} \nExpected random loss: {math.log(len(tokenizer))=}')
    print(f"CUDA version: {torch.version.cuda=}")
    eval_hf_with_subsample('UDACA/pile_openwebtext2', None, 'validation', model, tokenizer, block_size, output_dir, max_eval_samples=2, print_str='> Eval OpenWebtext rand mdl')
    eval_hf_with_subsample('c4', 'en', 'validation', model, tokenizer, block_size, output_dir, max_eval_samples=2, print_str='> Eval C4 rand mdl')
    
    # --- Load datasets
    # -- Get train data set
    # - Load interleaved combined datasets
    train_datasets = [load_dataset(path, name, data_files=data_file, streaming=streaming, split=split).with_format("torch") for path, name, data_file, split in zip(path, name, data_files, split)]
    probabilities = [1.0/len(train_datasets) for _ in train_datasets]  
    # - Get raw train data set
    raw_train_datasets = interleave_datasets(train_datasets, probabilities)
    remove_columns = get_column_names(raw_train_datasets)  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader
    # - Get tokenized train data set
    # Note: Setting `batched=True` in the `dataset.map` function of Hugging Face's datasets library processes the data in batches rather than one item at a time, significantly speeding up the tokenization and preprocessing steps.
    tokenize_function = lambda examples: tokenizer(examples["text"])
    tokenized_train_datasets = raw_train_datasets.map(tokenize_function, batched=True, remove_columns=remove_columns)
    _group_texts = lambda examples : group_texts(examples, block_size)
    # - Get actual data set for lm training (in this case each seq is of length block_size, no need to worry about pad = eos since we are filling each sequence)
    lm_train_dataset = tokenized_train_datasets.map(_group_texts, batched=True)
    batch = get_data_from_hf_dataset(lm_train_dataset, streaming=streaming, batch_size=batch_size)
    print(f'{len(next(iter(batch))["input_ids"])=}')
    assert all(len(data_dict['input_ids']) == block_size for data_dict in iter(batch)), f'Error, some seq in batch are not of length {block_size}'
    train_dataset = lm_train_dataset

    # -- max steps manually decided depending on how many tokens we want to train on
    per_device_train_batch_size = batch_size
    print(f'{per_device_train_batch_size=}')
    print(f'{num_epochs=} {max_steps=}')

    # -- Training arguments and trainer instantiation ref: https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments
    print(f'{output_dir=}')
    training_args = TrainingArguments(
        output_dir=output_dir,  # The output directory where the model predictions and checkpoints will be written.
        # output_dir='.',  # The output directory where the model predictions and checkpoints will be written.
        max_steps=max_steps,  # TODO: hard to fix, see above
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,  # based on alpaca https://github.com/tatsu-lab/stanford_alpaca, allows to process effective_batch_size = gradient_accumulation_steps * batch_size, num its to accumulate before opt update step
        gradient_checkpointing = gradient_checkpointing,  # TODO depending on hardware set to true?
        optim=optim,
        # warmup_steps=int(max_steps*warmup_ratio),  # TODO: once real training starts we can select this number for llama v2, what does llama v2 do to make it stable while v1 didn't?
        warmup_ratio=warmup_ratio,  # copying alpaca for now, number of steps for a linear warmup, TODO once real training starts change? 
        # weight_decay=0.01,  # TODO once real training change?
        weight_decay=weight_decay,  # TODO once real training change?
        learning_rate = learning_rate,  # TODO once real training change? anything larger than -3 I've had terrible experiences with
        max_grad_norm=1.0, # TODO once real training change?
        lr_scheduler_type=lr_scheduler_type,  # TODO once real training change? using what I've seen most in vision 
        # lr_scheduler_kwargs=lr_scheduler_kwargs,  # ref: https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/optimizer_schedules#transformers.SchedulerType 
        logging_dir=Path('~/data/maf/logs').expanduser(),
        # save_steps=4000,  # alpaca does 2000, other defaults were 500
        save_steps=max_steps//3,  # alpaca does 2000, other defaults were 500
        # save_steps=1,  # alpaca does 2000, other defaults were 500
        # logging_steps=250,
        # logging_steps=50,  
        logging_first_step=True,
        # logging_steps=3,
        logging_steps=1,
        remove_unused_columns=False,  # TODO don't get why https://stackoverflow.com/questions/76879872/how-to-use-huggingface-hf-trainer-train-with-custom-collate-function/76929999#76929999 , https://claude.ai/chat/475a4638-cee3-4ce0-af64-c8b8d1dc0d90
        report_to=report_to,  # change to wandb!
        fp16=False,  # never ever set to True
        bf16=torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8,  # if >= 8 ==> brain float 16 available or set to True if you always want fp32
    )
    print(f'{pretrained_model_name_or_path=}\n{optim=}\n{learning_rate=}')

    # -- Init Trainer
    print(f'{train_dataset=}')
    trainer = Trainer(
        model=model,
        args=training_args,  
        train_dataset=train_dataset,
    )

    # - Train
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    if cuda_visible_devices is not None:
        print(f"CUDA_VISIBLE_DEVICES = {cuda_visible_devices}")
    trainer.train()
    trainer.save_model(output_dir=output_dir)  # TODO is this really needed? https://discuss.huggingface.co/t/do-we-need-to-explicity-save-the-model-if-the-save-steps-is-not-a-multiple-of-the-num-steps-with-hf/56745

    # -- Evaluation, NOTE: we are evaluating at the end not during training
    print()
    # - Evaluate model on OpenWebtext
    print('---- Evaluate model on OpenWebtext')
    metrics = eval_hf_with_subsample('UDACA/pile_openwebtext2', None, 'validation', model, tokenizer, block_size, output_dir)
    print(f'{metrics=}')
    # - Evaluate on C4
    print('---- Evaluate model on C4')
    metrics = eval_hf_with_subsample('c4', 'en', 'validation', model, tokenizer, block_size, output_dir)
    print(f'{metrics=}')
    # - Evluate on whole datasets
    print('---- Evaluate model on Whole OpenWebtext')
    metrics = eval_hf_with_subsample('UDACA/pile_openwebtext2', None, 'validation', model, tokenizer, block_size, output_dir, max_eval_samples=None)
    # eval_hf(trainer=Trainer(model=model, args=eval_args, train_dataset=None, eval_dataset=eval_dataset1))
    print(f'{metrics=}')
    print('---- Evaluate model on Whole C4')
    metrics = eval_hf_with_subsample('c4', 'en', 'validation', model, tokenizer, block_size, output_dir, max_eval_samples=None)
    # eval_hf(trainer=Trainer(model=model, args=eval_args, train_dataset=None, eval_dataset=eval_dataset2))
    print(f'{metrics=}')
    print('Done!\a')

def main():  
    """Since accelerate config wants this, main_training_function: main"""
    train()

# -- Run __main__

if __name__ == '__main__':
    print(f'\n\n\n------------------- Running {__file__} -------------------')
    # -- Run tests and time it
    import time
    time_start = time.time()
    # -- Run tests
    main()
    # -- End tests, report how long it took in seconds, minutes, hours, days
    print(f'Time it took to run {__file__}: {time.time() - time_start} seconds, {(time.time() - time_start)/60} minutes, {(time.time() - time_start)/60/60} hours, {(time.time() - time_start)/60/60/24} days\a')

related: huggingface transformers - Cuda detected version in Hugging Face (HF) is 5.4.0 while it's recommended to be 5.5.0, but pytorch & nvidia-smi say it's higher, how to fix? - Stack Overflow

remaining code

"""
todo:
    - finish passing the HF block_size tokenization code here so its modular
    - add function to our train code train.py
    - print the sequence length of the data once we include this code
    - create a unit test here to test block size
    - use the re-init code smart ally & brando wrote
"""
from itertools import chain
import math
import random
from typing import Optional

import torch

import datasets
from datasets import load_dataset, interleave_datasets

from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, AutoConfig
from transformers.testing_utils import CaptureLogger
from transformers import GPT2Tokenizer

def cuda_debug():
    import torch

    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"CUDA available: {cuda_available}")

    # Get the CUDA version used by PyTorch
    cuda_version = torch.version.cuda
    print(f"CUDA version: {cuda_version}")

    # Get the number of CUDA devices (GPUs)
    num_cuda_devices = torch.cuda.device_count()
    print(f"Number of CUDA devices: {num_cuda_devices}")

# # For each CUDA device, print its name and capabilities
# for i in range(num_cuda_devices):
#     print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
#     print(f"Compute Capability: {torch.cuda.get_device_capability(i)}")


def do_quick_matrix_multiply():
    """
python -c "import torch; print(torch.randn(2, 4).to('cuda') @ torch.randn(4, 1).to('cuda'));"
    """
    print(torch.randn(2, 4).to('cuda') @ torch.randn(4, 1).to('cuda'))

def get_actual_data_batch(data_set_or_batch):
    """ Returns the actual  data batch from the HF dataset obj e.g., dataset, batch etc. """
    data_batch = next(iter(data_set_or_batch))
    return data_batch

def get_vocab_size_and_ln(tokenizer: GPT2Tokenizer) -> tuple[int, float]:
    """
    Calculate the vocabulary size and its natural logarithm for a given tokenizer.

    Note:
        Sanity check -- is loss random? lnV = -ln(1/V) = -ln(1/50257) = 10.82 since CE = avg_i v_i * ln(1/p_i) but only one token is right so vi = 1 for some i so CE = ln(1/p_i)

    Args:
    tokenizer (GPT2Tokenizer): A tokenizer from the Hugging Face library.

    Returns:
    tuple[int, float]: A tuple containing the vocabulary size and its natural logarithm.
    """
    vocab_size = len(tokenizer)  # Get the size of the tokenizer's vocabulary
    ln_vocab_size = math.log(vocab_size)  # Calculate the natural logarithm of the vocabulary size
    return vocab_size, ln_vocab_size

def num_tokens(max_steps: int, batch_size: int, L: int, num_batches: int) -> int:
    """
    All sequences are of length L, due to our block size code. 
    num_batch = when using distributed training. 
            num_tokens_trained = max_steps * batch_size * L * num_batches

    how long do I have to train     
    """
    num_tokens_trained = max_steps * batch_size * L * num_batches
    return num_tokens_trained

def get_freest_gpu():
    # Get the index of the GPU with the most free memory
    devices = list(range(torch.cuda.device_count()))
    free_memory = [torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) for device in devices]
    freest_device = devices[free_memory.index(max(free_memory))]
    return freest_device

# Use for IterableDatasetDict objects (i.e. streaming=T, split is unspecified (each key in dict is name of split))
def view_exs_iterable_dataset_dict(dataset, num_exs=10, split='train'):
  dataset_split = dataset[split]
  for ex in dataset_split:
    print(ex)
    print('example details: keys', ex.keys(), ', text char length', len(ex['text']), '\n---')
    num_exs -= 1
    if num_exs == 0:
      break


# Use for IterableDataset objects (i.e. streaming=T, split=specified)
def view_exs_iterable_dataset(dataset_split, num_exs=10):
  for ex in dataset_split:
    print(ex)
    print('example details: keys', ex.keys(), ', text char length', len(ex['text']), '\n---')
    num_exs -= 1
    if num_exs == 0:
      break

def get_num_steps():
    # dataset_size: int = int(1.5e12)  # TODO, doesn't seem easy to solve. Either count all the sequennces/rows or have the meta data have this. Or make this number huge. 
    # dataset_size: int = train_dataset.num_rows
    # dataset_size: int = len(train_dataset)
    # TODO dataset.info['split']['train']['num_examples']
    # dataset_size = sum(len(dataset) for dataset in datasets)  # TODO: works on with streaming = False?
    # dataset_size = sum(dataset.cardinality() for dataset in datasets)
    pass

def raw_dataset_2_lm_data(raw_dataset, 
                          tokenizer, 
                          block_size: int, 
                          desired_dataset_column: str = 'text',
                          method_to_remove_columns: str = 'keys',
                          ):
    remove_columns = get_column_names(raw_dataset, method_to_remove_columns)  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader
    # - Get tokenized train data set
    # Note: Setting `batched=True` in the `dataset.map` function of Hugging Face's datasets library processes the data in batches rather than one item at a time, significantly speeding up the tokenization and preprocessing steps.
    tokenize_function = lambda examples: tokenizer(examples[desired_dataset_column])
    tokenized_train_datasets = raw_dataset.map(tokenize_function, batched=True, remove_columns=remove_columns)
    _group_texts = lambda examples : group_texts(examples, block_size)
    # - Get actual data set for lm training (in this case each seq is of length block_size, no need to worry about pad = eos since we are filling each sequence)
    lm_dataset = tokenized_train_datasets.map(_group_texts, batched=True)
    return lm_dataset

def get_size_of_seq_len(dataset_or_batch, verbose: bool = True, streaming: bool = True, batch_size: int = 2) -> int:
    """Print size of a sequence length in a batch. Give a hf data set obj (batches are data set objs sometimes)."""
    batch = get_data_from_hf_dataset(dataset_or_batch, streaming=streaming, batch_size=batch_size)
    size_seq_len = len(next(iter(batch))["input_ids"])
    if verbose:
        print(f'{size_seq_len=}')
        print(f'{len(next(iter(batch))["input_ids"])=}')
    return size_seq_len

def get_column_names(dataset, 
                    #   split: str = 'train',
                      method: str = 'keys', 
                      streaming: bool = True,
                      ):
    if method == 'features':
        # column_names = list(dataset[spit].features)
        column_names = list(dataset.features)
    elif method == 'keys':
        batch = get_data_from_hf_dataset(dataset, streaming=streaming, batch_size=1)
        column_names = next(iter(batch)).keys()
        # column_names = next(iter(dataset)).keys()
    else:
        raise ValueError(f"method {method} not supported")
    return column_names

def get_data_from_hf_dataset(dataset, 
                             streaming: bool = True, 
                             batch_size: int = 4, 
                            #  shuffle: bool= False, # shuffle is better but slower afaik
                            #  seed: int = 0, 
                            #  buffer_size: int = 500_000,
                             ):
    """ Gets data from a HF dataset, it's usually an iterator object e.g., some ds.map(fn, batched=True, remove_columns=remove_columns) has been applied. 
    Handles both streaming and non-streaming datasets, take for streaming and select for non-streaming.
    """
    # sample_data = dataset.select(range(batch_size)) if not isinstance(dataset, datasets.iterable_dataset.IterableDataset) else dataset.take(batch_size)
    batch = dataset.take(batch_size) if streaming else dataset.select(random.sample(list(range(len(dataset))), batch_size))
    return batch

def _tokenize_function(examples, tokenizer, tok_logger, text_column_name: str):
    """
    
    To use do:
    tokenizer = ...obtained from your model... 
    tokenize_function = lambda examples: tokenize_function(examples, tokenizer=tokenizer) 
    tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=column_names,
        )
    """
    with CaptureLogger(tok_logger) as cl:
        output = tokenizer(examples[text_column_name])
    # clm input could be much much longer than block_size
    if "Token indices sequence length is longer than the" in cl.out:
        tok_logger.warning(
            "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
            " before being passed to the model."
        )
    return output

def tokenize_function(examples, tokenizer, text_column_name: str):
    """ 
    creates a tokenize function that can be used in HF's map function and you specify which text column to tokenize.
    
    Assumes batched=True so examples is many row/data points.
    """
    return tokenizer(examples["text_column_name"])

def preprocess(examples, tokenizer, max_length: int = 1024):
    return tokenizer(examples["text"], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    # return tokenizer(examples["text"], padding="max_length", max_length=model.config.context_length, truncation=True, return_tensors="pt")

def group_texts(examples, # if batched=True it's a dict of input_ids, attention_mask, labels of len(examples['input_ids']) = 1000 
                block_size: int,  # 4096, 1024
                ):
    """
    tokenizer = ...obtained from your model... 
    tokenize_function = lambda examples: tokenize_function(examples, tokenizer=tokenizer) 
    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, remove_columns=column_names)

    if used as above then examples is
    examples = {'input_ids': [[...], [...], ...], 'attention_mask': [[...], [...], ...], 'labels': [[...], [...], ...]]]}
    examples.keys() = dict_keys(['input_ids', 'attention_mask'])
    type(examples) = <class 'dict'>
    type(examples['input_ids']) = <class 'list'>
    len(examples['input_ids']) = 1000  # if batched=True

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map    
    """
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
    # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
    total_length = (total_length // block_size) * block_size  # rounds down
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

def group_texts_v2(examples, # if batched=True it's a dict of input_ids, attention_mask, labels of len(examples['input_ids']) = 1000 
                block_size: int,  # 4096, 1024
                ):
    """
    tokenizer = ...obtained from your model... 
    tokenize_function = lambda examples: tokenize_function(examples, tokenizer=tokenizer) 
    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, remove_columns=column_names)
    _group_texts = lambda examples : group_texts_v2(examples, block_size)
    lm_train_dataset = tokenized_train_datasets.map(_group_texts, batched=True)

    if used as above then examples is
    examples = {'input_ids': [[...], [...], ...], 'attention_mask': [[...], [...], ...], 'labels': [[...], [...], ...]]]}
    examples.keys() = dict_keys(['input_ids', 'attention_mask'])
    type(examples) = <class 'dict'>
    type(examples['input_ids']) = <class 'list'>
    len(examples['input_ids']) = 1000  # if batched=True

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder for each of those groups of 1,000 texts. 
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map    
    """
    # Concatenate all texts for each key in the examples e.g., it creates one concatenated list of all input_ids, one for all attention_mask, etc.
    # for column_name in examples.keys():
    #     # chain makes an iterator that returns elements from each iterator in order, basically concatenates iterators 
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
    # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    # # get sequences of length block_size, then add eos token to end of each sequence and mask the rest of the sequence
    # result = {}
    # for k, t in concatenated_examples.items():
    #     # Initialize a list for each key (really key="text" is the one we care about) in the result
    #     result[k] = []
    #     total_length = len(t)  # Assuming t is a list or has a length
    #     for i in range(0, total_length, block_size):
    #         # Append the sublist of t from i to i + block_size
    #         seq = t[i : i + block_size]

    #         result[k].append(t[i : i + block_size])
    result["labels"] = result["input_ids"].copy()
    return result

def collate_fn_train_only_first_eos_token_mask_everything_after_it(data: list[dict[str, str]], 
                                                                   tokenizer: PreTrainedTokenizer, 
                                                                   max_length: int=1024,  # GPT2 default, likely worth you change it! This default might cause bugs.
                                                                   ) -> dict[str, torch.Tensor]:
    """ Train only on first occurence of eos. The remaining eos are masked out.

    Sometimes the model might not have a padding token. Sometimes people set the padding token to be the eos token.
    But sometimes this seems to lead to the model to predict eos token to much. 
    So instead of actually using the pad token that was set to the eos token, we instead mask out all excesive eos tokens that act as pads 
    and leave the first eos token at the end to be predicted -- since that is the only one that semantically means end of sequence 
    and therby by not training on random eos at the end by masking it not unncesserily shift/amplify the distribution of eos. 
    
    ref: https://discuss.huggingface.co/t/why-does-the-falcon-qlora-tutorial-code-use-eos-token-as-pad-token/45954/13?u=brando 
    ref: https://chat.openai.com/share/02d16770-a1f3-4bf4-8fc2-464286daa8a1
    ref: https://claude.ai/chat/80565d1f-ece3-4fad-87df-364ce57aec15 on when to call .clone()
    ref: https://stackoverflow.com/questions/76633368/how-does-one-set-the-pad-token-correctly-not-to-eos-during-fine-tuning-to-avoi
    """
    # we are training full context length for llama so remove code bellow, if it tries to pad hopefully it throws an error
    # -- Ensure tokenizer has a padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # -- Extract sequences
    # sequences: list[str] = [example.get("text", "") or "" for example in data]
    sequences: list[str] = []
    for idx, example in enumerate(data):
        # Retrieve the value for "text" from the dictionary or default to an empty string if not present or falsy. ref: https://chat.openai.com/share/bead51fe-2acf-4f05-b8f7-b849134bbfd4
        text: str = example.get("text", "") or ""
        sequences.append(text)
    # -- Tokenize the sequences
    tokenized_data = tokenizer(sequences, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    tokenized_data["labels"] = tokenized_data["input_ids"].clone()  # labels is hardcoded in HF so put it!
    # -- Set the mask value for the first eos_token in each sequence to 1 and remaining to -100
    eos_token_id = tokenizer.eos_token_id
    for idx, input_ids in enumerate(tokenized_data["input_ids"]):
        # Find all occurrences of eos_token
        eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)[0]
        if eos_positions.nelement() > 0:  # Check if eos_token is present
            first_eos_position = eos_positions[0]
            tokenized_data["attention_mask"][idx, first_eos_position] = 1  # Set the mask value to 1
            
            # Assert that the label for the first occurrence of eos_token is eos_token_id
            assert tokenized_data["labels"][idx, first_eos_position] == eos_token_id, "The label for the first eos_token is incorrect!"
            
            # For all subsequent occurrences of eos_token, set their labels to -100
            for subsequent_eos_position in eos_positions[1:]:
                tokenized_data["labels"][idx, subsequent_eos_position] = -100
                assert tokenized_data["labels"][idx, subsequent_eos_position] == -100, "The label for the subsequent_eos_position incorrect! Should be -100."
    return tokenized_data

# -- eval code

def compute_metrics(eval_preds):
    """ todo document clearly, from SS's code. """
    import evaluate
    metric = evaluate.load("accuracy")
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)

def whole_eval(model, 
         path, 
         name, 
         split, 
         tokenizer, 
         block_size,
         output_dir,
         max_eval_samples: int = 1028, 
         streaming: bool = True,
         ):
    """
    path, name, split = 'suolyer/pile_openwebtext2', None, 'validation'  # the one sudharsan used
    """
    eval_dataset = load_dataset(path, name, streaming=streaming, split=split).with_format("torch") 
    eval_dataset = raw_dataset_2_lm_data(eval_dataset, tokenizer, block_size)
    eval_dataset = eval_dataset.take(max_eval_samples)

    print(f'Saving eval results at: {output_dir=}') # The output directory where the model predictions and checkpoints will be written.
    eval_args = TrainingArguments(output_dir=output_dir, fp16=False, bf16=torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8)

    trainer = Trainer(model=model, args=eval_args, train_dataset=None, eval_dataset=eval_dataset)
    metrics = trainer.evaluate()
    try:
        perplexity = math.exp(metrics["eval_loss"])
    except OverflowError:
        perplexity = float("inf")
    metrics["perplexity"] = perplexity
    print(f'Eval metrics: {metrics=}')
    trainer.log_metrics("eval", metrics)  # display metrics
    trainer.save_metrics("eval", metrics)
    return metrics

def eval_hf(trainer):
    metrics = trainer.evaluate()
    try:
        perplexity = math.exp(metrics["eval_loss"])
    except OverflowError:
        perplexity = float("inf")
    metrics["perplexity"] = perplexity
    print(f'Eval metrics: {metrics=}')
    trainer.log_metrics("eval", metrics)  # display metrics
    trainer.save_metrics("eval", metrics)
    return metrics

def eval_hf_with_subsample(path, name, split, model, tokenizer, block_size, output_dir, max_eval_samples: int = 1024, streaming: bool = True, print_str: Optional[str] = None):
    eval_dataset = load_dataset(path, name, streaming=streaming, split=split).with_format("torch") 
    eval_dataset2 = raw_dataset_2_lm_data(eval_dataset, tokenizer, block_size)
    if max_eval_samples is None:
        eval_batch2 = eval_dataset2 
    else:
        eval_batch2 = eval_dataset2.take(max_eval_samples)
    print(f'Saving eval results at: {output_dir=}') # The output directory where the model predictions and checkpoints will be written.
    eval_args = TrainingArguments(output_dir=output_dir, fp16=False, bf16=torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8)
    trainer = Trainer(model=model, args=eval_args, train_dataset=None, eval_dataset=eval_batch2)
    metrics = eval_hf(trainer)
    if print_str is not None:
        print(print_str)
    return metrics

# -- unit tests -- #

def _test_all_batches_are_size_block_size():
    print('-- starting unit test')
    batch_size = 4
    # get gpt2 tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenize_function = lambda examples: tokenizer(examples["text"])
    # load c4 data set hf in streaming mode 
    from datasets import load_dataset
    streaming = True
    # raw_datasets = load_dataset("c4", "en", streaming=streaming, split="train")
    # raw_datasets = load_dataset('UDACA/PileSubsets', streaming=streaming).with_format('torch')   # this defaults to the subset 'all'
    # raw_datasets = load_dataset('UDACA/PileSubsets', 'pubmed', split='train', streaming=streaming).with_format('torch')
    raw_datasets = load_dataset('UDACA/PileSubsets', 'uspto', split='train', streaming=streaming).with_format('torch')
    batch = get_data_from_hf_dataset(raw_datasets, streaming=streaming, batch_size=batch_size) 
    # print(f'{batch=}')
    # print(f'{next(iter(batch))=}')
    # print(f'{next(iter(batch)).keys()}')
    # print()
    remove_columns = get_column_names(raw_datasets)  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader

    # how does it know which column to tokenize? gpt4 says default is text or your tokenized function can specify it, see my lambda fun above
    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,  # Setting `batched=True` in the `dataset.map` function of Hugging Face's datasets library processes the data in batches rather than one item at a time, significantly speeding up the tokenization and preprocessing steps.
        remove_columns=remove_columns,
    )
    batch = get_data_from_hf_dataset(tokenized_datasets, streaming=streaming, batch_size=batch_size)
    # print(f'{batch=}')
    # print(f'{next(iter(batch))=}')
    # print(f'{next(iter(batch)).keys()}')
    # print()

    _group_texts = lambda examples : group_texts(examples, block_size=tokenizer.model_max_length)
    lm_datasets = tokenized_datasets.map(
        _group_texts,
        batched=True,  # Setting `batched=True` in the `dataset.map` function of Hugging Face's datasets library processes the data in batches rather than one item at a time, significantly speeding up the tokenization and preprocessing steps.
    )
    batch = get_data_from_hf_dataset(lm_datasets, streaming=streaming, batch_size=batch_size)
    # print(f'{batch=}')
    # print(f'{next(iter(batch))=}')
    # print(f'{next(iter(batch)).keys()}')
    # print()

    # - Make sure all seq are of length block_size
    batch = get_data_from_hf_dataset(lm_datasets, streaming=streaming, batch_size=batch_size)
    for data_dict in iter(batch):
        seq = data_dict['input_ids']
        print(len(seq))
    print('Success!')

def _test_train_dataset_setup_for_main_code():
    import os
    batch_size = 2
    streaming = True
    # path, name, data_files, split = ['c4'], ['en'], [None], ['train']
    # path, name, data_files, split = ['c4', 'c4'], ['en', 'en'], [None, None], ['train', 'validation']
    # path, name, data_files, split = ['csv'], [None], [os.path.expanduser('~/data/maf_data/maf_textbooks_csv_v1/train.csv')], ['train']
    # path, name, data_files, split = ['suolyer/pile_pile-cc'] + ['parquet'] * 4, [None] + ['hacker_news', 'nih_exporter', 'pubmed', 'uspto'], [None] + [urls_hacker_news, urls_nih_exporter, urls_pubmed, urls_uspto], ['validation'] + ['train'] * 4
    # path, name, data_files, split = ['UDACA/PileSubsets'], ['uspto'], [None], ['train']
    # path, name, data_files, split = ['UDACA/PileSubsets'], ['pubmed'], [None], ['train']
    path, name, data_files, split = ['UDACA/PileSubsets', 'UDACA/PileSubsets'], ['uspto', 'pubmed'], [None, None], ['train', 'train']

    # -- Get tokenizer and model
    # tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', padding_side="right", use_fast=False, trust_remote_code=True, use_auth_token=True)
    tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1', padding_side="right", use_fast=False, trust_remote_code=True, use_auth_token=True)
    # tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenize_function = lambda examples: tokenizer(examples["text"])
    # torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8 else torch.float32  # if >= 8 ==> brain float 16 available or set to True if you always want fp32 
    # model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf', trust_remote_code=True, torch_dtype=torch_dtype, use_auth_token=True)

    # -- Get raw train data set
    raw_train_datasets = [load_dataset(p, n, data_files=data_file, streaming=streaming, split=split).with_format("torch") for p, n, data_file, split in zip(path, name, data_files, split)]
    probabilities = [1.0/len(raw_train_datasets) for _ in raw_train_datasets]  
    raw_train_datasets = interleave_datasets(raw_train_datasets, probabilities)
    # raw_train_datasets = load_dataset(path[0], name[0], data_files=data_files[0], streaming=streaming, split=split[0]).with_format("torch")
    # raw_train_datasets = load_dataset('UDACA/PileSubsets', 'uspto', split='train', streaming=streaming).with_format('torch')
    batch = get_data_from_hf_dataset(raw_train_datasets, streaming=streaming, batch_size=batch_size) 
    print(f'{batch=}')
    print(f'{next(iter(batch))=}')
    print(f'{next(iter(batch)).keys()}')
    print()
    remove_columns = get_column_names(raw_train_datasets)  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader
    
    # - Get tokenized train data set
    # Note: Setting `batched=True` in the `dataset.map` function of Hugging Face's datasets library processes the data in batches rather than one item at a time, significantly speeding up the tokenization and preprocessing steps.
    tokenized_train_datasets = raw_train_datasets.map(tokenize_function, batched=True, remove_columns=remove_columns)
    # block_size: int = tokenizer.model_max_length
    block_size: int = 4096
    assert block_size != 1000000000000000019884624838656, f'Error, block_size is {block_size} which is the default value. This is likely because you are using a tokenizer that does not have a model_max_length attribute. Please set block_size to a value that makes sense for your model.'
    _group_texts = lambda examples : group_texts_v2(examples, block_size)
    batch = get_data_from_hf_dataset(tokenized_train_datasets, streaming=streaming, batch_size=batch_size) 
    print(f'{batch=}')
    print(f'{next(iter(batch))=}')
    print(f'{next(iter(batch)).keys()}')
    
    # - Get data set for lm training (in this case each seq is of length block_size, no need to worry about pad = eos since we are filling each sequence)    lm_train_dataset = tokenized_train_datasets
    lm_train_dataset = tokenized_train_datasets.map(_group_texts, batched=True)
    batch = get_data_from_hf_dataset(lm_train_dataset, streaming=streaming, batch_size=batch_size)
    print(f'{batch=}')
    # - get an example for debugging
    # batch = iter(batch)
    # example = next(batch)
    # print(f'{example=}')
    # print(f'{next(iter(batch))=}')
    print(f'{next(iter(batch)).keys()}')
    
    # - Make sure all seq are of length block_size
    batch = get_data_from_hf_dataset(lm_train_dataset, streaming=streaming, batch_size=batch_size)
    for data_dict in iter(batch):
        seq = data_dict['input_ids']
        print(len(seq))
    print('Success!')

def _test_expt_planning():
    # -- 2.5B tokens
    num_tokens_desired: int = int(2.5e9)
    batch_size = 32
    num_batches = 1
    L = 4096
    # num_tokens_trained = max_steps * batch_size * L * num_batches
    max_steps = num_tokens_desired / (batch_size * L * num_batches)
    print(f'{max_steps=}')
    # 19_073
    # 281:07:43 --> 11 days ...

    # -- 5.5M tokens
    num_tokens_desired: int = int(5.5e6)
    max_steps = num_tokens_desired / (batch_size * L * num_batches)
    print(f'{max_steps=}')
    # 42

def _test_utils_padding_and_eos():
    # GPT2 tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    block_size = 1024
    print(f'{tokenizer.model_max_length=}')
    print(f'{block_size=}')
    raw_dataset = load_dataset("c4", "en", streaming=True, split="train").with_format("torch")
    lm_dataset = raw_dataset_2_lm_data(raw_dataset, tokenizer, block_size=block_size)
    # take a batch of size 2 and print it
    batch = get_data_from_hf_dataset(lm_dataset, streaming=True, batch_size=2) 
    print(f'{batch=}')
    data_batch = next(iter(batch))
    # todo: test that when length changes attention mask labels etc make sense
    # todo: do we need to put eos & padding and make sure label = -1? 
    print()

if __name__ == "__main__":
    from time import time
    start_time = time()
    # _test_all_batches_are_size_block_size()
    # _test_train_dataset_setup_for_main_code()
    # _test_expt_planning()
    _test_utils_padding_and_eos()
    print(f"Done!\a Total time: {time() - start_time} seconds, or {(time() - start_time)/60} minutes. or {(time() - start_time)/60/60} hours.\a")