Skip to content

Commit

Permalink
pushing the ragas experimentation
Browse files Browse the repository at this point in the history
  • Loading branch information
someshfengde committed Apr 4, 2024
1 parent d06c28f commit b2a0e68
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, config):
)
self.obj_retriever = self.obj_index.as_retriever(similarity_top_k=3)
self.qp = self.build_query_pipeline()
self.debug_log = {}
self.ragas_clinical_trial = {}

def _get_table_info_with_index(self, idx: int) -> str:
results_gen = Path(CLINICAL_TRIALS_TABLE_INFO_DIR).glob(f"{idx}_*")
Expand Down Expand Up @@ -148,27 +148,21 @@ def extract_sql(self, llm_response: str) -> str:
return extracted_sql
return llm_response

def debug_ragas(self):
with open("debug_log_ragas.txt", "a") as f:
f.write(str(self.debug_log)+ ",\n")



def get_sql_query(self, question, context):
self.debug_log["question"] = question
self.debug_log["table_context"] = context
self.ragas_clinical_trial["question"] = question
self.ragas_clinical_trial["table_context"] = context
sql_query = self.sql_module(question = question, context = context).answer
return sql_query

def get_synthesized_response(self, question, sql, database_output):
if len(database_output) > 0:
database_output = database_output[0].text
# context provided is sql and database output
self.debug_log["sql"] = sql
self.debug_log["database_output"] = database_output
self.ragas_clinical_trial["sql"] = sql
self.ragas_clinical_trial["database_output"] = database_output
with dspy.context(lm=self.nous):
response = self.response_synthesizer(question = question, sql = sql, database_output = database_output).answer
self.debug_log["answer"] = response
self.ragas_clinical_trial["answer"] = response
return response

def build_query_pipeline(self):
Expand Down Expand Up @@ -203,13 +197,15 @@ def build_query_pipeline(self):

async def call_text2sql(
self,
search_text:str
search_text:str,
ragas_experimentation: bool = False
) -> dict[str, str]:
try:
logger.info(f"call_text2sql search_text: {search_text}")
response = self.qp.run(query=search_text)
logger.info(f"call_text2sql response: {str(response)}")
self.debug_ragas()
if ragas_experimentation:
return {"result": str(self.ragas_clinical_trial)}
except Exception as ex:
logger.exception("call_text2sql Exception -", exc_info = ex, stack_info=True)
raise ex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, config):
"Response: "
)
self.qp = self.build_query_pipeline()
self.debug_chembl = {}
self.ragas_chembl = {}

def execute_graph_query(self, queries):
logger.info(
Expand Down Expand Up @@ -111,7 +111,7 @@ def execute_graph_query(self, queries):
result_dict = self.graph_storage.execute_query(query)
results.append(result_dict)

self.debug_chembl['cypher_query'] = str(query_list)
self.ragas_chembl['cypher_query'] = str(query_list)
logger.info(
f"execute_graph_query results: {results}"
)
Expand Down Expand Up @@ -157,14 +157,11 @@ def get_table_context_str(self, table_schema_objs: List[dict[str, str]]) -> str:
table_context += f"{key}: {value}\n"

context_strs.append(table_context)
self.debug_chembl['table_context_str'] = "\n\n".join(context_strs)

return "\n\n".join(context_strs)

# ragas
self.ragas_chembl['table_context_str'] = "\n\n".join(context_strs)

def store_debug(self):
with open("debug_chembl.txt", "a") as f:
f.write(str(self.debug_chembl) + ",\n")
return "\n\n".join(context_strs)

def get_response_synthesis_prompt(
self, query_str, sql_query, context_str
Expand All @@ -188,7 +185,7 @@ def cypher_output_parser(self, response: list[dict[str, list]]) -> str:

response_str += " ## ".join(record_in_list) + "\n"

self.debug_chembl['cypher_response'] = response_str
self.ragas_chembl['cypher_response'] = response_str
logger.info(
f"cypher_output_parser response_str: {response_str}"
)
Expand Down Expand Up @@ -238,19 +235,20 @@ def build_query_pipeline(self):

return qp

async def call_text2cypher(self, search_text:str) -> str:
async def call_text2cypher(
self,
search_text:str,
ragas_experimentation: bool = False) -> str:
try:
logger.info(f"call_text2cypher search_text: {search_text}")
self.debug_chembl['question'] = search_text
response = self.qp.run(query=search_text)

self.debug_chembl['response'] = response

if ragas_experimentation:
self.ragas_chembl['question'] = search_text
self.ragas_chembl['response'] = response
return self.ragas_chembl
logger.info(f"call_text2cypher response: {str(response)}")
self.store_debug()
except Exception as ex:
logger.exception("call_text2cypher Exception -", exc_info = ex, stack_info=True)

raise ex

return response

0 comments on commit b2a0e68

Please sign in to comment.