I am trying to deploy a fine tuned mistralai/Mistral-7B-v0.1
model to Sagemaker. Along with the base model, we also have an adapter to load. The adapter files and model files are packaged into a .tar.gz
file.
Here is our inference.py
import torch
import json
from transformers import AutoTokenizer, AutoModel
from torch import nn
class AutoModelForSentenceEmbedding(nn.Module):
def __init__(self, base_model, tokenizer, normalize=True):
super(AutoModelForSentenceEmbedding, self).__init__()
self.model = base_model # , load_in_8bit=True, device_map={"":0})
self.normalize = normalize
self.tokenizer = tokenizer
def forward(self, **kwargs):
ea = self.get_embeddings(kwargs['input_ids'], kwargs['attention_mask'])
ep = self.get_embeddings(kwargs['input_ids_positive'], kwargs['attention_mask_positive'])
en = self.get_embeddings(kwargs['input_ids_negative'], kwargs['attention_mask_negative'])
return ea, ep, en
def get_embeddings(self, input_ids, attention_mask):
model_output = self.model(input_ids, attention_mask)
embeddings = self.mean_pooling(model_output, attention_mask)
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
def model_fn(model_dir, context=None):
model = AutoModel.from_pretrained(model_dir, device_map='auto', load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.load_adapter(f"{model_dir}/adapter")
print("loaded adapter")
return AutoModelForSentenceEmbedding(model, tokenizer)
def transform_fn(model, input_data, content_type, accept):
# decode the input data (e.g. JSON string -> dict)
data = decoder_encoder.decode(input_data, content_type)
# call your custom model with the data
outputs = model.get_embeddings("blast")
# convert the model output to the desired output format (e.g. dict -> JSON string)
response = decoder_encoder.encode(output, accept)
return response
requirements.py
accelerate
transformers
bitsandbytes
And our notebook has
from sagemaker.huggingface.model import HuggingFaceModel
huggingface_model = HuggingFaceModel(
model_data=s3_model_uri, # path to your model and script
role=role, # iam role with permissions to create an Endpoint
transformers_version="4.26", # Transformers version used
pytorch_version="1.13", # PyTorch version used
py_version='py39'
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g4dn.xlarge',
endpoint_name='query-embeddings'
)
We are encountering the following error when we are calling the predict.
com.amazonaws.ml.mms.wlm.WorkerLifeCycle - KeyError: 'mistral'
From the looks of it, it looks like a transformer version mismatch but we are not able to figure out what is the correct version to go for. Has someone deployed a fine tuned mistral model before? If so, what is the correct approach?