Incremental decoding with T5

Recently, we have seen evidence that in a variety of tasks, it may be helpful for a model to attend over intermediate computation steps when solving a task. An example is ReAct: Synergizing Reasoning and Acting in Language Models – Google AI Blog (googleblog.com). The authors cite some work from the neural program synthesis community where this approach was found beneficial.

Let’s assume we are processing conversations, where the context is progressively longer as the user and agent interact. Typically, we would re-encode the dialogue history and generate the answer from scratch for every interaction. Schematically, this could be represented as follows:

step 1: [usr] sent_1answer_1
step 2: [usr] sent_1 [agent] sent_1 [usr] sent_2answer_2

step k: [usr] sent_1 [agent] sent_1 [usr] sent_2 ... [agent] sent_k [user] sent_k answer_k

Above sent is just an abbreviation for “sentence”. The LHS of “->” is the encoder input, the “RHS” is the decoder output. However, the answers are highly correlated, so arguably the model could predict more consistently if it was asked to show all the reasoning steps as the conversation progresses, instead of producing a single answer for the task. Schematically:

step 1: [usr] sent_1answer_1
step 2: [usr] sent_1 [agent] sent_1 [usr] sent_2answer_1 <sep> answer_2

step k: [usr] sent_1 [agent] sent_1 [usr] sent_2 ... [agent] sent_k [user] sent_k → answer_1 <sep> answer_2 <sep><sep> answer_k

In inference, this is problematic because concatenating the answers can lead to very long sequences if everything was generated from scratch. However, I was wondering if the use_cache feature together with the past_key_value could be used to effectively implement a memory on the decoder side? In the above, after we decode answer_1 we feed back the keys and values generated during decoding as past_key_values to decode answer_2. Then we would feed back the outputs to generate answer_3 and so on. So the model could attend over an updated conversational context and its past answers but would not “revise” all its previous answers.

@patrickvonplaten, am I naive to think that the caching during inference could be implemented with huggingface as is?