Skip to content

Commit

Permalink
pubmed integration
Browse files Browse the repository at this point in the history
  • Loading branch information
raahulrahl committed Mar 26, 2024
1 parent 80b4732 commit ec3de1f
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 136 deletions.
17 changes: 16 additions & 1 deletion backend/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,36 @@
"AACT",
"arize",
"biopharmaceutical",
"bioxriv",
"bravesearch",
"CHEMBL",
"clinicaltrials",
"Curieo",
"cypher",
"GROQ",
"llms",
"mistralai",
"Mixtral",
"openai",
"perticipents",
"postprocess",
"postprocessor",
"pubmed",
"pubmedqueryengine",
"pubmedsearch",
"pydantic",
"pyvis",
"QDRANT",
"Rerank",
"reranked",
"reranker",
"reranking",
"routecategory",
"sqlalchemy",
"starlette",
"togetherai",
"uvicorn"
"topqueries",
"uvicorn",
"WANDB"
]
}
7 changes: 7 additions & 0 deletions backend/app/api/common/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum

class RouteCategory(str, Enum):
CT = "clinicaltrials"
DRUG = "drug"
PBW = "pubmed_bioxriv_web"
NS = "not_selected"
19 changes: 10 additions & 9 deletions backend/app/api/endpoints/search_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from fastapi.responses import JSONResponse
from fastapi_versioning import version
from authx import AuthX, AuthXConfig
from app.database.redis import Redis
import sentry_sdk


from app.api.common.util import RouteCategory
from app.database.redis import Redis
from app.api.router.gzip import GzipRoute
from app.router.orchestrator import Orchestrator
from app.config import config, JWT_SECRET_KEY, JWT_ALGORITHM
Expand All @@ -25,15 +25,16 @@


@router.get(
"/search",
"/search/{routecategory}",
summary="List all Search Results",
description="List all Search Results",
dependencies=[Depends(security.access_token_required)],
response_model=dict[str, str]
response_model=str,
)
@version(1, 0)
async def get_search_results(
query: str = ""
query: str = "",
routecategory: RouteCategory = RouteCategory.PBW
) -> JSONResponse:
if trace_transaction := sentry_sdk.Hub.current.scope.transaction:
trace_transaction.set_tag("title", 'api_get_search_results')
Expand All @@ -47,14 +48,14 @@ async def get_search_results(
if search_result:
logger.info(f"get_search_results. cached_result: {search_result}")
else:
search_result = await orchestrator.query_and_get_answer(search_text=query)
await cache.set_value(query, search_result)
search_result = await orchestrator.query_and_get_answer(routecategory, search_text=query)
await cache.set_value(query, search_result['result'])

await cache.add_to_sorted_set("searched_queries", query)

logger.info(f"get_search_results. result: {search_result}")

return JSONResponse(status_code=200, content={"result": search_result})
return JSONResponse(status_code=200, content=search_result)


@router.get(
Expand All @@ -81,4 +82,4 @@ async def get_top_search_queries(

logger.info(f"get_top_search_queries. result: {last_x_keys}")

return JSONResponse(status_code=200, content=last_x_keys)
return JSONResponse(status_code=200, content=last_x_keys)
123 changes: 71 additions & 52 deletions backend/app/rag/reranker/response_reranker.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,88 @@
import collections
from typing import Any, List, Optional
import requests
import re

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle

from app.services.search_utility import setup_logger
from app.config import EMBEDDING_RERANK_API, EMBEDDING_CHUNK_SIZE
from app.config import EMBEDDING_RERANK_API, RERANK_TOP_COUNT

logger = setup_logger('Reranking')
TAG_RE = re.compile(r'<[^>]+>')


class ReRankEngine:
logger = setup_logger("TextEmbeddingInferenceRerankEngine")

class TextEmbeddingInferenceRerankEngine(BaseNodePostprocessor):
"""
This class implements the logic re-ranking response and returns the results.
It uses the embedding api that process the query and responses from the retrieval layer.
It returns the output in list format.
The class extends the BaseNodePostprocessor class, aimed at reranking nodes (elements) based on text embedding inference.
This class is part of a larger framework, likely for processing and analyzing data within a specific domain,
such as document retrieval or search engine optimization. Here's an overview of the class and its components:
"""
model: str = Field(
default="BAAI/bge-reranker-large",
description="The model to use when calling AI API",
)
_session: Any = PrivateAttr()

def __init__(self, config):
self.config = config

async def call_embedding_api(
def __init__(
self,
search_text: str,
retrieval_results: collections.defaultdict[list]
) -> collections.defaultdict[list]:
logger.info("call_embedding_api. search_text: " + search_text)
logger.info("call_embedding_api. retrieval_results length: " + str(len(retrieval_results)))

endpoint = "{url_address}".format(
url_address=EMBEDDING_RERANK_API
)

headers = {
'Accept': 'application/json'
}

results = collections.defaultdict(list)

#clean the data
retrieval_results_text_data = [result.get('text') for result in retrieval_results.values()]
retrieval_results_clean_text_data = [payload.replace("\n", " ").replace("\"","") for payload in retrieval_results_text_data]
retrieval_results_clean_html_data = [TAG_RE.sub('', payload) for payload in retrieval_results_clean_text_data]

#chunking the data
payload = [payload[:EMBEDDING_CHUNK_SIZE] for payload in retrieval_results_clean_html_data]
top_n: int = 2,
model: str = "BAAI/bge-reranker-large"
):
super().__init__(top_n=top_n, model=model)
self.model = model
self._session = requests.Session()

request_data = {
"query": search_text,
"texts": payload
}

try:
logger.info("call_embedding_api. endpoint: " + endpoint)
logger.info("call_embedding_api. headers: " + str(headers))
logger.info("call_embedding_api. request_data: " + str(request_data))
@classmethod
def class_name(cls) -> str:
return "TextEmbeddingInferenceRerankEngine"

response = requests.request("POST", endpoint, headers=headers, json=request_data)
response.raise_for_status()
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""
This method takes a list of nodes (each represented by a NodeWithScore object) and an optional QueryBundle object.
It performs the reranking operation by:
-- Validating the input.
-- Extracting text from each node, removing HTML tags.
-- Sending a request to the specified AI API with the extracted texts and additional query information.
-- Processing the API's response to update the nodes' scores based on reranking results.
-- Returning the top N reranked nodes, according to the class's top_n attribute and possibly constrained by a global RERANK_TOP_COUNT.
"""
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []

rerank_orders = response.json()
results = [retrieval_results[order.get('index')] for order in rerank_orders]
except Exception as ex:
logger.exception("call_embedding_api Exception -", exc_info = ex, stack_info=True)
raise ex

logger.info(f"call_embedding_api. results: {results}")
with self.callback_manager.event(
CBEventType.RERANKING
) as event:
logger.info("TextEmbeddingInferenceRerankEngine.postprocess_nodes query: " + query_bundle.query_str)
texts = [TAG_RE.sub('', node.get_content()) for node in nodes]
results = self._session.post( # type: ignore
EMBEDDING_RERANK_API,
json={
"query": query_bundle.query_str,
"truncate": True,
"texts": texts
},
).json()

if len(results) == 0:
raise RuntimeError(results["detail"])

new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result["index"]], score=result["score"]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})

return results
return new_nodes[:RERANK_TOP_COUNT]
Loading

0 comments on commit ec3de1f

Please sign in to comment.