I have been very impressed with some of the latest open-source 7B LLMs, like OpenChat3.5 and OpenHermes2.5.
I have been working on some code to classify text with these models by:
- Asking questions about the input text with (specifying a set of expected responses or options)
- Calculating the joint output probability of each expected response token sequence using the LLM output
The typical structure is something like:
CLASSIFY_PROMPT = “{instruction}{text_to_classify}{question_and_options}”
I ask a question with specific options as output. E.g. How did the reviewer find the movie? (good/bad/ok)
I then append each option (with the appropriate tokens to specify end of turn and end of sequence) to the CLASSIFY_PROMPT and calculate the joint token probabilities of each output option sequence. My result is something like:
good<eos_token> - 92%
ok<eos_token> - 3%
bad<eos_token> - 0.001%
In the real-world scenario, my text_to_classify is more complex, and I want to ask multiple questions about it.
The “{instruction}{text_to_classify}” will be at the start of every question/option combination, so I was hoping I could save the hidden state (past_key_values) for the “{instruction}{text_to_classify}” and then batch process all my question/option combinations. This should save a huge amount of compute, especially when my text_to_classify has many tokens.
The sad thing is that I couldn’t figure out how to use past_key_values with model() (the __call__ method) for processing batched inputs where
model = AutoModelForCausalLM.from_pretrained("openchat/openchat_3.5")
Advice would be greatly appreciated. Even if it is not possible with transformers as is, I would appreciate those comments too, because then, at least, I know whether I need to get my hands dirty and code something from scratch.