Skip to content

Commit

Permalink
Adust passing around of configs and params
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 committed Jan 29, 2025
1 parent 311d9d2 commit 3b9fd86
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 41 deletions.
6 changes: 1 addition & 5 deletions src/wandbot/chat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@ def __init__(
)
self.retrieval_engine = FusionRetrievalEngine(
vector_store=vector_store,
top_k=chat_config.top_k,
search_type=chat_config.search_type,
english_reranker_model=chat_config.english_reranker_model,
multilingual_reranker_model=chat_config.multilingual_reranker_model,
do_web_search=chat_config.do_web_search,
chat_config=chat_config,
)
self.response_synthesizer = ResponseSynthesizer(
model=chat_config.response_synthesizer_model,
Expand Down
2 changes: 1 addition & 1 deletion src/wandbot/configs/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ChatConfig(BaseSettings):
redundant_similarity_threshold: float = 0.95 # used to remove very similar retrieved documents

# Retrieval settings: MMR settings
fetch_k: int = 20 # Used in mmr retrieval. Typically set as top_k * 4
fetch_k: int = 60 # Used in mmr retrieval. Typically set as top_k * 4
mmr_lambda_mult: float = 0.5 # used in mmr retrieval

# Reranker models
Expand Down
47 changes: 25 additions & 22 deletions src/wandbot/rag/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from wandbot.schema.document import Document
from wandbot.schema.retrieval import RetrievalResult
from wandbot.schema.api_status import APIStatus
from wandbot.configs.chat_config import ChatConfig
import cohere
import weave
from weave.trace.autopatch import autopatch
Expand All @@ -26,18 +27,10 @@ class FusionRetrievalEngine:
def __init__(
self,
vector_store: VectorStore,
top_k: int,
search_type: str,
english_reranker_model: str,
multilingual_reranker_model: str,
do_web_search: bool
chat_config: ChatConfig,
):
self.vectorstore = vector_store
self.top_k = top_k
self.search_type = search_type
self.english_reranker_model = english_reranker_model
self.multilingual_reranker_model = multilingual_reranker_model
self.do_web_search = do_web_search
self.chat_config = chat_config
try:
self.reranker_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
except Exception as e:
Expand All @@ -59,8 +52,8 @@ def rerank_results(

documents = [doc.page_content for doc in context]
reranker_model_name = (
self.english_reranker_model if language == "en"
else self.multilingual_reranker_model
self.chat_config.english_reranker_model if language == "en"
else self.chat_config.multilingual_reranker_model
)
assert isinstance(query, str), "In rerank, `query` must be a string"
assert len(documents) > 0, "No context documents passed to the re-ranker"
Expand Down Expand Up @@ -130,23 +123,33 @@ async def _run_retrieval_common(self, inputs: Dict[str, Any], use_async: bool) -
docs_context, web_search_results = await asyncio.gather(
self.vectorstore._async_retrieve(
query_texts=inputs["all_queries"],
search_type=self.search_type,
search_type=self.chat_config.search_type,
search_params={
"top_k": self.chat_config.top_k_per_query,
"fetch_k": self.chat_config.fetch_k,
"lambda_mult": self.chat_config.mmr_lambda_mult
}
),
_async_run_web_search(
query=inputs["standalone_query"],
top_k=self.top_k,
avoid=not self.do_web_search
top_k=self.chat_config.top_k,
avoid=not self.chat_config.do_web_search
)
)
else:
docs_context = self.vectorstore.retrieve(
query_texts=inputs["all_queries"],
search_type=self.search_type,
search_type=self.chat_config.search_type,
search_params={
"top_k": self.chat_config.top_k_per_query,
"fetch_k": self.chat_config.fetch_k,
"lambda_mult": self.chat_config.mmr_lambda_mult
}
)
web_search_results = run_sync(_async_run_web_search(
query=inputs["standalone_query"],
top_k=self.top_k,
avoid=not self.do_web_search
top_k=self.chat_config.top_k,
avoid=not self.chat_config.do_web_search
))

def flatten_retrieved_results(results: Dict[str, Any]) -> tuple[List[Document], Any]:
Expand Down Expand Up @@ -187,19 +190,19 @@ def flatten_retrieved_results(results: Dict[str, Any]) -> tuple[List[Document],
context, api_status = await self._async_rerank_results(
query=inputs["standalone_query"],
context=fused_context_deduped,
top_k=self.top_k,
top_k=self.chat_config.top_k,
language=inputs["language"]
)
else:
context, api_status = self.rerank_results(
query=inputs["standalone_query"],
context=fused_context_deduped,
top_k=self.top_k,
top_k=self.chat_config.top_k,
language=inputs["language"]
)
if api_status.has_error:
logger.error(f"FUSION-RETRIEVAL: Reranker failed: {api_status.error_message}")
context = fused_context_deduped[:self.top_k] # Fallback to non-reranked results
context = fused_context_deduped[:self.chat_config.top_k] # Fallback to non-reranked results
raise Exception(api_status.error_message) # Raise for weave tracing
except Exception as e:
error_info = ErrorInfo(
Expand All @@ -209,7 +212,7 @@ def flatten_retrieved_results(results: Dict[str, Any]) -> tuple[List[Document],
error_type=type(e).__name__,
stacktrace=''.join(traceback.format_exc())
)
context = fused_context_deduped[:self.top_k] # Fallback to non-reranked results
context = fused_context_deduped[:self.chat_config.top_k] # Fallback to non-reranked results

logger.debug(f"RETRIEVAL-ENGINE: Reranked {len(context)} documents.")

Expand Down
63 changes: 50 additions & 13 deletions src/wandbot/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,71 @@ def from_config(cls, vector_store_config: VectorStoreConfig, chat_config: ChatCo
return cls(vector_store_config=vector_store_config, chat_config=chat_config)

@weave.op
def retrieve(self, query_texts: List[str], search_type: str = "mmr", search_kwargs: dict = None) -> Dict[str, List[Document]]:
"""retrieve method returns a list of documents per query in query_texts."""
search_kwargs = search_kwargs or {}
def retrieve(
self,
query_texts: List[str],
search_type: str = "mmr",
search_params: dict = None,
filter_params: dict = None,
) -> Dict[str, List[Document]]:
"""Retrieve documents using either MMR or similarity search.
Args:
query_texts: List of queries to search for
search_type: Type of search ("mmr" or "similarity")
search_params: Parameters specific to the search type
For MMR: {"top_k": int, "fetch_k": int, "lambda_mult": float}
For similarity: {"top_k": int}
filter_params: Optional filtering parameters
{"filter": dict, "where_document": dict}
"""
# Use config as defaults if not provided in search_params
if search_type == "mmr":
default_params = {
"top_k": self.chat_config.top_k_per_query,
"fetch_k": self.chat_config.fetch_k,
"lambda_mult": self.chat_config.mmr_lambda_mult
}
else:
default_params = {
"top_k": self.chat_config.top_k_per_query
}

# Merge provided params with defaults
search_params = {**default_params, **(search_params or {})}
filter_params = filter_params or {}

if search_type == "mmr":
results = self.chroma_vectorstore.max_marginal_relevance_search(
query_texts=query_texts,
top_k=self.chat_config.top_k_per_query,
fetch_k=self.chat_config.fetch_k,
lambda_mult=self.chat_config.mmr_lambda_mult,
filter=search_kwargs.get("filter"),
where_document=search_kwargs.get("where_document")
top_k=search_params.get("top_k", self.chat_config.top_k_per_query),
fetch_k=search_params.get("fetch_k", self.chat_config.fetch_k),
lambda_mult=search_params.get("lambda_mult", self.chat_config.mmr_lambda_mult),
filter=filter_params.get("filter"),
where_document=filter_params.get("where_document")
)
else:
results = self.chroma_vectorstore.similarity_search(
query_texts=query_texts,
top_k=self.chat_config.top_k_per_query,
filter=search_kwargs.get("filter"),
where_document=search_kwargs.get("where_document")
top_k=search_params.get("top_k", self.chat_config.top_k_per_query),
filter=filter_params.get("filter"),
where_document=filter_params.get("where_document")
)

return results

async def _async_retrieve(self, query_texts: List[str], search_type: str = "mmr", search_kwargs: dict = None) -> Dict[str, List[Document]]:
async def _async_retrieve(
self,
query_texts: List[str],
search_type: str = "mmr",
search_params: dict = None,
filter_params: dict = None
) -> Dict[str, List[Document]]:
"""Async version of retrieve that returns the same dictionary structure."""
return await asyncio.to_thread(
self.retrieve,
query_texts=query_texts,
search_type=search_type,
search_kwargs=search_kwargs,
search_params=search_params,
filter_params=filter_params,
)

0 comments on commit 3b9fd86

Please sign in to comment.