Hi all,
As suggested by @philschmid in my previous thread, it should be possible to pack multiple models into one SageMaker inference endpoint in order to run multiple predictions for the same input.
When trying this out on Inferentia instances, I ran into an error where the stack trace makes little sense to me.
I compiled each model to use 4 Neuron cores and adapted the inference code as follows :
import os
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer
import torch
import torch.neuron
# To use one neuron core per worker
os.environ["NEURON_RT_NUM_CORES"] = "4"
N_MODELS = 4
# saved weights name
AWS_NEURON_TRACED_WEIGHTS_NAME = "neuron_model.pt"
def model_fn(model_dir):
# load tokenizer and neuron model from model_dir
models_path = Path(model_dir)
all_models = []
for i in range(N_MODELS):
model_path = models_path / f'model_{i}'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = torch.jit.load(os.path.join(model_path, AWS_NEURON_TRACED_WEIGHTS_NAME))
model_config = AutoConfig.from_pretrained(model_path)
all_models.append((tokenizer, model, model_config))
return all_models
def predict_fn(data, model_tokenizer_model_config):
all_models = model_tokenizer_model_config
predictions = []
inputs = data.pop("inputs", data)
for i, (model, tokenizer, model_config) in enumerate(all_models):
embeddings = tokenizer(
inputs,
return_tensors="pt",
max_length=model_config.traced_sequence_length,
padding="max_length",
truncation=True,
)
# convert to tuple for neuron model
neuron_inputs = tuple(embeddings.values())
with torch.no_grad():
predictions = model(*neuron_inputs)[0]
scores = torch.nn.Softmax(dim=1)(predictions)
for item in scores:
predictions.append({f"label_{i}": model_config.id2label[item.argmax().item()],
f"score_{i}": item.max().item()})
return predictions
The stack trace given is as follows :
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Prediction error
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 234, in handle
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - response = self.transform_fn(self.model, input_data, content_type, accept)
2022-09-13T12:31:15,797 [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response time: 7
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 190, in transform_fn
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - predictions = self.predict(processed_data, model)
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/.sagemaker/mms/models/model/code/inference.py", line 44, in predict_fn
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - truncation=True,
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - return forward_call(*input, **kwargs)
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - RuntimeError: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor))
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During handling of the above exception, another exception occurred:
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/opt/conda/lib/python3.7/site-packages/mms/service.py", line 108, in predict
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - ret = self._entry_point(input_batch, self.context)
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 243, in handle
2022-09-13T12:31:15,800 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - raise PredictionException(str(e), 400)
2022-09-13T12:31:15,800 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor)) : 400
The model loading seems to have worked and the inference part crashes.
The part that I don’t understand is that it seems to be the tokenization that fails, this part :
File "/.sagemaker/mms/models/model/code/inference.py", line 44, in predict_fn
truncation=True,
corresponds to the tokenization, however the next function call in the stack trace seems to be a forward pass?
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
RuntimeError: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor))
I’m not sure why there would be a forward pass during tokenization and therefore I’m not sure how to proceed. I also tested with some logging to make sure that it was indeed the tokenization step that crashes and not the forward pass (in case of source code misalignment), which is indeed the case.
Any hints would be welcome, thanks a lot!
Best,
Vil