Skip to content

Commit

Permalink
🐛 fixing app running issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rathijitpapon committed Mar 26, 2024
1 parent dda9a50 commit eb0c633
Show file tree
Hide file tree
Showing 6 changed files with 519 additions and 41 deletions.
1 change: 1 addition & 0 deletions backend/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ WANDB_API_KEY=
# Sentry Configuration
## need to be set before running the server
SENTRY_DSN=
SENTRY_ENABLE_TRACING=

# GROQ API Configuration
## need to be set before running the server
Expand Down
15 changes: 9 additions & 6 deletions backend/app/api/endpoints/search_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi_versioning import version
from authx import AuthX, AuthXConfig
import sentry_sdk
import json

from app.api.common.util import RouteCategory
from app.database.redis import Redis
Expand All @@ -25,16 +26,16 @@


@router.get(
"/search/{routecategory}",
"/search",
summary="List all Search Results",
description="List all Search Results",
dependencies=[Depends(security.access_token_required)],
response_model=str,
response_model=dict[str, str],
)
@version(1, 0)
async def get_search_results(
query: str = "",
routecategory: RouteCategory = RouteCategory.PBW
routecategory: RouteCategory = RouteCategory.NS
) -> JSONResponse:
if trace_transaction := sentry_sdk.Hub.current.scope.transaction:
trace_transaction.set_tag("title", 'api_get_search_results')
Expand All @@ -43,13 +44,15 @@ async def get_search_results(

query = query.strip()
cache = Redis()
search_result = await cache.get_value(query)
cache_key = f"{query}##{routecategory}"
search_result = await cache.get_value(cache_key)

if search_result:
search_result = json.loads(search_result)
logger.info(f"get_search_results. cached_result: {search_result}")
else:
search_result = await orchestrator.query_and_get_answer(routecategory, search_text=query)
await cache.set_value(query, search_result['result'])
search_result = await orchestrator.query_and_get_answer(search_text=query, routecategory=routecategory)
await cache.set_value(cache_key, json.dumps(search_result))

await cache.add_to_sorted_set("searched_queries", query)

Expand Down
8 changes: 7 additions & 1 deletion backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@

# Sentry Configuration
SENTRY_DSN: Secret = config("SENTRY_DSN", cast=Secret)
SENTRY_ENABLE_TRACING: bool = config("SENTRY_ENABLE_TRACING", cast=bool, default=False)

# GROQ API Configuration
GROQ_API_KEY: Secret = config("GROQ_API_KEY", cast=Secret)
Expand All @@ -129,4 +130,9 @@

# LLAMA_INDEX Configuration
CHAT_ENABLED: bool = config("CHAT_ENABLED", default=False)
PUBMED_RELEVANCE_CRITERIA: float = config("PUBMED_RELEVANCE_CRITERIA", default=0.7)
PUBMED_RELEVANCE_CRITERIA: float = config("PUBMED_RELEVANCE_CRITERIA", default=0.7)

# Dspy Integration Configuration
CLINICAL_TRIAL_SQL_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_sql_generation.json"
CLINICAL_TRIALS_RESPONSE_REFINEMENT_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_response_refinement.json"
ORCHESRATOR_ROUTER_PROMPT_PROGRAM: str = "app/dspy_integration/dspy_programs/orchestrator_router_prompt.json"
71 changes: 38 additions & 33 deletions backend/app/router/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import asyncio
import re

from llama_index.core.tools import ToolMetadata
from llama_index.core.selectors import LLMSingleSelector
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import QueryBundle
from llama_index.llms.together import TogetherLLM
from llama_index.core.response_synthesizers import SimpleSummarize
Expand Down Expand Up @@ -48,11 +45,12 @@ def __init__(self, config):
self.bravesearch = BraveSearchQueryEngine(config)

async def query_and_get_answer(
self,
routecategory: RouteCategory = RouteCategory.PBW,
search_text: str = "") -> str:
self,
search_text: str,
routecategory: RouteCategory = RouteCategory.NS
) -> dict[str, str]:
# search router call
logger.debug(
logger.info(
f"Orchestrator.query_and_get_answer.router_id search_text: {search_text}"
)

Expand All @@ -68,28 +66,30 @@ async def query_and_get_answer(
logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True)
logger.info(f"query_and_get_answer.router_id router_id: {router_id}")

breaks_sql = False

#routing
if router_id == 0 or routecategory == RouteCategory.CT:
# clinical trial call
logger.debug(
logger.info(
"Orchestrator.query_and_get_answer.router_id clinical trial Entered."
)
try:
sqlResponse = self.clinicalTrialSearch.call_text2sql(search_text=search_text)
result = str(sqlResponse)
sources = result

logger.debug(f"Orchestrator.query_and_get_answer.sqlResponse sqlResponse: {result}")
logger.info(f"Orchestrator.query_and_get_answer.sqlResponse sqlResponse: {result}")

return {
"result" : result,
"sources": sources
}
except Exception as e:
breaks_sql = True
logger.exception("Orchestrator.query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True)
pass

elif router_id == 1 or routecategory == RouteCategory.DRUG:
# drug information call
logger.debug(
logger.info(
"Orchestrator.query_and_get_answer.router_id drug_information_choice Entered."
)
try:
Expand All @@ -98,38 +98,43 @@ async def query_and_get_answer(
)
result = str(cypherResponse)
sources = result
logger.debug(
logger.info(
f"Orchestrator.query_and_get_answer.cypherResponse cypherResponse: {result}"
)

return {
"result" : result,
"sources": sources
}
except Exception as e:
breaks_sql = True
logger.exception(
"Orchestrator.query_and_get_answer.cypherResponse Exception -",
exc_info=e,
stack_info=True,
)

if router_id == 2 or routecategory == RouteCategory.PBW or routecategory == RouteCategory.NS or breaks_sql:
logger.debug(
"Orchestrator.query_and_get_answer.router_id Fallback Entered."
)

extracted_pubmed_results, extracted_web_results = await asyncio.gather(
self.pubmedsearch.call_pubmed_vectors(search_text=search_text), self.bravesearch.call_brave_search_api(search_text=search_text)
)
extracted_results = extracted_pubmed_results + extracted_web_results
logger.debug(
f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}"
)
# if routing fails, sql and cypher calls fail, routing to pubmed or brave
logger.info(
"Orchestrator.query_and_get_answer.router_id Fallback Entered."
)

extracted_pubmed_results, extracted_web_results = await asyncio.gather(
self.pubmedsearch.call_pubmed_vectors(search_text=search_text), self.bravesearch.call_brave_search_api(search_text=search_text)
)
extracted_results = extracted_pubmed_results + extracted_web_results
logger.info(
f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}"
)

# rerank call
reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes(
nodes = extracted_results,
query_bundle=QueryBundle(query_str=search_text))
# rerank call
reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes(
nodes = extracted_results,
query_bundle=QueryBundle(query_str=search_text))

summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY)))
result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results])
sources = [node.node.metadata for node in reranked_results ]
summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY)))
result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results])
sources = [node.node.metadata for node in reranked_results ]

return {
"result" : result,
Expand Down
Loading

0 comments on commit eb0c633

Please sign in to comment.