Skip to content

Commit

Permalink
pubmed_data_coll
Browse files Browse the repository at this point in the history
  • Loading branch information
someshfengde committed Mar 28, 2024
1 parent 05594e9 commit d06c28f
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 45 deletions.
8 changes: 6 additions & 2 deletions backend/app/api/endpoints/search_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
@version(1, 0)
async def get_search_results(
query: str = "",
routecategory: RouteCategory =RouteCategory.PBW# RouteCategory.NS
routecategory: RouteCategory = RouteCategory.PBW, # RouteCategory.NS,
ragas_experimentation: bool = True
) -> JSONResponse:
if trace_transaction := sentry_sdk.Hub.current.scope.transaction:
trace_transaction.set_tag("title", 'api_get_search_results')
Expand All @@ -51,7 +52,10 @@ async def get_search_results(
search_result = json.loads(search_result)
logger.info(f"get_search_results. cached_result: {search_result}")
else:
search_result = await orchestrator.query_and_get_answer(search_text=query, routecategory=routecategory)
search_result = await orchestrator.query_and_get_answer(
search_text=query,
routecategory=routecategory,
ragas_experimentation=ragas_experimentation)
await cache.set_value(cache_key, json.dumps(search_result))

await cache.add_to_sorted_set("searched_queries", query)
Expand Down
2 changes: 1 addition & 1 deletion backend/app/rag/retrieval/pubmed/pubmedqueryengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, config):
sparse_top_k=int(QDRANT_SPARSE_TOP_K),
vector_store_query_mode=VectorStoreQueryMode.HYBRID,
embed_model=TextEmbeddingsInference(base_url=EMBEDDING_MODEL_API, model_name="")
)z
)

async def call_pubmed_vectors(self, search_text: str) -> List[BaseNode]:
logger.info(
Expand Down
36 changes: 14 additions & 22 deletions backend/app/router/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
logger = setup_logger("Orchestrator")
TAG_RE = re.compile(r'<[^>]+>')


class Orchestrator:
"""
Orchestrator is responsible for routing the search engine query.
It routes the query into three routes now.The first one is clinical trails, second one is drug related information,
and third one is pubmed brave.
"""

def __init__(self, config):
self.config = config



self.llm = dspy.OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY))
dspy.settings.configure(lm = self.llm)
self.router = Router_module()
Expand All @@ -43,18 +40,13 @@ def __init__(self, config):
self.drugChemblSearch = DrugChEMBLText2CypherEngine(config)
self.pubmedsearch = PubmedSearchQueryEngine(config)
self.bravesearch = BraveSearchQueryEngine(config)

self.ragas_pubmed = {}


def store_ragas_pubmed(self):
with open("ragas_pubmed_results.txt", "a") as f:
f.write(str(self.ragas_pubmed) + "\n")

async def query_and_get_answer(
self,
search_text: str,
routecategory: RouteCategory = RouteCategory.NS
routecategory: RouteCategory = RouteCategory.NS,
ragas_experimentation: bool = False
) -> dict[str, str]:
# search router call
logger.info(
Expand Down Expand Up @@ -120,7 +112,6 @@ async def query_and_get_answer(
stack_info=True,
)

self.ragas_pubmed['question'] = search_text
# if routing fails, sql and cypher calls fail, routing to pubmed or brave
logger.info(
"Orchestrator.query_and_get_answer.router_id Fallback Entered."
Expand All @@ -131,27 +122,28 @@ async def query_and_get_answer(
)

extracted_results = extracted_pubmed_results + extracted_web_results
self.ragas_pubmed['context_pubmed_results'] = extracted_pubmed_results
self.ragas_pubmed['context_web_results'] = extracted_pubmed_results
logger.info(
f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}"
)

# rerank call
reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes(
nodes = extracted_results,
query_bundle=QueryBundle(query_str=search_text))

self.ragas_pubmed['context_reranked_pubmed_web'] = str([TAG_RE.sub('', node.get_content()) for node in reranked_results])
self.ragas_pubmed['context_sources'] = str([node.node.metadata for node in reranked_results])

summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY)))
result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results])
sources = [node.node.metadata for node in reranked_results ]

self.ragas_pubmed['answer'] = result

self.store_ragas_pubmed()
if ragas_experimentation:
self.ragas_pubmed['question'] = search_text
self.ragas_pubmed['context_pubmed_results'] = extracted_pubmed_results
self.ragas_pubmed['context_web_results'] = extracted_web_results
self.ragas_pubmed['context_reranked_pubmed_web'] = str([TAG_RE.sub('', node.get_content()) for node in reranked_results])
self.ragas_pubmed['context_sources'] = str([node.node.metadata for node in reranked_results])
self.ragas_pubmed['answer'] = result
return {
"result" : self.ragas_pubmed,
"sources": {}
}
return {
"result" : result,
"sources": sources
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,43 +1,39 @@
#%%
import requests
from tqdm import tqdm
import pandas as pd
import json
#%%
df = pd.read_csv("pubmed_eval_data.csv")

#%%

df = pd.read_csv("pubmed_eval_data.csv")
pubmed_data_debug = pd.DataFrame()

for i in tqdm(df.iterrows()):
url = f"http://127.0.0.1:8000/search?query={i[1]['Question']}"
data_store = {
url: str = f"http://127.0.0.1:8000/search?query={i[1]['Question']}"
data_store: dict = {
"question": i[1]['Question'],
"answer": i[1]['Answer'],
"ground_truth": i[1]['Answer'],
"study": i[1]['Study Title'],
"link": i[1]['Link'],
"source": i[1]['Source']
}
payload = {}
payload: dict = {}
headers = {
'Authorization': 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJjdXJpZW8iLCJqdGkiOiI2ZmI2MTEyNS1jZGU1LTQ2MDAtYWE2MS1jMjBiYzEwNmRhNDMiLCJ0eXBlIjoiYWNjZXNzIiwiZnJlc2giOmZhbHNlLCJpYXQiOjE3MTA3NjUzMDYsImV4cCI6MTcxMDc2NjIwNi4yNzg4Mzd9.nmuzrzmr81ulI8TDauwx19QvLHFi8nXJeUgrEVzfhXs'
}
tries = 0
while tries < 1:
try:
tries += 1
response = requests.request("GET", url, headers=headers, data=payload)
json_str = response.text.replace("'", '"').strip('"')
data = json.loads(json_str)
if len(data['results']) > 0:
data_store['refined_results'] = data['result']
break
new_data_store = {
**data_store,
**dict(response.json()['result'])
}
pubmed_data_debug = pd.concat([pubmed_data_debug, pd.DataFrame([new_data_store])], ignore_index= True )
except:
tries += 1
print(' no results for ', i[1]['Question'])
continue
if data_store.get("refined_results", None) is not None:
pubmed_data_debug = pd.concat([pubmed_data_debug, pd.DataFrame([data_store])], ignore_index= True )
else:
data_store['refined_results'] = ""
pubmed_data_debug = pd.concat([pubmed_data_debug, pd.DataFrame([data_store])], ignore_index= True)


pubmed_data_debug.to_csv("pubmed_eval_data_debug.csv", index = False)


Expand All @@ -52,3 +48,5 @@
# pubmed_data = pd.concat([cancer_pubmed, brain_hemorrhage_pubmed, bioinformatics_pubmed, brain_damage_pubmed])
# # %%
# pubmed_data.reset_index(drop = True).to_csv("pubmed_eval_data.csv", index = False)

# %%

0 comments on commit d06c28f

Please sign in to comment.