How to prompt Llama2 for text classification?

Here is my script:

from transformers import AutoTokenizer, AutoModelForCausalLM   

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

prompt = """
A message can be classified as one of the following categories: book, cancel, change.

Based on these categories, classify this message:
I would like to cancel my booking and ask for a refund.
"""

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))

Here is the output:

<s> 
A message can be classified as one of the following categories: book, cancel, change.

Based on these categories, classify this message:
I would like to cancel my booking and ask for a refund.

Please select one of the following options:

book
cancel
change</s>

I would like to know how to design a prompt so that Llama-2 can give me “cancel” as the answer.

1 Like