Unable to run Optuna hyperparam search

I’m using this simple script, using the example blog post. However, it fails because of wandb. It was of no use to make wandb as OFFLINE as well.

from datasets import load_dataset, load_metric
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          Trainer, TrainingArguments)
import wandb


wandb.init()

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
dataset = load_dataset('glue', 'mrpc')
metric = load_metric('glue', 'mrpc')

def encode(examples):
    outputs = tokenizer(
        examples['sentence1'], examples['sentence2'], truncation=True)
    return outputs

encoded_dataset = dataset.map(encode, batched=True)

def model_init():
    return AutoModelForSequenceClassification.from_pretrained(
        'distilbert-base-uncased', return_dict=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Evaluate during training and a bit more often
# than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = TrainingArguments(
    "test", eval_steps=500, disable_tqdm=True,
    evaluation_strategy='steps',)

trainer = Trainer(
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    model_init=model_init,
    compute_metrics=compute_metrics,
)

def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 0.1, 0.3),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 5, 10),
        "seed": trial.suggest_int("seed", 20, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [32, 64]),
    }


trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    n_trials=10,
    hp_space=my_hp_space
)

Trail 0 finishes successfully, but next Trail 1 crashes with following error:

  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 138, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1376, in train
    self.log(metrics)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer.py", line 1688, in log
    self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 371, in on_log
    return self.call_event("on_log", args, state, control, logs=logs)
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/trainer_callback.py", line 378, in call_event
    result = getattr(callback, event)(
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/transformers/integrations.py", line 754, in on_log
    self._wandb.log({**logs, "train/global_step": state.global_step})
  File "/home/user123/anaconda3/envs/iza/lib/python3.8/site-packages/wandb/sdk/lib/preinit.py", line 38, in preinit_wrapper
    raise wandb.Error("You must call wandb.init() before {}()".format(name))
wandb.errors.Error: You must call wandb.init() before wandb.log()

Any help is highly appreciated.