Skip to content

Commit

Permalink
provide some configuration options for tweaking the chunk retrieval s…
Browse files Browse the repository at this point in the history
…trategies (#99)

* provide some configuration options for tweaking the chunk retrieval strategies

* remove unused endpoint

* make a method "private"

* remove unused import
  • Loading branch information
jkwatson authored Jan 9, 2025
1 parent a982e04 commit 8e3f9bc
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 37 deletions.
2 changes: 2 additions & 0 deletions llm-service/app/rag_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@ class RagPredictConfiguration(BaseModel):
chunk_size: int = 512
model_name: str = DEFAULT_BEDROCK_LLM_MODEL
exclude_knowledge_base: Optional[bool] = False
use_question_condensing: Optional[bool] = True
use_hyde: Optional[bool] = False
42 changes: 21 additions & 21 deletions llm-service/app/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse

from ..ai.vector_stores.qdrant import QdrantVectorStore
from ..rag_types import RagPredictConfiguration
from . import evaluators, qdrant
from . import evaluators, querier
from .chat_store import (
ChatHistoryManager,
Evaluation,
RagContext,
RagPredictSourceNode,
RagStudioChatMessage,
)
from ..ai.vector_stores.qdrant import QdrantVectorStore
from ..rag_types import RagPredictConfiguration


def v2_chat(
session_id: int,
data_source_ids: list[int],
query: str,
configuration: RagPredictConfiguration,
session_id: int,
data_source_ids: list[int],
query: str,
configuration: RagPredictConfiguration,
) -> RagStudioChatMessage:
response_id = str(uuid.uuid4())

Expand All @@ -84,7 +84,7 @@ def v2_chat(
timestamp=time.time(),
)

response = qdrant.query(
response = querier.query(
data_source_id,
query,
configuration,
Expand Down Expand Up @@ -146,10 +146,10 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod


def generate_suggested_questions(
configuration: RagPredictConfiguration,
data_source_ids: list[int],
data_source_size: int,
session_id: int,
configuration: RagPredictConfiguration,
data_source_ids: list[int],
data_source_size: int,
session_id: int,
) -> List[str]:
data_source_id = data_source_ids[0]
chat_history = retrieve_chat_history(session_id)
Expand All @@ -168,16 +168,16 @@ def generate_suggested_questions(
)
if chat_history:
query_str = (
query_str
+ (
"I will provide a response from my last question to help with generating new questions."
" Consider returning questions that are relevant to the response"
" They might be follow up questions or questions that are related to the response."
" Here is the last response received:\n"
)
+ chat_history[-1].content
query_str
+ (
"I will provide a response from my last question to help with generating new questions."
" Consider returning questions that are relevant to the response"
" They might be follow up questions or questions that are related to the response."
" Here is the last response received:\n"
)
+ chat_history[-1].content
)
response = qdrant.query(
response = querier.query(
data_source_id,
query_str,
configuration,
Expand Down
12 changes: 10 additions & 2 deletions llm-service/app/services/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@
import itertools

from llama_index.core.base.llms.types import ChatMessage, ChatResponse
from llama_index.core.llms import LLM

from ..rag_types import RagPredictConfiguration
from .chat_store import ChatHistoryManager, RagStudioChatMessage
from .models import get_llm
from ..rag_types import RagPredictConfiguration


def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]:
Expand All @@ -51,7 +52,7 @@ def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]:


def completion(
session_id: int, question: str, configuration: RagPredictConfiguration
session_id: int, question: str, configuration: RagPredictConfiguration
) -> ChatResponse:
model = get_llm(configuration.model_name)
chat_history = ChatHistoryManager().retrieve_chat_history(session_id)[:10]
Expand All @@ -62,3 +63,10 @@ def completion(
)
messages.append(ChatMessage.from_str(question, role="user"))
return model.chat(messages)


def hypothetical(question: str, configuration: RagPredictConfiguration) -> str:
model: LLM = get_llm(configuration.model_name)
prompt: str = (f"You are an expert. You are asked: {question}. "
"Produce a brief document that would hypothetically answer this question.")
return model.complete(prompt).text
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,101 @@
# DATA.
# ##############################################################################
import logging
from typing import Optional, List, Any

import botocore.exceptions
from fastapi import HTTPException
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core import QueryBundle, PromptTemplate
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.callbacks import trace_method
from llama_index.core.chat_engine import CondenseQuestionChatEngine
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core.llms import LLM
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer

from . import models, llm_completion
from .chat_store import RagContext
from ..ai.vector_stores.qdrant import QdrantVectorStore
from ..rag_types import RagPredictConfiguration
from . import models
from .chat_store import RagContext

logger = logging.getLogger(__name__)

CUSTOM_TEMPLATE = """\
Given a conversation (between Human and Assistant) and a follow up message from Human, \
rewrite the message to be a standalone question that captures all relevant context \
from the conversation. Just provide the question, not any description of it.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
"""

CUSTOM_PROMPT = PromptTemplate(CUSTOM_TEMPLATE)


class FlexibleChatEngine(CondenseQuestionChatEngine):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configuration : RagPredictConfiguration = RagPredictConfiguration()

@property
def configuration(self) -> RagPredictConfiguration:
return self._configuration

@configuration.setter
def configuration(self, value: RagPredictConfiguration) -> None:
self._configuration = value

@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
chat_history = chat_history or self._memory.get(input=message)

if self.configuration.use_question_condensing:
# Generate standalone question from conversation context and last message
condensed_question = self._condense_question(chat_history, message)
log_str = f"Querying with condensed question: {condensed_question}"
logger.info(log_str)
if self._verbose:
print(log_str)
message = condensed_question

embedding_strings = None
if self.configuration.use_hyde:
hypothetical = llm_completion.hypothetical(message, self.configuration)
logger.info(f"hypothetical document: {hypothetical}")
embedding_strings = [hypothetical]

# Query with standalone question
query_bundle = QueryBundle(message, custom_embedding_strs=embedding_strings)
query_response = self._query_engine.query(query_bundle)

tool_output = self._get_tool_output_from_response(
message, query_response
)

# Record response
self._memory.put(ChatMessage(role=MessageRole.USER, content=message))
self._memory.put(
ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response))
)

return AgentChatResponse(response=str(query_response), sources=[tool_output])


def query(
data_source_id: int,
query_str: str,
configuration: RagPredictConfiguration,
chat_history: list[RagContext],
data_source_id: int,
query_str: str,
configuration: RagPredictConfiguration,
chat_history: list[RagContext],
) -> AgentChatResponse:
qdrant_store = QdrantVectorStore.for_chunks(data_source_id)
vector_store = qdrant_store.llama_vector_store()
Expand All @@ -75,17 +146,13 @@ def query(
similarity_top_k=configuration.top_k,
embed_model=embedding_model, # is this needed, really, if it's in the index?
)
# TODO: factor out LLM and chat engine into a separate function
llm = models.get_llm(model_name=configuration.model_name)

response_synthesizer = get_response_synthesizer(llm=llm)
query_engine = RetrieverQueryEngine(
retriever=retriever, response_synthesizer=response_synthesizer
)
chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine,
llm=llm,
)
chat_engine = _build_chat_engine(configuration, llm, query_engine)

logger.info("querying chat engine")
chat_messages = list(
Expand All @@ -106,3 +173,13 @@ def query(
status_code=json_error["ResponseMetadata"]["HTTPStatusCode"],
detail=json_error["message"],
) from error


def _build_chat_engine(configuration: RagPredictConfiguration, llm: LLM, query_engine: RetrieverQueryEngine)-> FlexibleChatEngine:
chat_engine: FlexibleChatEngine = FlexibleChatEngine.from_defaults(
query_engine=query_engine,
llm=llm,
condense_question_prompt=CUSTOM_PROMPT,
)
chat_engine.configuration = configuration
return chat_engine
2 changes: 0 additions & 2 deletions llm-service/app/tests/services/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def suggested_questions_responses(


class TestProcessResponse:
# todo: add below for a failing case to be fixed!
# @reproduce_failure('6.122.1', b'AAAJAQAADQABAAEEAAABCgEAAQ0BAAEMAAAADQEAAA0A')
@given(response=suggested_questions_responses())
@example(response="Empty Response")
def test_process_response(self, response: str) -> None:
Expand Down
8 changes: 8 additions & 0 deletions ui/src/api/chatApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export interface QueryConfiguration {
top_k: number;
model_name: string;
exclude_knowledge_base: boolean;
use_question_condensing: boolean;
use_hyde: boolean;
}

export interface ChatMutationRequest {
Expand Down Expand Up @@ -200,6 +202,7 @@ const chatMutation = async (

export const createQueryConfiguration = (
excludeKnowledgeBase: boolean,
forSuggestedQuestions: boolean,
activeSession?: Session,
): QueryConfiguration => {
// todo: maybe we should just throw an exception here?
Expand All @@ -208,11 +211,16 @@ export const createQueryConfiguration = (
top_k: 5,
model_name: "",
exclude_knowledge_base: false,
use_question_condensing: false,
use_hyde: false,
};
}

return {
top_k: activeSession.responseChunks,
model_name: activeSession.inferenceModel ?? "",
exclude_knowledge_base: excludeKnowledgeBase,
use_question_condensing: !forSuggestedQuestions,
use_hyde: false,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const SuggestedQuestionsCards = () => {
data_source_ids: activeSession?.dataSourceIds ?? [],
configuration: createQueryConfiguration(
excludeKnowledgeBase,
true,
activeSession,
),
session_id: sessionId ?? "",
Expand Down Expand Up @@ -86,6 +87,7 @@ const SuggestedQuestionsCards = () => {
session_id: sessionId,
configuration: createQueryConfiguration(
excludeKnowledgeBase,
false,
activeSession,
),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ const RagChatQueryInput = () => {

const configuration = createQueryConfiguration(
excludeKnowledgeBase,
true,
activeSession,
);
const {
Expand Down Expand Up @@ -99,6 +100,7 @@ const RagChatQueryInput = () => {
session_id: sessionId,
configuration: createQueryConfiguration(
excludeKnowledgeBase,
false,
activeSession,
),
});
Expand Down

0 comments on commit 8e3f9bc

Please sign in to comment.