LLM Zero shot-text classification - How do you answer multiple questions computationally efficiently?

I have been very impressed with some of the latest open-source 7B LLMs, like OpenChat3.5 and OpenHermes2.5.

I have been working on some code to classify text with these models by:

  1. Asking questions about the input text with (specifying a set of expected responses or options)
  2. Calculating the joint output probability of each expected response token sequence using the LLM output

The typical structure is something like:
CLASSIFY_PROMPT = “{instruction}{text_to_classify}{question_and_options}”

I ask a question with specific options as output. E.g. How did the reviewer find the movie? (good/bad/ok)

I then append each option (with the appropriate tokens to specify end of turn and end of sequence) to the CLASSIFY_PROMPT and calculate the joint token probabilities of each output option sequence. My result is something like:
good<eos_token> - 92%
ok<eos_token> - 3%
bad<eos_token> - 0.001%

In the real-world scenario, my text_to_classify is more complex, and I want to ask multiple questions about it.

The “{instruction}{text_to_classify}” will be at the start of every question/option combination, so I was hoping I could save the hidden state (past_key_values) for the “{instruction}{text_to_classify}” and then batch process all my question/option combinations. This should save a huge amount of compute, especially when my text_to_classify has many tokens.

The sad thing is that I couldn’t figure out how to use past_key_values with model() (the __call__ method) for processing batched inputs where

model = AutoModelForCausalLM.from_pretrained("openchat/openchat_3.5")  

Advice would be greatly appreciated. Even if it is not possible with transformers as is, I would appreciate those comments too, because then, at least, I know whether I need to get my hands dirty and code something from scratch.