Hi, I have a working hyperparameter tuning setup that looks something like the following:
training_args = TrainingArguments(
output_dir="test_trainer",
save_strategy="epoch",
save_total_limit=1,
evaluation_strategy="steps",
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
lr_scheduler_type="cosine",
warmup_steps=500,
optim=optimizer)
trainer = Trainer(
model=None,
model_init=model_init,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_cls_metrics,
tokenizer=tokenizer,
)
def ray_hp_space(trial):
return {
"learning_rate": tune.loguniform(1e-6, 1e-2),
"per_device_train_batch_size":
tune.choice([16, 32, 64, 128]),
}
best_trial = trainer.hyperparameter_search(
direction="maximize",
backend="ray",
hp_space=ray_hp_space,
n_trials=tune_num_trials,
compute_objective=lambda metrics: metrics['eval_accuracy'],
search_alg=HyperOptSearch(metric="objective", mode="max"),
scheduler=ASHAScheduler(metric="objective", mode="max")
)
This is based on the example code from the docs here. However, I’m noticing that the learning_rate
hyperparameter in the docs doesn’t seem to be used anywhere – it seems like the learning_rate
is not actually being tuned at all by this example, which is instead passing the dummy trial config that is not consumed anywhere (since model_init
is the only place where the trial parameters are provided, and model_init
does not use the learning rate).
Optimization hyperparameters are probably the most commonly-tuned parameters. So, I am wondering, how can we actually tune the optimizer hyperparameters? This would include other parameters besides learning rate, such as weight_decay and betas for Adam.
Even subclassing Trainer and overriding create_optimizer_and_scheduler()
doesn’t seem like it would accomplish this, because create_optimizer_and_scheduler()
doesn’t have access to the hyperparameters being used for a given tuning run.
Thanks for any suggestions you can provide! I would be happy to contribute some updates to the docs to demonstrate this if that functionality exists but is not currently demonstrated in the documentation.