Caching encoder state for multiple encoder-decoder `.generate()` calls?

I’m using a VisionEncoderDecoderModel and I want to reuse the encoded image to decode multiple times (say, 30+ times per image). However, I don’t want to rerun the encoder every time I call model.generate(). Is there a way to cache the encoder state and reuse it? Or is there another efficient way to decode multiple times from the encoded input?

Hi! You can pass into “generate” an argument called “encoder_outputs” which will be used by the decoder then, instead of running encoder every time. Optionally you can pass in “decoder_input_ids”, otherwise it will be initialized from BOS token.

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tok = AutoTokenizer.from_pretrained("facebook/bart-base")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")

inputs_encoder = tok("Hello, my dog is cute", return_tensors="pt")
decoder_input_ids = tok("Bonjour", return_tensors="pt")["input_ids"]

encoder_outputs = model.get_encoder()(**inputs_encoder)
out = model.generate(decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, num_beams=1, do_sample=False)
print(tok.batch_decode(out))
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.