Packing multiple models into one SageMaker inference instance with Inferentia

Hi all,

As suggested by @philschmid in my previous thread, it should be possible to pack multiple models into one SageMaker inference endpoint in order to run multiple predictions for the same input.

When trying this out on Inferentia instances, I ran into an error where the stack trace makes little sense to me.

I compiled each model to use 4 Neuron cores and adapted the inference code as follows :

import os
from pathlib import Path

from transformers import AutoConfig, AutoTokenizer
import torch
import torch.neuron

# To use one neuron core per worker
os.environ["NEURON_RT_NUM_CORES"] = "4"

N_MODELS = 4

# saved weights name
AWS_NEURON_TRACED_WEIGHTS_NAME = "neuron_model.pt"


def model_fn(model_dir):
    # load tokenizer and neuron model from model_dir
    models_path = Path(model_dir)
    all_models = []
    for i in range(N_MODELS):
        model_path = models_path / f'model_{i}'
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = torch.jit.load(os.path.join(model_path, AWS_NEURON_TRACED_WEIGHTS_NAME))
        model_config = AutoConfig.from_pretrained(model_path)
        all_models.append((tokenizer, model, model_config))
    return all_models


def predict_fn(data, model_tokenizer_model_config):
    all_models = model_tokenizer_model_config
    predictions = []
    inputs = data.pop("inputs", data)
    for i, (model, tokenizer, model_config) in enumerate(all_models):

        embeddings = tokenizer(
            inputs,
            return_tensors="pt",
            max_length=model_config.traced_sequence_length,
            padding="max_length",
            truncation=True,
        )
        # convert to tuple for neuron model
        neuron_inputs = tuple(embeddings.values())

        with torch.no_grad():
            predictions = model(*neuron_inputs)[0]
            scores = torch.nn.Softmax(dim=1)(predictions)

        for item in scores:
            predictions.append({f"label_{i}": model_config.id2label[item.argmax().item()],
                                f"score_{i}": item.max().item()})
    return predictions

The stack trace given is as follows :

2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Prediction error
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 234, in handle
2022-09-13T12:31:15,797 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     response = self.transform_fn(self.model, input_data, content_type, accept)
2022-09-13T12:31:15,797 [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response time: 7
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 190, in transform_fn
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     predictions = self.predict(processed_data, model)
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/.sagemaker/mms/models/model/code/inference.py", line 44, in predict_fn
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     truncation=True,
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     return forward_call(*input, **kwargs)
2022-09-13T12:31:15,798 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - RuntimeError: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor))
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - 
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During handling of the above exception, another exception occurred:
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - 
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.7/site-packages/mms/service.py", line 108, in predict
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch, self.context)
2022-09-13T12:31:15,799 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File "/opt/conda/lib/python3.7/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 243, in handle
2022-09-13T12:31:15,800 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e), 400)
2022-09-13T12:31:15,800 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor)) : 400

The model loading seems to have worked and the inference part crashes.

The part that I don’t understand is that it seems to be the tokenization that fails, this part :

File "/.sagemaker/mms/models/model/code/inference.py", line 44, in predict_fn
     truncation=True,

corresponds to the tokenization, however the next function call in the stack trace seems to be a forward pass?

  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: forward() expected at most 3 argument(s) but received 6 argument(s). Declaration: forward(__torch__.torch_neuron.runtime.___torch_mangle_422.AwsNeuronGraphModule self, Tensor argument_1, Tensor tensor) -> ((Tensor))

I’m not sure why there would be a forward pass during tokenization and therefore I’m not sure how to proceed. I also tested with some logging to make sure that it was indeed the tokenization step that crashes and not the forward pass (in case of source code misalignment), which is indeed the case.

Any hints would be welcome, thanks a lot!

Best,
Vil

The problem was just some variable issues in the inference code, here’s the fixed code for reference. With it I can fit at least 4 XLM roberta based models on one Inf1.xlarge instance.

import os
from pathlib import Path

from transformers import AutoConfig, AutoTokenizer
import torch
import torch.neuron

# To use one neuron core per worker
os.environ["NEURON_RT_NUM_CORES"] = "4"

N_MODELS = 4

# saved weights name
AWS_NEURON_TRACED_WEIGHTS_NAME = "neuron_model.pt"


def model_fn(model_dir):
    # load tokenizer and neuron model from model_dir
    models_path = Path(model_dir)
    all_models = []
    for i in range(N_MODELS):
        model_path = models_path / f'model_{i}'
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = torch.jit.load(os.path.join(model_path, AWS_NEURON_TRACED_WEIGHTS_NAME))
        model_config = AutoConfig.from_pretrained(model_path)
        all_models.append((model, tokenizer, model_config))
    return all_models


def predict_fn(data, model_tokenizer_model_config):
    all_models = model_tokenizer_model_config
    predictions = []
    inputs = data.pop("inputs", data)
    for i, (model, tokenizer, model_config) in enumerate(all_models):

        embeddings = tokenizer(
            inputs,
            return_tensors="pt",
            max_length=model_config.traced_sequence_length,
            padding="max_length",
            truncation=True,
        )
        # convert to tuple for neuron model
        neuron_inputs = tuple(embeddings.values())

        with torch.no_grad():
            prediction = model(*neuron_inputs)[0]
            scores = torch.nn.Softmax(dim=1)(prediction)

        for item in scores:
            predictions.append({f"label_{i}": model_config.id2label[item.argmax().item()],
                                f"score_{i}": item.max().item()})
    return predictions
3 Likes

This is a great thread

1 Like