How to tune optimizer hyperparameters with Trainer.hyperparameter_search

Hi, I have a working hyperparameter tuning setup that looks something like the following:

training_args = TrainingArguments(
    output_dir="test_trainer",
    save_strategy="epoch",
    save_total_limit=1,
    evaluation_strategy="steps",
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    lr_scheduler_type="cosine",
    warmup_steps=500,
    optim=optimizer)

trainer = Trainer(
        model=None,
        model_init=model_init,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_cls_metrics,
        tokenizer=tokenizer,
    )

def ray_hp_space(trial):
    return {
        "learning_rate": tune.loguniform(1e-6, 1e-2),
        "per_device_train_batch_size": 
           tune.choice([16, 32, 64, 128]),
    }

best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="ray",
    hp_space=ray_hp_space,
    n_trials=tune_num_trials,
    compute_objective=lambda metrics: metrics['eval_accuracy'],
    search_alg=HyperOptSearch(metric="objective", mode="max"),
    scheduler=ASHAScheduler(metric="objective", mode="max")
)

This is based on the example code from the docs here. However, I’m noticing that the learning_rate hyperparameter in the docs doesn’t seem to be used anywhere – it seems like the learning_rate is not actually being tuned at all by this example, which is instead passing the dummy trial config that is not consumed anywhere (since model_init is the only place where the trial parameters are provided, and model_init does not use the learning rate).

Optimization hyperparameters are probably the most commonly-tuned parameters. So, I am wondering, how can we actually tune the optimizer hyperparameters? This would include other parameters besides learning rate, such as weight_decay and betas for Adam.

Even subclassing Trainer and overriding create_optimizer_and_scheduler() doesn’t seem like it would accomplish this, because create_optimizer_and_scheduler() doesn’t have access to the hyperparameters being used for a given tuning run.

Thanks for any suggestions you can provide! I would be happy to contribute some updates to the docs to demonstrate this if that functionality exists but is not currently demonstrated in the documentation.