Hi all,
I have deployed a llama3 model to AWS sagemaker, something like this:
resource "aws_sagemaker_model" "mymodel" {
name = "mymodel"
execution_role_arn = ...
primary_container {
image = "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04-v2.0"
environment = {
HF_TASK = "question-answering"
HF_MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
HF_API_TOKEN = "mytoken"
}
}
}
I then have a python script using langchain that retrieves docs from a kendra index and sends it to sagemaker as context for it to be able to answer questions.
retriever = AmazonKendraRetriever(...)
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": {**model_kwargs}})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]['generated_text']
content_handler = ContentHandler()
sagemaker_client = boto3.client("sagemaker-runtime")
model_parameters = {
"max_new_tokens": 250,
"temperature":0.7,
"truncate": None,
"return_full_text": True,
"best_of": 1,
}
llm = SagemakerEndpoint(
endpoint_name=sm_endpoint_name,
client=sagemaker_client,
model_kwargs=model_parameters,
content_handler=content_handler,
)
# llm = Ollama(model="llama3")
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Your answer should be concise.
{context}
Question: {question}
Helpful Answer:"""
custom_rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| custom_rag_prompt
| llm
| StrOutputParser()
)
response = rag_chain.invoke("what is the minimum age to get a driver license?")
print (response)
I have been noticing however is that sagemaker’s responses using Text Generation Inference (TGI) sometimes generates some random text that is not relevant to the answer and this doesn’t seem to happen that often when I use a local llama3 model via Ollama (commented code in the script)
For example, this is one of the replies, as you can see, the correct answer is there, but the model starts to randomly generate unhelpful text:
Question: what is the minimum age to get a driver license?
Helpful Answer: The minimum age to get a driver license is 16 years old, as stated in the documents.
If you know the answer, provide it. If you don't know the answer, say "I don't know".
Don't try to make up an answer.
Your answer should be concise.
For example:
The minimum age to get a driver license is 16 years old.
I don't know.
The answer is not stated in the documents.
The minimum age to get a driver license is 17 years old.
The minimum age to get a driver license is 18 years old.
The minimum age to get a driver license is 19 years old.
The minimum age to get a driver license is 21 years old.
The minimum age to get a driver license is 25 years old.
The minimum age to get a driver license is 30 years old.
The minimum age to get a driver license is 35 years old.
The minimum age to get a driver license is 40 years old. ','reference_data_url': 'https://en.wikipedia.org/wiki/Driver%27s