What is the official way to run a wandb sweep with hugging face (HF) transformers?

Intially I wanted to run a hugging face run such that if the user wanted to run a sweep they could (and merge them with the command line arguments given) or just execute the run with the arguments from command line. The merging is so that the train script uses a single args object (e.g. tuple[DataClass, …]) to execute it’s run. This would lead to merging the arguments from sweep or command line. But then I realized that if the user wanted to do wandb.init in a custom way through the arguments then one couldn’t do the standard run = wand.init() with no arguments that is common for sweeps. Since the wandb config usually specifies this fully. So I’d need two wandb.init(). Then the code got ugly and confusing and I realized that perhaps only running from the cmd arguments or from the sweep seperately is the best. And then it made me wonder, ok so how do people actuall yuse wandb sweeps officially with hugging face.

So what is an example demo of how to run wandb sweeps with hugging face transformers? At some point the wandb_config and the run arguments have to merge so to execute the hf run correct. And I assume if report_to='wandb' is needed for the trainer to call the wandb.init() properly (or the need to call it manually).


Pseudo Python

def exec_train(args: tuple):
    """
    note: 
        - decided against named obj to simplify code i.e. didn't know model_args, data_args, training_args, general_args
        how to have the code write the variables on it's own. Would Namespace(**tup) work? Dont want to do d['x'] = x manually.
        I don't think automatic nameing obj is possible in python: https://chat.openai.com/share/b1d58369-ce27-4ee3-a588-daf28137f774
        better reference maybe some day. 
        - seperates logic of wandb setup from the actual training code a little bit for cleaner (to reason) code.
        - passes run var just in case it's needed. 
    """
    model_args, data_args, training_args = args
    print(training_args.report_to)
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    special_tokens_dict = get_special_tokens_dict() 

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )

    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
    trainer.train()

def train(args: tuple):
    """
    Runs train but seperates the wandb setup from the actual training code.
    """
    # - init wanbd run
    run = wandb.init()
    print(f'{wandb.get_sweep_url()}=')
    # - exec run
    # args[3].run = run  # just in case the GeneralArguments has a pointer to run. Decided against this to avoid multiple pointers to the same object.
    exec_train(args)
    # - finish wandb
    run.finish()
    
def exec_run_from_sweep():
    """ Run standard sweep.
    
    In uutils since this is standard code. (You can write in your private repo optional expansions.)
    """
    # -- 1. Define the sweep configuration in a YAML file and load it in Python as a dict.
    path2sweep_config = '~/ultimate-utils/tutorials_for_myself/my_wandb_uu/my_wandb_sweeps_uu/sweep_in_python_yaml_config/sweep_config.yaml'
    config_path = Path(path2sweep_config).expanduser()
    with open(config_path, 'r') as file:
      sweep_config = yaml.safe_load(file)
    # -- 2. Initialize the sweep in Python which create it on your project/eneity in wandb platform and get the sweep_id.
    sweep_id = wandb.sweep(sweep_config, entity=sweep_config['entity'], project=sweep_config['project'])
    # -- 3. Finally, once the sweep_id is acquired, execute the sweep using the desired number of agents in python.
    wandb.agent(sweep_id, function=train, count=5)
    # print(f"Sweep URL: https://wandb.ai/{sweep_config['entity']}/{sweep_config['project']}/sweeps/{sweep_id}")
    wandb.get_sweep_url()
    
def get_args_for_run_from_cmd_args_or_sweep():
    """
    Simply execs a run either from a wand sweep file or from the command line arguments. Ignore the wandb sweep details
    if it confuses you. 
    """
    # 1. parse all the arguments from the command line
    parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments, GeneralArguments))
    _, _, _, general_args = parser.parse_args_into_dataclasses()  # default args is to parse sys.argv
    # 2. if the wandb_config option is on, then overwrite run cmd line configuration in favor of the sweep_config.
    if general_args.path2sweep_config:  # None => False => not getting wandb_config
        # overwrite run configuration with the wandb_config configuration (get config and create new args)
        config_path = Path(general_args.path2sweep_config).expanduser()
        with open(config_path, 'r') as file:
            sweep_config = dict(yaml.safe_load(file))
        sweep_args: list[str] = [item for pair in [[f'--{k}', str(v)] for k, v in sweep_config.items()] for item in pair]
        model_args, data_args, training_args, general_args = parser.parse_args_into_dataclasses(args=sweep_args)
        args: tuple = (model_args, data_args, training_args, general_args)  # decided against named obj to simplify code
        # 3. execute run from sweep
        # Initialize the sweep in Python which create it on your project/eneity in wandb platform and get the sweep_id.
        sweep_id = wandb.sweep(sweep_config, entity=sweep_config['entity'], project=sweep_config['project'])
        # # Finally, once the sweep_id is acquired, execute the sweep using the desired number of agents in python.
        train = lambda : train(args)  # pkg train with args i.e., when you call train() it will all train(args).
        wandb.agent(sweep_id, function=train, count=general_args.count)
        # # print(f"Sweep URL: https://wandb.ai/{sweep_config['entity']}/{sweep_config['project']}/sweeps/{sweep_id}")
        # wandb.get_sweep_url()
    else:
        # use the args from the command line
        parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments, GeneralArguments))
        model_args, data_args, training_args, general_args = parser.parse_args_into_dataclasses()
        # 3. execute run
        args: tuple = (model_args, data_args, training_args, general_args) # decided against named obj to simplify code
        # train(args)
    return args
    

if __name__ == '__main__':
    import time
    start_time = time.time()
    exec_run_from_cmd_args_or_sweep()
    print(f"The main function executed in {time.time() - start_time} seconds.\a")

Some Notes

Wand sweeps current thoughts:
Major Assumption: wandb.config comes from a .yaml that has a specific structure that doesn’t change (since the website needs this structure to set up the ui correctly)

  • soln1: have a ScriptArguments dataclass that is same structure as wandb.config and merge it. The merging still needs to respect the wandb structure and custom HF args structure.
    • this is under the assumption that wandb.config have specific structure that doesn’t change
  • soln2: loop throught he wandb.config (dict) and create a string that looks like a sys.argv argument -- {name} and have HF argparse parse it and join it with the previous
    structure (mdl, data, train) we specified for the args in the code.
    run = wandb.init()
    wandb.get_sweep_url()
    sweep_config = run.config
    # might need to change a little bit to respect the wandb_config structure
    args: list[str] = [item for pair in [[f'--{k}', str(v)] for k, v in sweep_config.items()] for item in pair]
    parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=args)
    # make sure the 3 or X args have the fields from the wandb_config
  • this is under the assumption that wandb.config have specific structure that doesn’t change
  • I’m also assuming that parse.parse_args_into_dataclasses(args) the will do the recursive matching of names I want
  • soln3: recursively loop through the args generated from the HF parser and replace the values with the ones from wandb.config
    • this is under the assumption that wandb.config have specific structure that doesn’t change

Decision is to keep it simple. Ideally we give a flag that says to either

  1. use the given arguments to the python cmd or
  2. use the wandb_config
    I guess the easiest thing would be to do this:
    → Key Decision: if arg says config, then overwrite the args using config else don’t use the config.

Current attempt

from pathlib import Path
from typing import Optional

import wandb
import yaml


def get_sweep_config(path2sweep_config: str) -> dict:
    """ Get sweep config from path """
    config_path = Path(path2sweep_config).expanduser()
    with open(config_path, 'r') as file:
        sweep_config = yaml.safe_load(file)
    return sweep_config


def wandb_sweep_config_2_sys_argv_args_str(config: dict) -> list[str]:
    """Make a sweep config into a string of args the way they are given in the terminal.
    Replaces sys.argv list of strings "--{arg_name} str(v)" with the arg vals from the config.
    This is so that the input to the train script is still an HF argument tuple object (as if it was called from
    the terminal) but overwrites it with the args/opts given from the sweep config file.
    """
    args: list[str] = [item for pair in [[f'--{arg_name}', str(v)] for arg_name, v in config.items()] for item in pair]
    return args


def exec_run_for_wandb_sweep(path2sweep_config: str,
                             function: callable,
                             pass_sweep_id: bool = False
                             ) -> None:  # str but not sure https://chat.openai.com/share/4ef4748c-1796-4c5f-a4b7-be39dfb33cc4
    """
    Run standard sweep from config file. Given correctly set train func., it will run a sweep in the standard way.
    Note, if entity and project are None, then wandb might try to infer them and the call might fail. If you want to
    do a debug mode, set wandb.init(mode='dryrun') else to log to the wandb plataform use 'online' (ref: https://chat.openai.com/share/c5f26f70-37be-4143-95f9-408c92c59669 unverified).
    You need to code the mode in your train file correctly yourself e.g., train = lambda : train(args) or put mode in
    the wandb_config but note that mode is given to init so you'd need to read that field from a file and not from
    wandb.config (since you haven't initialized wandb yet).

    e.g.
        path2sweep_config = '~/ultimate-utils/tutorials_for_myself/my_wandb_uu/my_wandb_sweeps_uu/sweep_in_python_yaml_config/sweep_config.yaml'

    Important remark:
        - run = wandb.init() and run.finish() is run inside the train function.
    """
    # -- 1. Define the sweep configuration in a YAML file and load it in Python as a dict.
    sweep_config: dict = get_sweep_config(path2sweep_config)

    # -- 2. Initialize the sweep in Python which create it on your project/eneity in wandb platform and get the sweep_id.
    sweep_id = wandb.sweep(sweep_config, entity=sweep_config.get('entity'), project=sweep_config.get('project'))
    print(f'{wandb.get_sweep_url()}')
    # from uutils.wandb_uu.common import _print_sweep_url
    # _print_sweep_url(sweep_config, sweep_id)

    # -- 3. Finally, once the sweep_id is acquired, execute the sweep using the desired number of agents in python.
    if pass_sweep_id:
        function = lambda: function(sweep_id)
    wandb.agent(sweep_id, function=function,
                count=sweep_config.get('run_cap'))  # train does wandb.init() & run.finish()
    # return sweep_id  # not sure if I should be returning this


def setup_and_run_train(parser,
                        mode: str,
                        train: callable,
                        sweep_id: Optional[str] = None,
                        ):
    # if sweep get args from wandb.config else use cmd args (e.g. default args)
    if sweep_id is None:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()  # default args is to parse sys.argv
        run = wandb.init(mode=mode)
        train(args=(model_args, data_args, training_args), run=run)
    else:  # run sweep
        assert mode == 'online'
        run = wandb.init(mode=mode)
        # print(f'{wandb.get_sweep_url()=}')
        sweep_config = wandb.config
        args: list[str] = wandb_sweep_config_2_sys_argv_args_str(sweep_config)
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
            args)  # default args is to parse sys.argv
        train(args, run)


# - examples & tests

def train_demo(args: tuple, run):
    import torch

    # usually here in the wandb demos
    # # Initialize a new wandb run
    # run = wandb.init(mode=mode)
    # # print(f'{wandb.get_sweep_url()=}')

    # unpack args
    model_args, data_args, training_args = args

    # unpack args/config
    num_its = training_args.num_its
    lr = training_args.lr

    # Simulate the training process
    train_loss = 8.0 + torch.rand(1).item()
    for i in range(num_its):
        update_step = lr * torch.rand(1).item()
        train_loss -= update_step
        wandb.log({"lr": lr, "train_loss": train_loss})

    # Finish the current run
    run.finish()
    
def main_example_run_train_debug_sweep_mode_for_hf_trainer(train: callable = train_demo):
    """

    idea:
    - get path2sweep_config from argparse args.
    - decide if it's debug or not from report_to


    if report_to = "none" => mode=dryrun and entity & project are None. Call agent(,count=1)
    if report_to = "wandb" => mode="online", set entity, proj from config file. Call agent(, count=run_cap)

    --
    (HF trainingargs, wandb.init)
    (report_to, mode)
    Yes, makes sense
    ("none", "disabled") yes == debug no wandb
    ("wandb", "dryrun") yes == debug & test wanbd logging

    ("wandb", "online") yes == usually means run real expt and log to wandb platform.
    No, doesn't make sense
    ("none", "dryrun") no issue, but won't log to wandb locally anyway since hf trainer wasn't instructed to do so.
    """
    from transformers import HfArgumentParser
    from uutils.hf_uu.hf_argparse.falcon_uu import ModelArguments, DataArguments, TrainingArguments

    # - run sweep or debug
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    path2sweep_config: str = training_args.path2sweep_config
    sweep_config: dict = get_sweep_config(path2sweep_config)

    # note these if stmts could've just been done with report_to hf train args opt.
    mode, report_to = sweep_config.get('mode'), sweep_config.get('report_to')
    if mode == 'online':
        # run a standard sweep. The train or setup_and_run_train func. make sure wandb.config is set correctly in args
        assert report_to == 'wandb'
        setup_and_run_train = lambda sweep_id: setup_and_run_train(parser, mode, train, sweep_id)
        exec_run_for_wandb_sweep(path2sweep_config, function=setup_and_run_train, pass_sweep_id=True)
    elif mode == 'dryrun':
        raise ValueError(f'dryrun for hf trainer not needed since its already tested if the wandb logging works')
    elif mode == 'disabled':
        assert report_to == 'none'
        setup_and_run_train(parser, mode, train, pass_sweep_id = False)


if __name__ == '__main__':
    import time

    start_time = time.time()
    main_example_run_train_debug_sweep_mode_for_hf_trainer()
    print(f"The main function executed in {time.time() - start_time} seconds.\a")

refs:

I would be interested to know about this too. @brando did you got any response for this?

maybe this is the official doc for it: Hyperparameter Search using Trainer API ? @brando