My code -
import numpy as np
import torch
import faiss
from transformers import RagTokenizer, RagTokenForGeneration, DPRContextEncoder, DPRContextEncoderTokenizer
from datasets import Dataset
# Load DPR model and tokenizer
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
# Load the FAISS index
index = faiss.read_index("faiss_index.index")
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Define a function to retrieve documents using FAISS
def retrieve(query, index, k=5):
# Convert query to embedding
inputs = context_tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
query_embedding = context_encoder(**inputs).pooler_output.numpy().flatten()
# Perform the search in FAISS index
distances, indices = index.search(np.array([query_embedding]), k)
return indices
# Load the RAG model and tokenizer
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
# Function to query the RAG model with retrieved documents
def query_rag(query):
# Retrieve document indices
indices = retrieve(query, index)
# Prepare the context for RAG model
context = [chunks[i] for i in indices[0]] # Use 'chunks' which should be globally accessible
# Tokenize the context
context_inputs = tokenizer(context, return_tensors="pt", padding=True, truncation=True, max_length=512, truncation_strategy="longest_first")
# Prepare input_ids for the query
input_ids = tokenizer(query, return_tensors="pt")["input_ids"]
# Generate the response
generated = model.generate(
input_ids=input_ids,
context_input_ids=context_inputs["input_ids"],
context_attention_mask=context_inputs["attention_mask"], # Include attention mask
max_length=256
)
# Decode the generated response
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
# Example text file read and chunk creation
with open("rag_context.txt", "r") as f:
context = f.read()
def chunk_text(text, chunk_size=500):
words = text.split()
for i in range(0, len(words), chunk_size):
yield ' '.join(words[i:i + chunk_size])
chunks = list(chunk_text(context))
# Query the model
query = "Tell me about MPL's business model"
response = query_rag(query)
print(response)
In model.generate
, I am getting AttributeError: 'NoneType' object has no attribute 'repeat_interleave'
Not able to understand what am I doing wrong here.
Code to generate fiass index
import numpy as np
with open("rag_context.txt", "r") as f:
context = f.read()
def chunk_text(text, chunk_size=500):
words = text.split()
for i in range(0, len(words), chunk_size):
yield ' '.join(words[i:i + chunk_size])
chunks = list(chunk_text(context))
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
# Load DPR model and tokenizer
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
def get_embeddings(texts):
embeddings = []
for text in texts:
inputs = context_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
emb = context_encoder(**inputs).pooler_output
embeddings.append(emb.numpy().flatten())
return embeddings
# Generate embeddings for each chunk
embedding_matrix = np.array(get_embeddings(chunks))
import faiss
# Create a FAISS index
dimension = embedding_matrix.shape[1] # Dimensionality of embeddings
index = faiss.IndexFlatL2(dimension)
index.add(embedding_matrix)
# Save the FAISS index
faiss.write_index(index, "faiss_index.index")