import os
import gradio as gr
import openai
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import AzureChatOpenAI
from langchain.memory import ConversationBufferWindowMemory
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.prompts import PromptTemplate
from langchain.chains import ConversationalRetrievalChain
from named_tuples import ConversationalRetrievalChainResponse
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_type = "azure"
openai.api_version = os.getenv("OPENAI_API_VERSION")
# The LLM
llm = AzureChatOpenAI(
streaming=True,
callbacks=CallbackManager([StreamingStdOutCallbackHandler()]),
temperature=0.0,
deployment_name=os.getenv("OPENAI_DEPLOYMENT_NAME"),
)
# The Prompt
template = """
...
"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
# The Memory
buffer_window_memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=4
)
# The Embeddings
openai_embeddings = OpenAIEmbeddings(
deployment=os.getenv("OPENAI_EMBEDDINGS_MODEL"),
openai_api_base=os.getenv("OPENAI_API_BASE"),
openai_api_type="azure",
)
# The Knowledge Base
persist_directory = "vector_database/chromadb/"
vectordb = Chroma(
persist_directory=persist_directory, embedding_function=openai_embeddings
)
# The Retriever
retriever = vectordb.as_retriever(search_type="mmr")
# The chain
chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=buffer_window_memory,
combine_docs_chain_kwargs={"prompt": QA_CHAIN_PROMPT},
verbose=True,
)
# The function that takes a message as an argument and returns a streaming LLM resposne
def ask_conversational_retrieval_chain(
question,
) -> ConversationalRetrievalChainResponse:
result = chain({"question": question})
print(result)
return ConversationalRetrievalChainResponse(result=result["answer"])
# The Gradio App
def respond(message, history):
return str(ask_conversational_retrieval_chain(message).result)
with gr.Blocks(title="Chatbot") as demo:
gr.ChatInterface(respond, autofocus=True)
if __name__ == "__main__":
demo.queue(concurrency_count=5, max_size=20).launch()
The function: ask_conversational_retrieval_chain generates a streaming response in the form of chunks and I would like to print those chunks as they’re being generated by the LLM directly to the chat interface.
I’ve came across a few methods for streaming but those methods wait for the whole response to be generated before giving a “streaming like” response.
P.S: I am new to all this, so any help would be appreciated!