Generate answer to a query starting from Decoder Embeddings

Hello,

I am currently working on utilizing an Encoder-Decoder model in a different way.
Specifically, I am employing SwitchTransformers for the encoding section and Flan-T5 for the decoding part, without explicitly passing through the generate() function.

I have successfully processed my document and query through the Encoder, obtaining the Encoder embeddings using the following code snippet:

doc = "The lion (Panthera leo) is a large cat of the genus Panthera, native to Africa and India. (continues..)" # len() = 2018
query = "Which is the native zone of the lion?"

self.tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8", model_max_length=self.MAX_LENGTH) # Set reasonable default for models without max length
self.model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto", quantization_config=quantization_config)

def forward(self, docs, queries):
    # in lightning, forward defines the prediction/inference actions
    tokenized_doc = self.tokenizer(docs, return_tensors="pt", truncation=True, max_length=self.MAX_LENGTH).to(device="cuda")
    tokenized_query = self.tokenizer(queries, return_tensors="pt", truncation=True, max_length=self.MAX_LENGTH).to(device="cuda")

    with torch.no_grad():
        output_doc = self.model(**tokenized_doc, decoder_input_ids=torch.zeros_like(tokenized_doc['input_ids']))         # (batch_size, sequence_length, hidden_size)
        output_query = self.model(**tokenized_query, decoder_input_ids=torch.zeros_like(tokenized_query['input_ids']))   # (batch_size, sequence_length, hidden_size)

    # Get the embeddings for doc and query
    embedding_doc = output_doc.encoder_last_hidden_state
    embedding_query = output_query.encoder_last_hidden_state

    # 'input_ids' from tokenized_doc / tokenized_query are needed for the Decoder Module
    return (embedding_doc, embedding_query), (tokenized_doc['input_ids'], tokenized_query['input_ids'])

Now I successfully passed this latent representation through the Flan-T5 Decoder in the following way:

# Standard call for FlanT5
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", model_max_length=self.MAX_LENGTH)
self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto", quantization_config=quantization_config)

def forward(self, embedding, input_ids):
    # What about query_emb, query_input_ids?
    doc_emb, query_emb = embedding
    doc_input_ids, query_input_ids = input_ids

    decoder_input_ids = self.model._shift_right(doc_input_ids)
    outputs = self.model.base_model.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=doc_emb)
    return self.model.base_model.model.lm_head(outputs['last_hidden_state'])

The resulting tensor shape is torch.Size([1, 497, 32128]) after passing it through the lm_head layer and getting its last_hidden_state.

My primary question is how to generate text from this representation since it is a vector embedding rather than a sequence of token_ids that can be directly passed to the decode() function of the tokenizer.

Additionally, I am questioning myself on how to properly incorporate the query embeddings. It appears that the model is currently generating an output that closely resembles the document, rather than providing an answer to the query.

Any guidance on these matters would be greatly appreciated.

Thank you.

P.S. I used self.model.base_model.model.decoder etc, since I wrapped the model around QLoRA.