Skip to content

Commit

Permalink
pubmed integration docs change
Browse files Browse the repository at this point in the history
  • Loading branch information
raahulrahl committed Mar 26, 2024
1 parent ec3de1f commit 353c4d0
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 45 deletions.
19 changes: 11 additions & 8 deletions backend/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ POSTGRES_ENGINE=

# Drug ChEMBL Configuration
## need to be set before running the server
NEBULA_GRAPH_HOST=
NEBULA_GRAPH_PORT=
NEBULA_GRAPH_USER=
NEBULA_GRAPH_PASSWORD=
NEBULA_GRAPH_SPACE=

# Redis Configuration
## need to be set before running the server
Expand All @@ -70,14 +67,20 @@ JWT_ALGORITHM=
# WANDB Configuration
## need to be set before running the server
WANDB_API_KEY=
WANDB_PROJECT=
WANDB_ENTITY=
WANDB_NOTE=

# Sentry Configuration
## need to be set before running the server
SENTRY_DSN=

# GROQ API onfiguration
# GROQ API Configuration
## need to be set before running the server
GROQ_API_KEY=
GROQ_API_KEY=

# POSTHOG API Configuration
## need to be set before running the server
POSTHOG_API_KEY=
POSTHOG_HOST=

# QDRANT API Configuration
QDRANT_API_KEY=
QDRANT_TOP_K=
34 changes: 16 additions & 18 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Load environment variables from .env file
config = Config(".env")
ENVIRONMENT = config('ENVIRONMENT', default='production')

# Project Settings
DEBUG: bool = config("DEBUG", cast=bool, default=True)
Expand All @@ -15,8 +16,6 @@
)
PROMPT_LANGUAGE: str = config("PROMPT_LANGUAGE", default="en-US")



# SEARCH API Configuration
## The internal location of the Search Backend.
## Used for doing calls to the Search Backend service.
Expand All @@ -37,15 +36,13 @@
"SEARCH_PUBLIC_BASE_URL", default=SEARCH_PUBLIC_BASE_URL
)


# BRAVE SEARCH API Configuration
BRAVE_SEARCH_API: str = config(
"BRAVE_SEARCH_API_ROOT", default="https://api.search.brave.com/res/v1/web/search"
)
BRAVE_SUBSCRIPTION_KEY: Secret = config("BRAVE_SUBSCRIPTION_KEY", cast=Secret)
BRAVE_RESULT_COUNT: int = config("BRAVE_RESULT_COUNT", default=10)


# LLM SERVICE Configuration
LLM_SERVICE_PROVIDER: str = config("LLM_SERVICE_PROVIDER", default="togetherai")

Expand All @@ -68,7 +65,6 @@
# OpenAI API Configuration
OPENAI_API_KEY: Secret = config("OPENAI_API_KEY", cast=Secret)


# Embeddings Model Configuration
EMBEDDING_CHUNK_SIZE: int = config("EMBEDDING_CHUNK_SIZE", default=512)
EMBEDDING_MODEL_API: str = config("EMBEDDING_MODEL_API", default="http://127.0.0.1:8080")
Expand All @@ -89,46 +85,48 @@
## table info dir
DRUG_CHEMBL_TABLE_INFO_DIR: str = "app/rag/retrieval/drug_chembl/ChEMBLTableQuestions_TableInfo"


#Dspy programs
CLINICAL_TRIAL_SQL_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_sql_generation.json"
CLINICAL_TRIALS_RESPONSE_REFINEMENT_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_response_refinement.json"
ORCHESRATOR_ROUTER_PROMPT_PROGRAM: str = "app/dspy_integration/dspy_programs/orchestrator_router_prompt.json"

## NEBULA GRAPH Configuration
NEBULA_GRAPH_HOST: str = config("NEBULA_GRAPH_HOST", default="http://127.0.0.1")
NEBULA_GRAPH_PORT: str = config("NEBULA_GRAPH_PORT", default="9669")
NEBULA_GRAPH_USER: Secret = config("NEBULA_GRAPH_USER", cast=Secret)
NEBULA_GRAPH_PASSWORD: Secret = config("NEBULA_GRAPH_PASSWORD", cast=Secret)
NEBULA_GRAPH_SPACE: str = config("NEBULA_GRAPH_SPACE", default="chembl")


# REDIS Configuration
REDIS_URL: Secret = config("REDIS_URL", cast=Secret)
CACHE_MAX_AGE: str = config("SEARCH_CACHE_MAX_AGE", default="86400")
CACHE_MAX_SORTED_SET: int = config("CACHE_MAX_SORTED_SET", default=100)


# JWT Configuration
## JWT_SECRET_KEY key used to validate RS256 signed JWTs.
## Can also be shared secret for HS256 signed JWTs.
JWT_SECRET_KEY: Secret = config("JWT_SECRET_KEY", cast=Secret)

## Algorithm used to sign JWT. Can be RS256, HS256 and None.
JWT_ALGORITHM: str = config("JWT_ALGORITHM", default="HS256")


# WANDB Configuration
WANDB_API_KEY: Secret = config("WANDB_API_KEY", cast=Secret)
WANDB_PROJECT: str = config("WANDB_PROJECT", default="pe_router")
WANDB_ENTITY: str = config("WANDB_ENTITY", default="curieo")
WANDB_NOTE: str = config("WANDB_NOTE", default="Curieo Search")


# Sentry Configuration
SENTRY_DSN: Secret = config("SENTRY_DSN", cast=Secret)
SENTRY_ENABLE_TRACING: bool = config("SENTRY_ENABLE_TRACING", cast=bool, default=False)


# GROQ API Configuration
GROQ_API_KEY: Secret = config("GROQ_API_KEY", cast=Secret)

# QDRANT API configuration
QDRANT_API_KEY: Secret = config("QDRANT_API_KEY", cast=Secret)
QDRANT_API_PORT: int = config("QDRANT_API_PORT", default=6333)
if ENVIRONMENT == 'local':
QDRANT_API_URL = config("QDRANT_API_URL", default="localhost")
else:
QDRANT_API_URL = config("QDRANT_API_URL", default="https://ff1f8e90-959e-4cff-9455-03914d8a7002.europe-west3-0.gcp.cloud.qdrant.io")
QDRANT_COLLECTION_NAME: str = config("QDRANT_COLLECTION_NAME", default="pubmed_hybrid_vector_db")
QDRANT_TOP_K: int = config("QDRANT_TOP_K", default=20)
QDRANT_SPARSE_TOP_K: int = config("QDRANT_SPARSE_TOP_K", default=3)

# LLAMA_INDEX Configuration
CHAT_ENABLED: bool = config("CHAT_ENABLED", default=False)
PUBMED_RELEVANCE_CRITERIA: float = config("PUBMED_RELEVANCE_CRITERIA", default=0.7)
57 changes: 54 additions & 3 deletions backend/app/rag/retrieval/pubmed/pubmedqueryengine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,56 @@
from typing import List

from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from llama_index.core.schema import BaseNode
from qdrant_client import QdrantClient

from app.services.search_utility import setup_logger
from app.config import QDRANT_API_KEY, QDRANT_API_URL, QDRANT_API_PORT, QDRANT_COLLECTION_NAME, QDRANT_TOP_K, QDRANT_SPARSE_TOP_K, EMBEDDING_MODEL_API, PUBMED_RELEVANCE_CRITERIA

logger = setup_logger("PubmedSearchQueryEngine")


class PubmedSearchQueryEngine:
""" """
"""
This class implements the logic of call pubmed vector database.
It calls the pubmed vector database and processes the data and returns the result.
"""

def __init__(self, config):
self.config = config

self.client = QdrantClient(
url=QDRANT_API_URL,
port=QDRANT_API_PORT,
api_key=str(QDRANT_API_KEY),
https=False
)

self.vector_store = QdrantVectorStore(
client=self.client,
collection_name=QDRANT_COLLECTION_NAME,
enable_hybrid = True,
batch_size=20)

self.retriever = VectorIndexRetriever(
index=VectorStoreIndex.from_vector_store(vector_store=self.vector_store),
similarity_top_k=int(QDRANT_TOP_K),
sparse_top_k=int(QDRANT_SPARSE_TOP_K),
vector_store_query_mode=VectorStoreQueryMode.HYBRID,
embed_model=TextEmbeddingsInference(base_url=EMBEDDING_MODEL_API, model_name="")
)

async def call_pubmed_vectors(self, search_text: str) -> List[BaseNode]:
logger.info(
"PubmedSearchQueryEngine.call_pubmed_vectors query: " + search_text
)

def query_and_get_answer(self, search_text):
print()
try:
response = [eachNode.node for eachNode in self.retriever.retrieve(search_text) if eachNode.score >= float(PUBMED_RELEVANCE_CRITERIA)]
except Exception as ex:
raise ex
return response
39 changes: 23 additions & 16 deletions backend/app/rag/retrieval/web/brave_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
from typing import List
import requests

from llama_index.core.schema import TextNode

from app.services.search_utility import setup_logger
from app.config import BRAVE_RESULT_COUNT, BRAVE_SEARCH_API, BRAVE_SUBSCRIPTION_KEY

Expand All @@ -9,18 +11,21 @@

class BraveSearchQueryEngine:
"""
This class implements the logic brave search api and returns the results.
It calls the brave api and processes the data and returns the result.
The BraveSearchQueryEngine class is a utility for interacting with the Brave search API within a larger application framework,
likely aimed at providing search capabilities or integrating search results into an application's functionality.
It abstracts the details of making API requests, handling responses, and error logging, providing a simple interface
(call_brave_search_api) for obtaining processed search results in an asynchronous manner.
This class leverages a configuration object for flexibility, allowing it to adapt to different settings or
requirements without changing the core implementation.
"""

def __init__(self, config):
self.config = config

#@storage_cached('brave_search_website', 'search_text')
async def call_brave_search_api(
self,
search_text: str
) -> collections.defaultdict[list]:
) -> List[TextNode]:
logger.info("call_brave_search_api. query: " + search_text)

endpoint = "{url_address}?count={count}&q={search_text}&search_lang=en&extra_snippets=True".format(
Expand All @@ -34,7 +39,7 @@ async def call_brave_search_api(
'Accept-Encoding': 'gzip',
'X-Subscription-Token': str(BRAVE_SUBSCRIPTION_KEY)
}
results = collections.defaultdict(list)
results = []

try:
logger.info("call_brave_search_api. endpoint: " + endpoint)
Expand All @@ -43,20 +48,22 @@ async def call_brave_search_api(
response = requests.get(endpoint, headers=headers)
response.raise_for_status()
web_response = response.json().get('web').get('results')
i = 0

if web_response:
for resp in web_response:
detailed_text = resp.get('description') + ''.join(resp.get('extra_snippets') if resp.get('extra_snippets') else '')
results[i] = {
"text": detailed_text,
"url": resp['url'],
"page_age": resp.get('page_age')
}
i = i + 1
results = [
TextNode(
text=resp.get('description') + ''.join(resp.get('extra_snippets') if resp.get('extra_snippets') else ''),
metadata={
"url": resp['url'],
"page_age": resp.get('page_age')
}
)
for resp in web_response
]

except Exception as ex:
logger.exception("call_brave_search_api Exception -", exc_info=ex, stack_info=True)
raise ex

logger.info("call_brave_search_api. result: " + str(results))
return results
return results
2 changes: 2 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ nebula3-python = "^3.5.0"
sentry-sdk = "^1.43.0"
dspy = { git = "https://[email protected]/curieo-org/dspy.git" }
setuptools = "^69.2.0"
llama-index-llms-together = "^0.1.3"
llama-index-postprocessor-cohere-rerank = "^0.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
Expand Down

0 comments on commit 353c4d0

Please sign in to comment.