I want to create a SageMaker endpoint of WHISPER but i can’t get it to return timestamps.
My deployment script is:
hub = {"HF_MODEL_ID": "openai/whisper-base", "HF_TASK": "automatic-speech-recognition"}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
# model_data = s3_location,
transformers_version="4.26",
pytorch_version="1.13",
py_version="py39",
env=hub,
role=role,
entry_point="code/inference.py",
# source_dir="./code",
)
# deploy model to SageMaker Inference
audio_serializer = DataSerializer(content_type="audio/x-audio")
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type="ml.m5.xlarge", # ec2 instance type
serializer=audio_serializer,
)
And my code in code/inference.py
is:
import json
import logging
import torch
import transformers
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperFeatureExtractor,
)
from transformers import pipeline
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def model_fn(model_dir):
logger.info("Loading Model")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large")
model = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
chunk_length_s=30,
device=device,
)
logger.info("Got model, returning it.")
return model
def predict_fn(input_data, model):
logger.info("Predicting from model")
results = model(input_data, return_timestamps=True)
logger.info(f"Results\n{results}")
return results
def output_fn(predictions, accept):
if accept == "application/json":
output = predictions["chunks"]
return json.dumps(output), accept
else:
raise ValueError(f"Unsupported content type: {accept}")
However, when I run this, I just get the text without timestamps so I presume that inference.py
script isn’t even getting loaded. Do you know if I’m making any errors here?