Trainer API Error when Hyperparameter Tuning with Custom Loss Function

I have managed to fine-tune a sequence classification model with a custom loss function on Google Colab, but I’m running into an error instantiating the Trainer when trying to perform a hyperparameter search using optuna.

# Dealing with Class Imbalance
class_weights = (1 - (df["labels"].value_counts().sort_index() / len(df))).values
class_weights = torch.from_numpy(class_weights).float().to("cuda") # Change to "cpu" if not GPU enabled

# Defining the Trainer to compute Custom Loss Function
class WeightedLossTrainer(Trainer):
  def compute_loss(self, model, inputs, return_outputs=False):
    # Feed inputs to model and extract logits
    outputs = model(**inputs)
    logits = outputs.get("logits")
    # Extract Labels
    labels = inputs.get("labels")
    # Define loss function with class weights
    loss_func = torch.nn.CrossEntropyLoss(weight=class_weights)
    # Compute loss
    loss = loss_func(logits, labels)
    return (loss, outputs) if return_outputs else loss

# Instantiating the Model
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(checkpoint,
                                                              num_labels=len(labels),
                                                              id2label=id2label,
                                                              label2id=label2id,
                                                              ignored_mismatched_sizes=True
                                                              )

# Defining the Metrics Function
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  f1 = f1_score(labels, preds, average="weighted")
  return {"f1": f1}

training_args = TrainingArguments(output_dir=output_path,
                                  num_train_epochs=5,
                                  learning_rate=2e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  logging_steps=logging_steps,
                                  fp16=True,    # Mixed Precision to train faster
                                  push_to_hub=False
                                  )

trainer = WeightedLossTrainer(model=model_init,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=dataset["train"],
                  eval_dataset=dataset["validation"],
                  tokenizer=tokenizer
                  )

I am receiving the follow error…

    534     def _move_model_to_device(self, model, device):
--> 535         model = model.to(device)
    536         # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
    537         if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):

AttributeError: 'function' object has no attribute 'to'

Any insight would be much appreciated.