I’m encountering an issue when trying to log a SetFitModel
with MLflow. My SetFitModel
is loaded from a .pth
file using PyTorch’s torch.load()
function. However, when I try to log the model with mlflow.pytorch.log_model()
, I’m receiving a TypeError stating that the ‘pytorch_model’ argument should be a torch.nn.Module
.
import torch
import mlflow
model_path = "/path/to/my/model.pth"
model_V5 = torch.load(model_path, map_location=torch.device('cpu'))
model_V5.to('cpu')
mlflow.start_run()
mlflow.pytorch.log_model(model_V5, "my_SetFit_model_V5")
Any guidance on how to log a SetFitModel
with MLflow would be greatly appreciated.
Thank you in advance!