RAG for Reading Comprehension

Hi,

I am currently working on a Reading Comprehension task. I was thinking of using RAG for answer generation instead of using a span extraction model.

Find the starter code, question & passage below -

query = “Who is Adam’s sister?”
passage = “Adam is Bob’s friend. Bob was born in 1906. Bob married Angela and they are now happily living together. Angela is Adam’s sister. Angela lives in Los Angeles. Bob has a dog and its name is Moxie. Adam likes Bob because Bob is a kind person. Adam has 2 kids.”

import torch
from transformers import RagConfig, RagRetriever, RagTokenForGeneration, RagTokenizer, RagSequenceForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
nq_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
nq_model.rag.config.n_docs = 1
_ = nq_model.to("cuda:0")
_ = nq_model.eval()

num_beams=1
min_length=10
max_length=16

device = "cuda:0"
input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
passage_ids = tokenizer(passage, return_tensors="pt").input_ids.to(device)
passage_attention_mask = tokenizer(passage, return_tensors="pt").attention_mask.to(device)


generated_ids = nq_model.generate(
    input_ids=input_ids,
    context_input_ids=passage_ids,
    context_attention_mask=passage_attention_mask,
    doc_scores=torch.tensor([[100.0]]).to("cuda:0"),
    num_beams=1,
    num_return_sequences=1,
    min_length=min_length,
    max_length=max_length,
    length_penalty=1.0,
)
answer_texts = [
    tokenizer.generator.decode(gen_seq.tolist(), skip_special_tokens=True).strip() for gen_seq in generated_ids
]

print(answer_texts)

the model outputs ‘exit bar houses j j j j j’.

Any thoughts on what might be wrong here?
Thanks!

Maybe related: the latest work of google AI found that often RAG does not use context at all.