diff --git a/backend/.vscode/settings.json b/backend/.vscode/settings.json index 19a8159d..614a3fe4 100644 --- a/backend/.vscode/settings.json +++ b/backend/.vscode/settings.json @@ -3,21 +3,36 @@ "AACT", "arize", "biopharmaceutical", + "bioxriv", "bravesearch", + "CHEMBL", + "clinicaltrials", "Curieo", + "cypher", + "GROQ", "llms", "mistralai", + "Mixtral", "openai", "perticipents", + "postprocess", + "postprocessor", "pubmed", + "pubmedqueryengine", + "pubmedsearch", "pydantic", "pyvis", + "QDRANT", "Rerank", "reranked", "reranker", + "reranking", + "routecategory", "sqlalchemy", "starlette", "togetherai", - "uvicorn" + "topqueries", + "uvicorn", + "WANDB" ] } \ No newline at end of file diff --git a/backend/app/api/common/util.py b/backend/app/api/common/util.py new file mode 100644 index 00000000..895fb382 --- /dev/null +++ b/backend/app/api/common/util.py @@ -0,0 +1,7 @@ +from enum import Enum + +class RouteCategory(str, Enum): + CT = "clinicaltrials" + DRUG = "drug" + PBW = "pubmed_bioxriv_web" + NS = "not_selected" \ No newline at end of file diff --git a/backend/app/api/endpoints/search_endpoint.py b/backend/app/api/endpoints/search_endpoint.py index 56d13183..fe4db234 100644 --- a/backend/app/api/endpoints/search_endpoint.py +++ b/backend/app/api/endpoints/search_endpoint.py @@ -2,10 +2,10 @@ from fastapi.responses import JSONResponse from fastapi_versioning import version from authx import AuthX, AuthXConfig -from app.database.redis import Redis import sentry_sdk - +from app.api.common.util import RouteCategory +from app.database.redis import Redis from app.api.router.gzip import GzipRoute from app.router.orchestrator import Orchestrator from app.config import config, JWT_SECRET_KEY, JWT_ALGORITHM @@ -25,15 +25,16 @@ @router.get( - "/search", + "/search/{routecategory}", summary="List all Search Results", description="List all Search Results", dependencies=[Depends(security.access_token_required)], - response_model=dict[str, str] + response_model=str, ) @version(1, 0) async def get_search_results( - query: str = "" + query: str = "", + routecategory: RouteCategory = RouteCategory.PBW ) -> JSONResponse: if trace_transaction := sentry_sdk.Hub.current.scope.transaction: trace_transaction.set_tag("title", 'api_get_search_results') @@ -47,14 +48,14 @@ async def get_search_results( if search_result: logger.info(f"get_search_results. cached_result: {search_result}") else: - search_result = await orchestrator.query_and_get_answer(search_text=query) - await cache.set_value(query, search_result) + search_result = await orchestrator.query_and_get_answer(routecategory, search_text=query) + await cache.set_value(query, search_result['result']) await cache.add_to_sorted_set("searched_queries", query) logger.info(f"get_search_results. result: {search_result}") - return JSONResponse(status_code=200, content={"result": search_result}) + return JSONResponse(status_code=200, content=search_result) @router.get( @@ -81,4 +82,4 @@ async def get_top_search_queries( logger.info(f"get_top_search_queries. result: {last_x_keys}") - return JSONResponse(status_code=200, content=last_x_keys) + return JSONResponse(status_code=200, content=last_x_keys) \ No newline at end of file diff --git a/backend/app/rag/reranker/response_reranker.py b/backend/app/rag/reranker/response_reranker.py index 9d178abe..6fea023f 100644 --- a/backend/app/rag/reranker/response_reranker.py +++ b/backend/app/rag/reranker/response_reranker.py @@ -1,69 +1,88 @@ -import collections +from typing import Any, List, Optional import requests import re +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CBEventType, EventPayload +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore, QueryBundle + from app.services.search_utility import setup_logger -from app.config import EMBEDDING_RERANK_API, EMBEDDING_CHUNK_SIZE +from app.config import EMBEDDING_RERANK_API, RERANK_TOP_COUNT -logger = setup_logger('Reranking') TAG_RE = re.compile(r'<[^>]+>') -class ReRankEngine: +logger = setup_logger("TextEmbeddingInferenceRerankEngine") + +class TextEmbeddingInferenceRerankEngine(BaseNodePostprocessor): """ - This class implements the logic re-ranking response and returns the results. - It uses the embedding api that process the query and responses from the retrieval layer. - It returns the output in list format. + The class extends the BaseNodePostprocessor class, aimed at reranking nodes (elements) based on text embedding inference. + This class is part of a larger framework, likely for processing and analyzing data within a specific domain, + such as document retrieval or search engine optimization. Here's an overview of the class and its components: """ + model: str = Field( + default="BAAI/bge-reranker-large", + description="The model to use when calling AI API", + ) + _session: Any = PrivateAttr() - def __init__(self, config): - self.config = config - - async def call_embedding_api( + def __init__( self, - search_text: str, - retrieval_results: collections.defaultdict[list] - ) -> collections.defaultdict[list]: - logger.info("call_embedding_api. search_text: " + search_text) - logger.info("call_embedding_api. retrieval_results length: " + str(len(retrieval_results))) - - endpoint = "{url_address}".format( - url_address=EMBEDDING_RERANK_API - ) - - headers = { - 'Accept': 'application/json' - } - - results = collections.defaultdict(list) - - #clean the data - retrieval_results_text_data = [result.get('text') for result in retrieval_results.values()] - retrieval_results_clean_text_data = [payload.replace("\n", " ").replace("\"","") for payload in retrieval_results_text_data] - retrieval_results_clean_html_data = [TAG_RE.sub('', payload) for payload in retrieval_results_clean_text_data] - - #chunking the data - payload = [payload[:EMBEDDING_CHUNK_SIZE] for payload in retrieval_results_clean_html_data] + top_n: int = 2, + model: str = "BAAI/bge-reranker-large" + ): + super().__init__(top_n=top_n, model=model) + self.model = model + self._session = requests.Session() - request_data = { - "query": search_text, - "texts": payload - } - try: - logger.info("call_embedding_api. endpoint: " + endpoint) - logger.info("call_embedding_api. headers: " + str(headers)) - logger.info("call_embedding_api. request_data: " + str(request_data)) + @classmethod + def class_name(cls) -> str: + return "TextEmbeddingInferenceRerankEngine" - response = requests.request("POST", endpoint, headers=headers, json=request_data) - response.raise_for_status() + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + """ + This method takes a list of nodes (each represented by a NodeWithScore object) and an optional QueryBundle object. + It performs the reranking operation by: + -- Validating the input. + -- Extracting text from each node, removing HTML tags. + -- Sending a request to the specified AI API with the extracted texts and additional query information. + -- Processing the API's response to update the nodes' scores based on reranking results. + -- Returning the top N reranked nodes, according to the class's top_n attribute and possibly constrained by a global RERANK_TOP_COUNT. + """ + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + if len(nodes) == 0: + return [] - rerank_orders = response.json() - results = [retrieval_results[order.get('index')] for order in rerank_orders] - except Exception as ex: - logger.exception("call_embedding_api Exception -", exc_info = ex, stack_info=True) - raise ex - - logger.info(f"call_embedding_api. results: {results}") + with self.callback_manager.event( + CBEventType.RERANKING + ) as event: + logger.info("TextEmbeddingInferenceRerankEngine.postprocess_nodes query: " + query_bundle.query_str) + texts = [TAG_RE.sub('', node.get_content()) for node in nodes] + results = self._session.post( # type: ignore + EMBEDDING_RERANK_API, + json={ + "query": query_bundle.query_str, + "truncate": True, + "texts": texts + }, + ).json() + + if len(results) == 0: + raise RuntimeError(results["detail"]) + + new_nodes = [] + for result in results: + new_node_with_score = NodeWithScore( + node=nodes[result["index"]], score=result["score"] + ) + new_nodes.append(new_node_with_score) + event.on_end(payload={EventPayload.NODES: new_nodes}) - return results + return new_nodes[:RERANK_TOP_COUNT] \ No newline at end of file diff --git a/backend/app/router/orchestrator.py b/backend/app/router/orchestrator.py index 07372082..d4e7f0ec 100644 --- a/backend/app/router/orchestrator.py +++ b/backend/app/router/orchestrator.py @@ -1,18 +1,24 @@ -import dspy -from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine +import asyncio +import re + +from llama_index.core.tools import ToolMetadata +from llama_index.core.selectors import LLMSingleSelector +from llama_index.llms.openai import OpenAI +from llama_index.core.schema import QueryBundle +from llama_index.llms.together import TogetherLLM +from llama_index.core.response_synthesizers import SimpleSummarize + from app.rag.retrieval.clinical_trials.clinical_trial_sql_query_engine import ClinicalTrialText2SQLEngine from app.rag.retrieval.drug_chembl.drug_chembl_graph_query_engine import DrugChEMBLText2CypherEngine -from app.rag.reranker.response_reranker import ReRankEngine -from app.rag.generation.response_synthesis import ResponseSynthesisEngine -from app.config import config, OPENAI_API_KEY, RERANK_TOP_COUNT, ORCHESRATOR_ROUTER_PROMPT_PROGRAM - +from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine +from app.rag.retrieval.pubmed.pubmedqueryengine import PubmedSearchQueryEngine +from app.rag.reranker.response_reranker import TextEmbeddingInferenceRerankEngine +from app.api.common.util import RouteCategory +from app.config import config, OPENAI_API_KEY, TOGETHER_KEY from app.services.search_utility import setup_logger -from app.dspy_integration.router_prompt import Router_module - - -logger = setup_logger('Orchestrator') - +logger = setup_logger("Orchestrator") +TAG_RE = re.compile(r'<[^>]+>') class Orchestrator: """ @@ -23,80 +29,118 @@ class Orchestrator: def __init__(self, config): self.config = config - - self.llm = dspy.OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY)) - dspy.settings.configure(lm = self.llm) - self.router = Router_module() - self.router.load(ORCHESRATOR_ROUTER_PROMPT_PROGRAM) + self.choices = [ + ToolMetadata( + description="""useful for retrieving only the clinical trials information like adverse effects,eligibility details + of clinical trials perticipents, sponsor details, death count, condition of many healthcare problems""", + name="clinical_trial_choice", + ), + ToolMetadata( + description="""useful only for retrieving the drug related information like molecular weights, + similarities,smile codes, target medicines, effects on other medicine""", + name="drug_information_choice", + ), + ToolMetadata( + description="""useful for retrieving general information about healthcare data.""", + name="pubmed_brave_choice", + ), + ] + + self.ROUTER_PROMPT = "You are working as router of a healthcare search engine.Some choices are given below. It is provided in a numbered list (1 to {num_choices}) where each item in the list corresponds to a summary.\n---------------------\n{context_list}\n---------------------\nIf you are not super confident then please use choice 3 as default choice.Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '{query_str}'\n" + + self.selector = LLMSingleSelector.from_defaults( + llm=OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY)), + prompt_template_str=self.ROUTER_PROMPT, + ) + + self.clinicalTrialSearch = ClinicalTrialText2SQLEngine(config) + self.drugChemblSearch = DrugChEMBLText2CypherEngine(config) + self.pubmedsearch = PubmedSearchQueryEngine(config) + self.bravesearch = BraveSearchQueryEngine(config) async def query_and_get_answer( - self, - search_text: str - ) -> str: - logger.info(f"query_and_get_answer.router_id search_text: {search_text}") - try : - router_id = int(self.router(search_text).answer) - except Exception as e: - logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True) - logger.info(f"query_and_get_answer.router_id router_id: {router_id}") - - breaks_sql = False - - if router_id == 0: - clinicalTrialSearch = ClinicalTrialText2SQLEngine(config) + self, + routecategory: RouteCategory = RouteCategory.PBW, + search_text: str = "") -> str: + # search router call + logger.debug( + f"Orchestrator.query_and_get_answer.router_id search_text: {search_text}" + ) + + #initialize router with bad value + router_id = -1 + + # user not specified + if routecategory == RouteCategory.NS: + selector_result = self.selector.select(self.choices, query=search_text) + router_id = selector_result.selections[0].index + logger.debug( + f"Orchestrator.query_and_get_answer.router_id router_id: {router_id}" + ) + breaks_sql = False + + #routing + if router_id == 0 or routecategory == RouteCategory.CT: + # clinical trial call + logger.debug( + "Orchestrator.query_and_get_answer.router_id clinical trial Entered." + ) try: - sqlResponse = await clinicalTrialSearch.call_text2sql(search_text=search_text) - result = sqlResponse.get('result', '') - logger.info(f"query_and_get_answer.sqlResponse sqlResponse: {result}") + sqlResponse = self.clinicalTrialSearch.call_text2sql(search_text=search_text) + result = str(sqlResponse) + sources = result + + logger.debug(f"Orchestrator.query_and_get_answer.sqlResponse sqlResponse: {result}") except Exception as e: breaks_sql = True - logger.exception("query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True) + logger.exception("Orchestrator.query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True) + pass - elif router_id == 1: + elif router_id == 1 or routecategory == RouteCategory.DRUG: # drug information call - logger.info("query_and_get_answer.router_id drug_information_choice Entered.") - - drugChemblSearch = DrugChEMBLText2CypherEngine(config) - result = [] - + logger.debug( + "Orchestrator.query_and_get_answer.router_id drug_information_choice Entered." + ) try: - cypherResponse = await drugChemblSearch.call_text2cypher(search_text=search_text) + cypherResponse = self.drugChemblSearch.call_text2cypher( + search_text=search_text + ) result = str(cypherResponse) - - logger.info(f"query_and_get_answer.cypherResponse cypherResponse: {result}") + sources = result + logger.debug( + f"Orchestrator.query_and_get_answer.cypherResponse cypherResponse: {result}" + ) except Exception as e: breaks_sql = True - logger.exception("query_and_get_answer.cypherResponse Exception -", exc_info = e, stack_info=True) - - print() - - if router_id == 2 or breaks_sql: - logger.info("query_and_get_answer.router_id Fallback Entered.") - - bravesearch = BraveSearchQueryEngine(config) - extracted_retrieved_results = await bravesearch.call_brave_search_api(search_text=search_text) - - logger.info(f"query_and_get_answer.extracted_retrieved_results: {extracted_retrieved_results}") - - - #rerank call - rerank = ReRankEngine(config) - rerankResponse = await rerank.call_embedding_api( - search_text=search_text, - retrieval_results=extracted_retrieved_results + logger.exception( + "Orchestrator.query_and_get_answer.cypherResponse Exception -", + exc_info=e, + stack_info=True, + ) + + if router_id == 2 or routecategory == RouteCategory.PBW or routecategory == RouteCategory.NS or breaks_sql: + logger.debug( + "Orchestrator.query_and_get_answer.router_id Fallback Entered." + ) + + extracted_pubmed_results, extracted_web_results = await asyncio.gather( + self.pubmedsearch.call_pubmed_vectors(search_text=search_text), self.bravesearch.call_brave_search_api(search_text=search_text) ) - rerankResponse_sliced = rerankResponse[:RERANK_TOP_COUNT] - logger.info(f"query_and_get_answer.rerankResponse_sliced: {rerankResponse_sliced}") - - #generation call - response_synthesis = ResponseSynthesisEngine(config) - result = await response_synthesis.call_llm_service_api( - search_text=search_text, - reranked_results=rerankResponse_sliced + extracted_results = extracted_pubmed_results + extracted_web_results + logger.debug( + f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}" ) - result = result.get('result', '') + "\n\n" + "Source: " + ', '.join(result.get('source', [])) - logger.info(f"query_and_get_answer.response_synthesis: {result}") - - logger.info(f"query_and_get_answer. result: {result}") - return result + # rerank call + reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes( + nodes = extracted_results, + query_bundle=QueryBundle(query_str=search_text)) + + summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY))) + result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results]) + sources = [node.node.metadata for node in reranked_results ] + + return { + "result" : result, + "sources": sources + } \ No newline at end of file