RuntimeError: model_init should have 0 or 1 argument

I’m trying to tune hyper-params with the following code:

def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 5e-3, 5e-5),
        "arr_gradient_accumulation_steps": trial.suggest_int("num_train_epochs", 8, 16),
        "arr_per_device_train_batch_size": trial.suggest_int(2, 4),        
    }


def get_model(model_name, config):
    return AutoModelForSequenceClassification.from_pretrained(model_name, config=config)

def compute_metric(eval_predictions):
    
    metric         = load_metric('accuracy')    
    logits, labels = eval_predictions
    predictions    = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args   = TrainingArguments(output_dir='test-trainer', 
                                    evaluation_strategy="epoch",
                                    num_train_epochs= 10)
data_collator   = default_data_collator
model_name      = 'sentence-transformers/nli-roberta-base-v2'
config = AutoConfig.from_pretrained(model_name,num_labels=3)

trainer = Trainer(
    model_init      = get_model(model_name, config),
    args            = training_args,
    train_dataset   = tokenized_datasets['TRAIN'],
    eval_dataset    = tokenized_datasets['TEST'],    
    compute_metrics = compute_metric,
    tokenizer       = None,
    data_collator   = data_collator,
)

best = trainer.hyperparameter_search(direction="maximize", hp_space=my_hp_space)

And getting error:

 1173     model = self.model_init(trial)
   1174 else:
-> 1175     raise RuntimeError("model_init should have 0 or 1 argument.")
   1177 if model is None:
   1178     raise RuntimeError("model_init should not return None.")

RuntimeError: model_init should have 0 or 1 argument.
  1. What am I doing wrong ?
  2. How can I fix it and run hyper parameter method and get best model parameters ?

model_init should be a function that takes the 0 or 1 (the trial hyperparameters) argument and returns your model. You are passing it a model.