Hi,
I trained a BERT model with a custom classifier head. I tried to load is using the ORTModelForSequenceClassification
, but it did not load the proper classifier heads. It just loaded the default heads.
After several hours of digging through the code for optimum, I was able to get my code to work using this:
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.exporters.tasks import TasksManager
TasksManager.infer_library_from_model = lambda *args, **kwargs: "transformers"
TasksManager.get_model_class_for_task = lambda *args, **kwargs: BertWithCustomHead
onnx_model = ORTModelForSequenceClassification.from_pretrained("./my-classifier", export=True, task="text-classification")
It gets the job done, but I feel it’s kind of hacky. Does anyone know if there is there a better way to do this?