diff --git a/rag_demo/vector_chain.py b/rag_demo/vector_chain.py index afd89c0..1f5f320 100644 --- a/rag_demo/vector_chain.py +++ b/rag_demo/vector_chain.py @@ -1,6 +1,5 @@ from json import loads, dumps from langchain.prompts.prompt import PromptTemplate - from langchain_community.vectorstores import Neo4jVector from langchain.chains import RetrievalQAWithSourcesChain from langchain.chains.conversation.memory import ConversationBufferMemory @@ -11,7 +10,7 @@ VECTOR_PROMPT_TEMPLATE = """Human: You are a Financial expert with SEC filings who can answer questions only based on the context below. * Answer the question STRICTLY based on the context provided in JSON below. -* Do not assume or retrieve any information outside of the context +* Do not assume or retrieve any information outside of the context * Use three sentences maximum and keep the answer concise * Think step by step before answering. * Do not return helpful or extra text or apologies @@ -52,10 +51,10 @@ username = st.secrets["NEO4J_USERNAME"] password = st.secrets["NEO4J_PASSWORD"] - vector_store = None try: - logging.debug(f"Attempting to retrieve existing vector index: {index_name}...") + logging.debug( + f"Attempting to retrieve existing vector index: {index_name}...") vector_store = Neo4jVector.from_existing_index( embedding=EMBEDDING_MODEL, url=url, @@ -82,7 +81,8 @@ ) logging.debug(f"Created new index: {index_name}") except Exception as e: - logging.error(f"Failed to retrieve existing or to create a Neo4jVector: {e}") + logging.error( + f"Failed to retrieve existing or to create a Neo4jVector: {e}") if vector_store is None: logging.error(f"Failed to retrieve or create a Neo4jVector. Exiting.") @@ -133,26 +133,42 @@ def get_results(question) -> str: return result - # Using the vector store directly. But this could blow out the token count -# @retry(tries=5, delay=5) -# def get_results(question)-> str: -# """Generate response using Neo4jVector using vector index only -# Args: -# question (str): User query -# Returns: -# str: Formatted string answer with citations, if available. -# """ +@retry(tries=2, delay=5) +def get_results_minimized_tokens(question) -> str: + """Generate response using Neo4jVector with minimized token usage -# logging.info(f'Using Neo4j url: {url}') + Args: + question (str): User query -# # Returns a dict with keys: answer, sources -# vector_result = vector_store.similarity_search(question, k=3) + Returns: + str: Formatted string answer with citations, if available. + """ + logging.info(f"Using Neo4j url: {url}") -# logging.debug(f'chain_result: {vector_result}') + vector_result = vector_store.similarity_search(question, k=3) + context = { + "input": question, + "context": [doc.page_content for doc in vector_result] + } -# result = vector_result + chain_result = vector_chain.invoke( + {"question": question, "context": dumps(context)}, + prompt=VECTOR_PROMPT, + return_only_outputs=True, + ) + + logging.debug(f"chain_result: {chain_result}") -# return result + result = chain_result["answer"] + + # Cite sources, if any + sources = chain_result["sources"] + sources_split = sources.split(", ") + for source in sources_split: + if source != "" and source != "N/A" and source != "None": + result += f"\n - [{source}]({source})" + + return result