From 8e3f9bcc4f99f279ef70743c33e0c7d5fa3c1f20 Mon Sep 17 00:00:00 2001 From: John Watson Date: Thu, 9 Jan 2025 15:38:02 -0800 Subject: [PATCH] provide some configuration options for tweaking the chunk retrieval strategies (#99) * provide some configuration options for tweaking the chunk retrieval strategies * remove unused endpoint * make a method "private" * remove unused import --- llm-service/app/rag_types.py | 2 + llm-service/app/services/chat.py | 42 ++++---- llm-service/app/services/llm_completion.py | 12 ++- .../app/services/{qdrant.py => querier.py} | 101 +++++++++++++++--- llm-service/app/tests/services/test_chat.py | 2 - ui/src/api/chatApi.ts | 8 ++ .../Placeholders/SuggestedQuestionsCards.tsx | 2 + .../FooterComponents/RagChatQueryInput.tsx | 2 + 8 files changed, 134 insertions(+), 37 deletions(-) rename llm-service/app/services/{qdrant.py => querier.py} (56%) diff --git a/llm-service/app/rag_types.py b/llm-service/app/rag_types.py index a6a0be84..71dc96bb 100644 --- a/llm-service/app/rag_types.py +++ b/llm-service/app/rag_types.py @@ -50,3 +50,5 @@ class RagPredictConfiguration(BaseModel): chunk_size: int = 512 model_name: str = DEFAULT_BEDROCK_LLM_MODEL exclude_knowledge_base: Optional[bool] = False + use_question_condensing: Optional[bool] = True + use_hyde: Optional[bool] = False diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat.py index 0db07d38..4ca6c16e 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat.py @@ -45,9 +45,7 @@ from llama_index.core.base.llms.types import MessageRole from llama_index.core.chat_engine.types import AgentChatResponse -from ..ai.vector_stores.qdrant import QdrantVectorStore -from ..rag_types import RagPredictConfiguration -from . import evaluators, qdrant +from . import evaluators, querier from .chat_store import ( ChatHistoryManager, Evaluation, @@ -55,13 +53,15 @@ RagPredictSourceNode, RagStudioChatMessage, ) +from ..ai.vector_stores.qdrant import QdrantVectorStore +from ..rag_types import RagPredictConfiguration def v2_chat( - session_id: int, - data_source_ids: list[int], - query: str, - configuration: RagPredictConfiguration, + session_id: int, + data_source_ids: list[int], + query: str, + configuration: RagPredictConfiguration, ) -> RagStudioChatMessage: response_id = str(uuid.uuid4()) @@ -84,7 +84,7 @@ def v2_chat( timestamp=time.time(), ) - response = qdrant.query( + response = querier.query( data_source_id, query, configuration, @@ -146,10 +146,10 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod def generate_suggested_questions( - configuration: RagPredictConfiguration, - data_source_ids: list[int], - data_source_size: int, - session_id: int, + configuration: RagPredictConfiguration, + data_source_ids: list[int], + data_source_size: int, + session_id: int, ) -> List[str]: data_source_id = data_source_ids[0] chat_history = retrieve_chat_history(session_id) @@ -168,16 +168,16 @@ def generate_suggested_questions( ) if chat_history: query_str = ( - query_str - + ( - "I will provide a response from my last question to help with generating new questions." - " Consider returning questions that are relevant to the response" - " They might be follow up questions or questions that are related to the response." - " Here is the last response received:\n" - ) - + chat_history[-1].content + query_str + + ( + "I will provide a response from my last question to help with generating new questions." + " Consider returning questions that are relevant to the response" + " They might be follow up questions or questions that are related to the response." + " Here is the last response received:\n" + ) + + chat_history[-1].content ) - response = qdrant.query( + response = querier.query( data_source_id, query_str, configuration, diff --git a/llm-service/app/services/llm_completion.py b/llm-service/app/services/llm_completion.py index badf6b6f..a42a04d7 100644 --- a/llm-service/app/services/llm_completion.py +++ b/llm-service/app/services/llm_completion.py @@ -38,10 +38,11 @@ import itertools from llama_index.core.base.llms.types import ChatMessage, ChatResponse +from llama_index.core.llms import LLM -from ..rag_types import RagPredictConfiguration from .chat_store import ChatHistoryManager, RagStudioChatMessage from .models import get_llm +from ..rag_types import RagPredictConfiguration def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]: @@ -51,7 +52,7 @@ def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]: def completion( - session_id: int, question: str, configuration: RagPredictConfiguration + session_id: int, question: str, configuration: RagPredictConfiguration ) -> ChatResponse: model = get_llm(configuration.model_name) chat_history = ChatHistoryManager().retrieve_chat_history(session_id)[:10] @@ -62,3 +63,10 @@ def completion( ) messages.append(ChatMessage.from_str(question, role="user")) return model.chat(messages) + + +def hypothetical(question: str, configuration: RagPredictConfiguration) -> str: + model: LLM = get_llm(configuration.model_name) + prompt: str = (f"You are an expert. You are asked: {question}. " + "Produce a brief document that would hypothetically answer this question.") + return model.complete(prompt).text diff --git a/llm-service/app/services/qdrant.py b/llm-service/app/services/querier.py similarity index 56% rename from llm-service/app/services/qdrant.py rename to llm-service/app/services/querier.py index f4a941f8..7d7a8522 100644 --- a/llm-service/app/services/qdrant.py +++ b/llm-service/app/services/querier.py @@ -36,30 +36,101 @@ # DATA. # ############################################################################## import logging +from typing import Optional, List, Any import botocore.exceptions from fastapi import HTTPException -from llama_index.core.base.llms.types import ChatMessage +from llama_index.core import QueryBundle, PromptTemplate +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.callbacks import trace_method from llama_index.core.chat_engine import CondenseQuestionChatEngine from llama_index.core.chat_engine.types import AgentChatResponse from llama_index.core.indices import VectorStoreIndex from llama_index.core.indices.vector_store import VectorIndexRetriever +from llama_index.core.llms import LLM from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import get_response_synthesizer +from . import models, llm_completion +from .chat_store import RagContext from ..ai.vector_stores.qdrant import QdrantVectorStore from ..rag_types import RagPredictConfiguration -from . import models -from .chat_store import RagContext logger = logging.getLogger(__name__) +CUSTOM_TEMPLATE = """\ +Given a conversation (between Human and Assistant) and a follow up message from Human, \ +rewrite the message to be a standalone question that captures all relevant context \ +from the conversation. Just provide the question, not any description of it. + + +{chat_history} + + +{question} + + +""" + +CUSTOM_PROMPT = PromptTemplate(CUSTOM_TEMPLATE) + + +class FlexibleChatEngine(CondenseQuestionChatEngine): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._configuration : RagPredictConfiguration = RagPredictConfiguration() + + @property + def configuration(self) -> RagPredictConfiguration: + return self._configuration + + @configuration.setter + def configuration(self, value: RagPredictConfiguration) -> None: + self._configuration = value + + @trace_method("chat") + def chat( + self, message: str, chat_history: Optional[List[ChatMessage]] = None + ) -> AgentChatResponse: + chat_history = chat_history or self._memory.get(input=message) + + if self.configuration.use_question_condensing: + # Generate standalone question from conversation context and last message + condensed_question = self._condense_question(chat_history, message) + log_str = f"Querying with condensed question: {condensed_question}" + logger.info(log_str) + if self._verbose: + print(log_str) + message = condensed_question + + embedding_strings = None + if self.configuration.use_hyde: + hypothetical = llm_completion.hypothetical(message, self.configuration) + logger.info(f"hypothetical document: {hypothetical}") + embedding_strings = [hypothetical] + + # Query with standalone question + query_bundle = QueryBundle(message, custom_embedding_strs=embedding_strings) + query_response = self._query_engine.query(query_bundle) + + tool_output = self._get_tool_output_from_response( + message, query_response + ) + + # Record response + self._memory.put(ChatMessage(role=MessageRole.USER, content=message)) + self._memory.put( + ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response)) + ) + + return AgentChatResponse(response=str(query_response), sources=[tool_output]) + def query( - data_source_id: int, - query_str: str, - configuration: RagPredictConfiguration, - chat_history: list[RagContext], + data_source_id: int, + query_str: str, + configuration: RagPredictConfiguration, + chat_history: list[RagContext], ) -> AgentChatResponse: qdrant_store = QdrantVectorStore.for_chunks(data_source_id) vector_store = qdrant_store.llama_vector_store() @@ -75,17 +146,13 @@ def query( similarity_top_k=configuration.top_k, embed_model=embedding_model, # is this needed, really, if it's in the index? ) - # TODO: factor out LLM and chat engine into a separate function llm = models.get_llm(model_name=configuration.model_name) response_synthesizer = get_response_synthesizer(llm=llm) query_engine = RetrieverQueryEngine( retriever=retriever, response_synthesizer=response_synthesizer ) - chat_engine = CondenseQuestionChatEngine.from_defaults( - query_engine=query_engine, - llm=llm, - ) + chat_engine = _build_chat_engine(configuration, llm, query_engine) logger.info("querying chat engine") chat_messages = list( @@ -106,3 +173,13 @@ def query( status_code=json_error["ResponseMetadata"]["HTTPStatusCode"], detail=json_error["message"], ) from error + + +def _build_chat_engine(configuration: RagPredictConfiguration, llm: LLM, query_engine: RetrieverQueryEngine)-> FlexibleChatEngine: + chat_engine: FlexibleChatEngine = FlexibleChatEngine.from_defaults( + query_engine=query_engine, + llm=llm, + condense_question_prompt=CUSTOM_PROMPT, + ) + chat_engine.configuration = configuration + return chat_engine diff --git a/llm-service/app/tests/services/test_chat.py b/llm-service/app/tests/services/test_chat.py index 4a24f221..006e5c95 100644 --- a/llm-service/app/tests/services/test_chat.py +++ b/llm-service/app/tests/services/test_chat.py @@ -79,8 +79,6 @@ def suggested_questions_responses( class TestProcessResponse: - # todo: add below for a failing case to be fixed! - # @reproduce_failure('6.122.1', b'AAAJAQAADQABAAEEAAABCgEAAQ0BAAEMAAAADQEAAA0A') @given(response=suggested_questions_responses()) @example(response="Empty Response") def test_process_response(self, response: str) -> None: diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index f84dac77..235265f0 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -69,6 +69,8 @@ export interface QueryConfiguration { top_k: number; model_name: string; exclude_knowledge_base: boolean; + use_question_condensing: boolean; + use_hyde: boolean; } export interface ChatMutationRequest { @@ -200,6 +202,7 @@ const chatMutation = async ( export const createQueryConfiguration = ( excludeKnowledgeBase: boolean, + forSuggestedQuestions: boolean, activeSession?: Session, ): QueryConfiguration => { // todo: maybe we should just throw an exception here? @@ -208,11 +211,16 @@ export const createQueryConfiguration = ( top_k: 5, model_name: "", exclude_knowledge_base: false, + use_question_condensing: false, + use_hyde: false, }; } + return { top_k: activeSession.responseChunks, model_name: activeSession.inferenceModel ?? "", exclude_knowledge_base: excludeKnowledgeBase, + use_question_condensing: !forSuggestedQuestions, + use_hyde: false, }; }; diff --git a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx index 546798f0..56733b81 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx @@ -58,6 +58,7 @@ const SuggestedQuestionsCards = () => { data_source_ids: activeSession?.dataSourceIds ?? [], configuration: createQueryConfiguration( excludeKnowledgeBase, + true, activeSession, ), session_id: sessionId ?? "", @@ -86,6 +87,7 @@ const SuggestedQuestionsCards = () => { session_id: sessionId, configuration: createQueryConfiguration( excludeKnowledgeBase, + false, activeSession, ), }); diff --git a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx index 6187dd42..61c741f4 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx @@ -63,6 +63,7 @@ const RagChatQueryInput = () => { const configuration = createQueryConfiguration( excludeKnowledgeBase, + true, activeSession, ); const { @@ -99,6 +100,7 @@ const RagChatQueryInput = () => { session_id: sessionId, configuration: createQueryConfiguration( excludeKnowledgeBase, + false, activeSession, ), });