Formatting Inference API call for LLama 2

I am trying to call the Hugging Face Inference API to generate text using Llama-2 (specifically, Llama-2-7b-chat-hf). Following this documentation page, I am able to generate text using the following code:

import json
import requests
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf"
headers = {"Authorization": f"Bearer hf_XXXXXXXXXXXXXXXX",
        "Content-Type": "application/json",}
def query(payload):
    data = json.dumps({"inputs": payload})
    response = requests.request("POST", API_URL, headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))

data = query("Can you please let us know more details about your ")

However, when I try to set the context in the input, the Inference API response cannot parse my request.

For example, using my own AWS endpoint, following this, I am able to input this json body to Llama 2:

{
 "inputs": [
  [
   {"role": "system", "content": "You are chat bot who writes songs"},
   {"role": "user", "content": "Write a rap about Barbie"}
  ]
 ],
 "parameters": {"max_new_tokens":256, "top_p":0.9, "temperature":0.6}
}

But in the code above, when I call json.dumps on this json body, I get this error in the resonse:

Failed to deserialize the JSON body into the target type: inputs: invalid type: sequence, expected a string at line 1 column 11. 

I haven’t been able to figure out a formatting that works. Is there an example or documentation anywhere for using Llama 2 with the Hugging Face Inference API?

Hi @peteceptron,

Did you ever end up finding a solution to this? I am in the same boat.

I have setup LLama2 via jumpstart and have inputs very similar to yours. I have written a Flask API that sits in front of the LLM and reads and writes context to a Dynamo DB instance to be able to keep the context of the conversation. When trying to switch over to the hugginface model, as there is more capabilities for bringing up the infrastructure as code, I see your problem exactly.

I’ve spent some time researching around, but can’t seem to find anything related to our issue, other than this post which has had no responses. It seems as though the Huggingface inference api/endpoint is only programmed to take a single string as the input, which doesn’t seem to make sense since clearly the model itself can support more. If this is the case, I will not be able to use this model for my purpose, as I am not sure how I would provide the previous questions/responses for the context of the current question.

Hopefully you were able to figure out your issue, and might be able to provide some feedback.

I solved it by inputting a single string using the official Llama 2 format (see Llama 2 is here - get it on Hugging Face). I don’t know why the default Sagemaker Llama endpoint doesn’t work that way. But this works for me:

import json
import requests
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf"
headers = {"Authorization": f"Bearer hf_XXXXXXXXXXXXXXXXXX",
        "Content-Type": "application/json",}

def query(payload):
    json_body = {
        "inputs": f"[INST] <<SYS>> Your job is to talk like a pirate. Every reponse must sound like a pirate. <<SYS>> {payload} [/INST] ",
                "parameters": {"max_new_tokens":256, "top_p":0.9, "temperature":0.7}
        }
    data = json.dumps(json_body)
    response = requests.request("POST", API_URL, headers=headers, data=data)
    try:
        return json.loads(response.content.decode("utf-8"))
    except:
        return response

data = query("Just say hi!")
print(data[0]['generated_text'].split('[/INST] ')[1])
2 Likes

hi i am trying use the API in my javaScript project, I got this API endpoint from llama 2 hugging face space from " use via API " but getting 404 not found error used exact same code given by hugging face.

import { client } from "@gradio/client";

const app = await client("https://huggingface-projects-llama-2-7b-chat.hf.space/--replicas/gm5p8/");
const result = await app.predict("/chat", [		
				"Howdy!", // string  in 'Message' Textbox component		
				"Howdy!", // string  in 'System prompt' Textbox component		
				1, // number (numeric value between 1 and 2048) in 'Max new tokens' Slider component		
				0.1, // number (numeric value between 0.1 and 4.0) in 'Temperature' Slider component		
				0.05, // number (numeric value between 0.05 and 1.0) in 'Top-p (nucleus sampling)' Slider component		
				1, // number (numeric value between 1 and 1000) in 'Top-k' Slider component		
				1, // number (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
	]);

console.log(result.data);