-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
80b4732
commit ec3de1f
Showing
5 changed files
with
222 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from enum import Enum | ||
|
||
class RouteCategory(str, Enum): | ||
CT = "clinicaltrials" | ||
DRUG = "drug" | ||
PBW = "pubmed_bioxriv_web" | ||
NS = "not_selected" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,88 @@ | ||
import collections | ||
from typing import Any, List, Optional | ||
import requests | ||
import re | ||
|
||
from llama_index.core.bridge.pydantic import Field, PrivateAttr | ||
from llama_index.core.callbacks import CBEventType, EventPayload | ||
from llama_index.core.postprocessor.types import BaseNodePostprocessor | ||
from llama_index.core.schema import NodeWithScore, QueryBundle | ||
|
||
from app.services.search_utility import setup_logger | ||
from app.config import EMBEDDING_RERANK_API, EMBEDDING_CHUNK_SIZE | ||
from app.config import EMBEDDING_RERANK_API, RERANK_TOP_COUNT | ||
|
||
logger = setup_logger('Reranking') | ||
TAG_RE = re.compile(r'<[^>]+>') | ||
|
||
|
||
class ReRankEngine: | ||
logger = setup_logger("TextEmbeddingInferenceRerankEngine") | ||
|
||
class TextEmbeddingInferenceRerankEngine(BaseNodePostprocessor): | ||
""" | ||
This class implements the logic re-ranking response and returns the results. | ||
It uses the embedding api that process the query and responses from the retrieval layer. | ||
It returns the output in list format. | ||
The class extends the BaseNodePostprocessor class, aimed at reranking nodes (elements) based on text embedding inference. | ||
This class is part of a larger framework, likely for processing and analyzing data within a specific domain, | ||
such as document retrieval or search engine optimization. Here's an overview of the class and its components: | ||
""" | ||
model: str = Field( | ||
default="BAAI/bge-reranker-large", | ||
description="The model to use when calling AI API", | ||
) | ||
_session: Any = PrivateAttr() | ||
|
||
def __init__(self, config): | ||
self.config = config | ||
|
||
async def call_embedding_api( | ||
def __init__( | ||
self, | ||
search_text: str, | ||
retrieval_results: collections.defaultdict[list] | ||
) -> collections.defaultdict[list]: | ||
logger.info("call_embedding_api. search_text: " + search_text) | ||
logger.info("call_embedding_api. retrieval_results length: " + str(len(retrieval_results))) | ||
|
||
endpoint = "{url_address}".format( | ||
url_address=EMBEDDING_RERANK_API | ||
) | ||
|
||
headers = { | ||
'Accept': 'application/json' | ||
} | ||
|
||
results = collections.defaultdict(list) | ||
|
||
#clean the data | ||
retrieval_results_text_data = [result.get('text') for result in retrieval_results.values()] | ||
retrieval_results_clean_text_data = [payload.replace("\n", " ").replace("\"","") for payload in retrieval_results_text_data] | ||
retrieval_results_clean_html_data = [TAG_RE.sub('', payload) for payload in retrieval_results_clean_text_data] | ||
|
||
#chunking the data | ||
payload = [payload[:EMBEDDING_CHUNK_SIZE] for payload in retrieval_results_clean_html_data] | ||
top_n: int = 2, | ||
model: str = "BAAI/bge-reranker-large" | ||
): | ||
super().__init__(top_n=top_n, model=model) | ||
self.model = model | ||
self._session = requests.Session() | ||
|
||
request_data = { | ||
"query": search_text, | ||
"texts": payload | ||
} | ||
|
||
try: | ||
logger.info("call_embedding_api. endpoint: " + endpoint) | ||
logger.info("call_embedding_api. headers: " + str(headers)) | ||
logger.info("call_embedding_api. request_data: " + str(request_data)) | ||
@classmethod | ||
def class_name(cls) -> str: | ||
return "TextEmbeddingInferenceRerankEngine" | ||
|
||
response = requests.request("POST", endpoint, headers=headers, json=request_data) | ||
response.raise_for_status() | ||
def _postprocess_nodes( | ||
self, | ||
nodes: List[NodeWithScore], | ||
query_bundle: Optional[QueryBundle] = None, | ||
) -> List[NodeWithScore]: | ||
""" | ||
This method takes a list of nodes (each represented by a NodeWithScore object) and an optional QueryBundle object. | ||
It performs the reranking operation by: | ||
-- Validating the input. | ||
-- Extracting text from each node, removing HTML tags. | ||
-- Sending a request to the specified AI API with the extracted texts and additional query information. | ||
-- Processing the API's response to update the nodes' scores based on reranking results. | ||
-- Returning the top N reranked nodes, according to the class's top_n attribute and possibly constrained by a global RERANK_TOP_COUNT. | ||
""" | ||
if query_bundle is None: | ||
raise ValueError("Missing query bundle in extra info.") | ||
if len(nodes) == 0: | ||
return [] | ||
|
||
rerank_orders = response.json() | ||
results = [retrieval_results[order.get('index')] for order in rerank_orders] | ||
except Exception as ex: | ||
logger.exception("call_embedding_api Exception -", exc_info = ex, stack_info=True) | ||
raise ex | ||
|
||
logger.info(f"call_embedding_api. results: {results}") | ||
with self.callback_manager.event( | ||
CBEventType.RERANKING | ||
) as event: | ||
logger.info("TextEmbeddingInferenceRerankEngine.postprocess_nodes query: " + query_bundle.query_str) | ||
texts = [TAG_RE.sub('', node.get_content()) for node in nodes] | ||
results = self._session.post( # type: ignore | ||
EMBEDDING_RERANK_API, | ||
json={ | ||
"query": query_bundle.query_str, | ||
"truncate": True, | ||
"texts": texts | ||
}, | ||
).json() | ||
|
||
if len(results) == 0: | ||
raise RuntimeError(results["detail"]) | ||
|
||
new_nodes = [] | ||
for result in results: | ||
new_node_with_score = NodeWithScore( | ||
node=nodes[result["index"]], score=result["score"] | ||
) | ||
new_nodes.append(new_node_with_score) | ||
event.on_end(payload={EventPayload.NODES: new_nodes}) | ||
|
||
return results | ||
return new_nodes[:RERANK_TOP_COUNT] |
Oops, something went wrong.