I have a finetuned LLAMA 2 model which I am trying to deploy of sagemaker using sagemaker.hugging_face. I have tried deploying a model finetuned elsewhere as well as deploying a model finetuned on sagemkaer endpoint. From both of which I am getting the below error :
Error:
‘ValueError: Could not load model /opt/ml/model with any of the following classes: (<class ‘transformers.models.auto.modeling_auto.AutoModelForCausalLM’>, <class ‘transformers.models.llama.modeling_llama.LlamaForCausalLM’>).’
Below is my structure of model.tar.tz
├── code
│ ├── inference.py
│ └── requirements.txt
├── config.json
├── generation_config.json
├── model-00007-of-00007.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.json
└── tokenizer.model
├── special_tokens_map
├── adapter_config.json
└── adapter_model.bin
I have finetuned a LLAMA-2-7b Chat using transformers.Trainer and PEFT library .
I am trying to deploy it on Sagemaker endpoint. Below is my code for deployment.
config = {
‘HF_TASK’:‘text-generation’,
}
huggingface_model = HuggingFaceModel
(
model_data=‘s3://…/output/model.tar.gz’,
role=role,
transformers_version=“4.28”,
pytorch_version=“2.0”,
py_version=“py310”,
model_server_workers=1,
env=config,
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type= “ml.g5.4xlarge”,
container_startup_health_check_timeout=600,
)
payload ={
“inputs”: prompt,
“parameters”:
{
“do_sample”: True,
“top_p”:0.7,
“temperature”:0.3,
“top_k”:50,
“max_new_tokens”:50,
“repetition_penalty”:1.03
}
}
predictor.predict(payload)
I am receiving the following/above error while running predict .
Error:
‘ValueError: Could not load model /opt/ml/model with any of the following classes: (<class ‘transformers.models.auto.modeling_auto.AutoModelForCausalLM’>, <class ‘transformers.models.llama.modeling_llama.LlamaForCausalLM’>).’
Now, I have tried multiple approaches .
Approach 1: Finetuning the model using AWS sagemaker endpoint (HuggingFace estimator) . Adding the saved model.tar.gz file path(estimator.model_data) .to model_data=huggingface_estimator.model_data. (No inference.py file, i.e. default inference code provided by sagemaker.hugging_face)
Approach 2: Finetuning the model locally > uploading to s3 bucket and adding the model.tar.gz file path to model_data=“s3://…/model.tar.tz” (The directory contains code/inference.py of which i have given the structure above.)
In both the cases I am facing the same issue/error as above…Now here is my inference.py file and requirements.txt file from Approach 2.
######## inference.py
def model_fn (model_dir):
compute_dtype = getattr(torch, "float16")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
quant_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
model_dir,
quantization_config=quant_config,
device_map={"": 0}
)
return base_model, tokenizer
def input_fn(input_data, content_type):
print(input_data)
print(content_type)
sentences = decoder_encoder.decode(input_data, content_type)
print(sentences)
return sentences
def predict_fn(data, model):
model, tokenizer = model
input_ids = tokenizer.encode(data, padding="max_length", truncation=True, return_tensors="pt")
summary_ids = model.generate(input_ids=input_ids.cuda(), max_length=100, num_beams=4, do_sample=True)
output = tokenizer.decode(summary_ids[0])
return {'vectors': output}
def output_fn(prediction, accept):
response = decoder_encoder.encode(prediction, accept)
print(response)
return response
######## requirements.txt
accelerate==0.16.0
transformers==4.26.0
bitsandbytes==0.37.0
Now, I read somewhere that this issue might be related to Tensorflow and using transformers.pipeline() might solve the issue .
If that is the case can someone please help me on how my inference.py file should look like or if any other file should be present in my model/code folder if I am missing any.???
OR
If I have written the inference.py file wrong. Can you help me on how should it look like.???
OR
If the error is coming from some other areas ???
Thanks in advance for the help!!!