Actually sorry, I realized there were a couple mistakes above. I also found the handler_service.py. I am still running into the same error though. I added only one custom function - predict_fn - and basically copied the original predict function except for the fact that the inputs parameter is now labelled text1. It still produces the same error. For context, inference.py was put in model1.tar.gz under the folder code which is what is in the instructions. My original model, model.tar.gz without the custom inference.py is working fine. The config files are identical. The only difference between the two folders being that the most recent, model1.tar.gz contains code/inference.py
Thanks.
import os
import json
import torch
def predict_fn(self, data):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.
Args:
data (dict): deserialized decoded_input_data returned by the input_fn
Returns:
obj (dict): prediction result.
"""
# pop inputs for pipeline
inputs = data.pop("text1", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.model(inputs, **parameters)
else:
prediction = self.model(inputs)
return prediction
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import BytesDeserializer
import sagemaker
model_name = 'model1'
endpoint_name = 'endpoint1'
role = sagemaker.get_execution_role()
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
model_data="s3://call-summarization/model1.tar.gz",
role=role,
transformers_version="4.6.1",
pytorch_version="1.7.1",
py_version='py36',
name=model_name
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g4dn.xlarge',
endpoint_name = endpoint_name,
)