AttributeError: 'ORTTrainingArguments' object has no attribute 'deepspeed_plugin'

I’m getting the following error when switching my trainer to the optimum[onnxruntime] one in my TokenClassification app:

AttributeError                            Traceback (most recent call last)
<ipython-input-4-3f32f8ae7125> in <cell line: 184>()
    182 compute_metrics = compute_metrics_wrapper(id2label)
    183 
--> 184 trainer = ORTTrainer(
    185         model=model,
    186         args=training_args,

2 frames
/usr/local/lib/python3.10/dist-packages/optimum/onnxruntime/trainer.py in __init__(self, model, tokenizer, feature, args, data_collator, train_dataset, eval_dataset, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, onnx_model_path)
    302         onnx_model_path: Union[str, os.PathLike] = None,
    303     ):
--> 304         super().__init__(
    305             model=model,
    306             args=args,

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in __init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    334         self.is_in_train = False
    335 
--> 336         self.create_accelerator_and_postprocess()
    337 
    338         # memory metrics - must set up as early as possible

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in create_accelerator_and_postprocess(self)
   3805         # create accelerator object
   3806         self.accelerator = Accelerator(
-> 3807             deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin
   3808         )
   3809 

AttributeError: 'ORTTrainingArguments' object has no attribute 'deepspeed_plugin'

Below you can see the code for the part I changed to integrate the ONXX functionalities:

!pip install optimum[onnxruntime]
from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments

model = AutoModelForTokenClassification.from_pretrained(
    checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

training_args = ORTTrainingArguments(
        output_dir=output_path,
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        weight_decay=weight_decay,
        logging_steps=logging_steps,
        report_to="wandb",
        run_name = "ml-training-v1",
        evaluation_strategy="epoch",
        save_strategy=save_strategy,
        save_total_limit=1,
        load_best_model_at_end=load_best_model_at_end,
        metric_for_best_model="overall_f1",
        optim = "adamw_ort_fused"
    )

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

compute_metrics = compute_metrics_wrapper(id2label)

trainer = ORTTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=callbacks,
        feature="token-classification"
    )

trainer.train()

metrics = trainer.evaluate(inference_with_ort=True)

I haven’t found this error anywhere… Any help is appreciated