Error converting transformers to ONNX with HF Optimum

Hi everyone, I am still new to the ONNX and would like to do the conversion from vanilla transformers to ONNX unlocking more runtime and memory efficiencies on GPU.

I am trying to use HF Optimum library to do the high-level conversion but got an error of

ValueError: Required inputs (['position_ids']) are missing from input feed (['input_ids', 'attention_mask']).

To reproduce the error, this is the code I used

from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
from pathlib import Path
import torch


model_name = "Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit"
onnx_path = Path("onnx")

# load vanilla transformers and convert to onnx
model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# save onnx checkpoint and tokenizer
model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)

# Customize embedding pipelines

from transformers import Pipeline
import torch.nn.functional as F
import torch

# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim  
def mean_pooling(model_output, weights, attention_mask):
  token_embeddings = model_output[0] #First element of model_output contains all token embeddings
  input_mask_expanded = attention_mask
  return torch.sum(token_embeddings * input_mask_expanded * weights, 1) / torch.sum(input_mask_expanded * weights, dim=1)


class SentenceEmbeddingPipeline(Pipeline):

  def _sanitize_parameters(self, **kwargs):
      # we don't have any hyperameters to sanitize
      preprocess_kwargs = {}
      if 'is_query' in kwargs:
        preprocess_kwargs['is_query'] = kwargs['is_query']
      return preprocess_kwargs, {}, {}

  def preprocess(self, inputs, is_query=False):

    SPECB_QUE_BOS = tokenizer.encode("[", add_special_tokens=False)[0]
    SPECB_QUE_EOS = tokenizer.encode("]", add_special_tokens=False)[0]

    SPECB_DOC_BOS = tokenizer.encode("{", add_special_tokens=False)[0]
    SPECB_DOC_EOS = tokenizer.encode("}", add_special_tokens=False)[0]

    # Tokenize without padding
    inputs_tokens = self.tokenizer(inputs, padding=False, max_length=2000, truncation=True)

    # Add special brackets & pay attention to them
    if is_query:
      inputs_tokens["input_ids"].insert(0, SPECB_QUE_BOS)
      inputs_tokens['input_ids'].append(SPECB_QUE_EOS)
    else:
      inputs_tokens["input_ids"].insert(0, SPECB_DOC_BOS)
      inputs_tokens['input_ids'].append(SPECB_DOC_EOS)
    inputs_tokens["attention_mask"].insert(0, 1)
    inputs_tokens["attention_mask"].append(1)

    # Add padding
    batch_tokens = self.tokenizer.pad(inputs_tokens, padding=True, return_tensors="pt")  

    batch_tokens['input_ids'] = batch_tokens['input_ids'].expand(1, -1)
    batch_tokens['attention_mask'] = batch_tokens['attention_mask'].expand(1, -1)

    return batch_tokens


  def _forward(self, model_inputs):

    # Get the embeddings
    with torch.no_grad():

      # Get hidden state of shape [bs, seq_len, hid_dim]
      last_hidden_state = self.model(**model_inputs, output_hidden_states=True, return_dict=True).last_hidden_state

    # Get weights of shape [bs, seq_len, hid_dim]
    weights = (
        torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
        .unsqueeze(0)
        .unsqueeze(-1)
        .expand(last_hidden_state.size())
        .float().to(last_hidden_state.device)
    )

    # Get attn mask of shape [bs, seq_len, hid_dim]
    input_mask_expanded = (
        model_inputs["attention_mask"]
        .unsqueeze(-1)
        .expand(last_hidden_state.size())
        .float()
    )
    return {"outputs": last_hidden_state, "weights": weights, "attention_mask": input_mask_expanded}

  def postprocess(self, model_outputs):
    # Perform pooling
    sentence_embeddings = mean_pooling(model_outputs["outputs"], model_outputs["weights"], model_outputs["attention_mask"])
    
    return sentence_embeddings

I tried to look up but could not find much information on this error. If anyone could look into this, that will be helpful, cheers!