Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gene retrieval via API #7

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e7476df
initial commit
Vincent-Ustach Jan 16, 2025
a4eb270
store false for inference_bool
Vincent-Ustach Jan 16, 2025
3749b33
require task_desc_infile
Vincent-Ustach Jan 16, 2025
a719456
task_desc_infile pathlib path
Vincent-Ustach Jan 16, 2025
a149c13
inference_bool default True
Vincent-Ustach Jan 16, 2025
0dc40c3
inference_bool typo
Vincent-Ustach Jan 16, 2025
6d30f24
import DataArgs
Vincent-Ustach Jan 16, 2025
9df5b98
set model device data_args to None
Vincent-Ustach Jan 16, 2025
7a7ed70
Merge pull request #2 from mims-harvard/main
Vincent-Ustach Jan 17, 2025
dda1bca
read disease desc
Vincent-Ustach Jan 17, 2025
43a3218
Merge branch 'refs/heads/main' into retrieval_script
Vincent-Ustach Jan 17, 2025
b18d3d6
reduce precision of model before loading to device
Vincent-Ustach Jan 17, 2025
c0e4881
copy retrieval script
Vincent-Ustach Jan 17, 2025
a589308
rename files
Vincent-Ustach Jan 17, 2025
824b3c2
if no model load no create_input_retrieval
Vincent-Ustach Jan 17, 2025
1205024
remove unused imports
Vincent-Ustach Jan 17, 2025
54943dd
if no model no create_input_retrieval
Vincent-Ustach Jan 17, 2025
3a92807
remove unused imports
Vincent-Ustach Jan 17, 2025
694b729
black formatting
Vincent-Ustach Jan 17, 2025
48f8fb9
refactor with startup and do methods for later api
Vincent-Ustach Jan 17, 2025
4844320
docstrings
Vincent-Ustach Jan 17, 2025
6a7b4f1
bug in calling startup_retrieval
Vincent-Ustach Jan 17, 2025
4f8368e
fastapi app
Vincent-Ustach Jan 17, 2025
4e5bb42
add more required env vars
Vincent-Ustach Jan 17, 2025
643cc6f
move utils to inference/retrieval_utils.py
Vincent-Ustach Jan 17, 2025
f4299b1
fix imports for retrieval_utils
Vincent-Ustach Jan 17, 2025
397d631
typo
Vincent-Ustach Jan 17, 2025
01c887f
update comment
Vincent-Ustach Jan 17, 2025
1b3b262
instruction_source_dataset passed as argument
Vincent-Ustach Jan 18, 2025
f4fa800
update docstring for app
Vincent-Ustach Jan 18, 2025
7c6619d
bug w repeated model.to_device() commands
Vincent-Ustach Jan 19, 2025
58ff2a6
return top k
Vincent-Ustach Jan 19, 2025
f30352f
by default return all records
Vincent-Ustach Jan 19, 2025
7a9edd7
fillna
Vincent-Ustach Jan 19, 2025
491d950
remove disease from input description. remove unused imports in app.
Vincent-Ustach Jan 22, 2025
6067989
remove disease from input description. remove unused imports in app.
Vincent-Ustach Jan 22, 2025
3c60c9a
delete drug script
Vincent-Ustach Jan 22, 2025
00de80f
update docstrings
Vincent-Ustach Jan 22, 2025
6947b3b
remove exception catch
Vincent-Ustach Jan 22, 2025
6bbd355
move args in script
Vincent-Ustach Jan 22, 2025
f9e0b39
move args in call to do_retrieval
Vincent-Ustach Jan 22, 2025
187de6f
move args in app call to do_retrieval
Vincent-Ustach Jan 22, 2025
00c7929
try again w args in script
Vincent-Ustach Jan 22, 2025
dcdafc3
default of omim
Vincent-Ustach Jan 22, 2025
3fcf39a
declare all_protein_embeddings as global in retrieve_proteins
Vincent-Ustach Jan 22, 2025
a7aeb6b
remove exceptions from highest level of app
Vincent-Ustach Jan 22, 2025
7aa2698
duplicate loguru removed from pyproject.toml
Vincent-Ustach Jan 22, 2025
cdb23fc
black formatting
Vincent-Ustach Jan 22, 2025
7e3631e
Merge pull request #1 from GeneDx/retrieval_script
Vincent-Ustach Jan 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions procyon/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import pandas as pd
from loguru import logger

# Import the key functions from the existing codebase
from procyon.inference.retrieval_utils import startup_retrieval, do_retrieval

app = FastAPI()

# Global variables to store model and device
model = None
device = None
data_args = None
all_protein_embeddings = None


class RetrievalRequest(BaseModel):
task_desc: str = Field(description="The task description.")
disease_desc: str = Field(description="The disease description.")
instruction_source_dataset: str = Field(
description="Dataset source for instructions - either 'disgenet' or 'omim'"
)
k: Optional[int] = Field(
default=None,
description="Number of top results to return. If None, returns all results",
ge=1,
)


@app.on_event("startup")
async def startup_event():
"""Initialize the model and required components on startup"""
global model, device, data_args, all_protein_embeddings

if not os.getenv("HF_TOKEN"):
raise EnvironmentError("HF_TOKEN environment variable not set")
if not os.getenv("CHECKPOINT_PATH"):
raise EnvironmentError("CHECKPOINT_PATH environment variable not set")
if not os.getenv("HOME_DIR"):
raise EnvironmentError("HOME_DIR environment variable not set")
if not os.getenv("DATA_DIR"):
raise EnvironmentError("DATA_DIR environment variable not set")
if not os.getenv("LLAMA3_PATH"):
raise EnvironmentError("LLAMA3_PATH environment variable not set")

# Use the existing startup_retrieval function
model, device, data_args, all_protein_embeddings = startup_retrieval(
inference_bool=True
)
logger.info("Model loaded and ready")


@app.post("/retrieve")
async def retrieve_proteins(request: RetrievalRequest):
"""Endpoint to perform protein retrieval"""
global model, device, data_args, all_protein_embeddings

if not all([model, device, data_args, all_protein_embeddings]):
raise HTTPException(status_code=500, detail="Model not initialized")

# Use the existing do_retrieval function
results_df = do_retrieval(
model=model,
data_args=data_args,
device=device,
instruction_source_dataset=request.instruction_source_dataset,
all_protein_embeddings=all_protein_embeddings,
task_desc=request.task_desc,
disease_desc=request.disease_desc,
)

results_df = results_df.fillna("")

# Return all results if k is None, otherwise return top k
if request.k is None:
return {"results": results_df.to_dict(orient="records")}
return {"results": results_df.head(request.k).to_dict(orient="records")}


if __name__ == "__main__":
"""
This API endpoint will allow users to perform protein retrieval for a given disease description using the
pre-trained ProCyon model Procyon-Full.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pre-trained ProCyon model Procyon-Full.
pre-trained ProCyon model ProCyon-Full.

This API script can be run directly using the command `python main.py`
this script will start the FastAPI server on port 8000
The API will be available at http://localhost:8000
An example request can be made using curl:
curl -X POST "http://localhost:8000/retrieve" \
-H "Content-Type: application/json" \
-d '{"task_desc": "Find proteins related to this disease",
"disease_desc": "Major depressive disorder",
"instruction_source_dataset": "disgenet",
"k": 1000}'
"""
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
199 changes: 199 additions & 0 deletions procyon/inference/retrieval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import os
from pathlib import Path

from huggingface_hub import login as hf_login
from loguru import logger
import pandas as pd
from typing import Dict, Optional, Tuple, Union
import torch

from procyon.data.inference_utils import (
create_input_retrieval,
get_proteins_from_embedding,
)
from procyon.evaluate.framework.utils import move_inputs_to_device
from procyon.model.model_unified import UnifiedProCyon
from procyon.training.train_utils import DataArgs

CKPT_NAME = os.path.expanduser(os.getenv("CHECKPOINT_PATH"))


def startup_retrieval(
inference_bool: bool = True,
) -> Tuple[
Union[UnifiedProCyon, None],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd been using e.g. Optional[UnifiedProCyon], so googled what's preferred, and it appears the current best practice for a value that could be None is to express it as e.g. UnifiedProCyon | None

Union[torch.device, None],
Union[DataArgs, None],
torch.Tensor,
]:
"""
This function performs startup functions to initiate protein retrieval:
Logs into the huggingface hub and loads the pre-trained ProCyon model.
Args:
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference;
then the model will not be loaded.
Returns:
model (UnifiedProCyon): The pre-trained ProCyon model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
data_args (DataArgs): The data arguments defined by the pre-trained model
all_protein_embeddings (torch.Tensor): The pre-calculated protein target embeddings
"""
logger.info("Now running startup functions for protein retrieval")

logger.info("Now logging into huggingface hub")
hf_login(token=os.getenv("HF_TOKEN"))
logger.info("Done logging into huggingface hub")

if inference_bool:
logger.info("Inference is enabled.")

# load the pre-trained ProCyon model
model, device, data_args = load_model_onto_device()
else:
logger.info("Inference is disabled.")
# loading the model takes much time and memory, so we skip it if we don't need it
model = None
device = None
data_args = None

# Load the pre-calculated protein target embeddings
logger.info("Now loading protein target embeddings")
all_protein_embeddings, all_protein_ids = torch.load(
os.path.join(CKPT_NAME, "protein_target_embeddings.pkl")
)
all_protein_embeddings = all_protein_embeddings.float()
logger.info(
f"shape of precalculated embeddings matrix: {all_protein_embeddings.shape}"
)
logger.info("Done loading protein target embeddings")
logger.info("Done running startup functions for protein retrieval")

return model, device, data_args, all_protein_embeddings


def load_model_onto_device() -> Tuple[UnifiedProCyon, torch.device, DataArgs]:
"""
Load the pre-trained ProCyon model and move it to the compute device.
Returns:
model (UnifiedProCyon): The pre-trained ProCyon model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
data_args (DataArgs): The data arguments defined by the pre-trained model
"""
# Load the pre-trained ProCyon model
logger.info("Now loading pretrained model")
# Replace with the path where you downloaded a pre-trained ProCyon model (e.g. ProCyon-Full)
data_args = torch.load(os.path.join(CKPT_NAME, "data_args.pt"))
model, _ = UnifiedProCyon.from_pretrained(checkpoint_dir=CKPT_NAME)
logger.info("Done loading pretrained model")

logger.info("Now quantizing the model to a smaller precision")
model.bfloat16() # Quantize the model to a smaller precision
logger.info("Done quantizing the model to a smaller precision")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly personal opinion, but the logging seems quite verbose. It might be nice to log some of these messages at the debug level instead and add a command-line option to turn them on. I haven't used loguru before but it looks like some combination of logger.add or logger.remove should do it


logger.info("Now setting the model to evaluation mode")
model.eval()
logger.info("Done setting the model to evaluation mode")

logger.info("Now applying pretrained model to device")
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}")
# identify available devices on the machine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}")

logger.info("Done loading model and applying it to compute device")

return model, device, data_args


def do_retrieval(
model: UnifiedProCyon,
data_args: DataArgs,
device: torch.device,
instruction_source_dataset: str,
all_protein_embeddings: torch.Tensor,
inference_bool: bool = True,
task_desc_infile: Path = None,
disease_desc_infile: Path = None,
task_desc: str = None,
disease_desc: str = None,
) -> Optional[pd.DataFrame]:
"""
This function performs protein retrieval for a given disease using the pre-trained ProCyon model.
Args:
model (UnifiedProCyon): The pre-trained ProCyon model
data_args (DataArgs): The data arguments defined by the pre-trained model
device (torch.device): The compute device (GPU or CPU) on which the model is loaded
instruction_source_dataset (str): Dataset source for instructions - either "disgenet" or "omim"
all_protein_embeddings (torch.Tensor): The pre-calculated protein target embeddings
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference
task_desc_infile (Path): The path to the file containing the task description.
disease_desc_infile (Path): The path to the file containing the disease description.
task_desc (str): The task description.
disease_desc (str): The disease description.
Returns:
df_dep (pd.DataFrame): The DataFrame containing the top protein retrieval results
"""
logger.info("Now performing protein retrieval")

if instruction_source_dataset not in ["disgenet", "omim"]:
raise ValueError(
'instruction_source_dataset must be either "disgenet" or "omim"'
)

logger.info("entering task description and prompt")
if task_desc_infile is not None:
if task_desc is not None:
raise ValueError(
"Only one of task_desc_infile and task_desc can be provided."
)
# read the task description from a file
with open(task_desc_infile, "r") as f:
task_desc = f.read()
elif task_desc is None:
raise ValueError("Either task_desc_infile or task_desc must be provided.")

if disease_desc_infile is not None:
if disease_desc is not None:
raise ValueError(
"Only one of disease_desc_infile and disease_desc can be provided."
)
# read the disease description from a file
with open(disease_desc_infile, "r") as f:
disease_desc = f.read()
elif disease_desc is None:
raise ValueError("Either disease_desc_infile or disease_desc must be provided.")

task_desc = task_desc.replace("\n", " ")
disease_desc = disease_desc.replace("\n", " ")
logger.info("Done entering task description and prompt")

if inference_bool:
logger.info("Now performing protein retrieval")

# Create input for retrieval
input_simple = create_input_retrieval(
input_description=disease_desc,
data_args=data_args,
task_definition=task_desc,
instruction_source_dataset=instruction_source_dataset,
instruction_source_relation="all",
aaseq_type="protein",
icl_example_number=1, # 0, 1, 2
)

input_simple = move_inputs_to_device(input_simple, device=device)
with torch.no_grad():
model_out = model(
inputs=input_simple,
retrieval=True,
aaseq_type="protein",
)
# The script can run up to here without a GPU, but the following line requires a GPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if running this on the GPU is a headache. If so, we can change get_proteins_from_embedding to operate on a specified device, I don't think it has to run on the GPU

df_dep = get_proteins_from_embedding(
all_protein_embeddings, model_out, top_k=None
)

logger.info("Done performing protein retrieval")

return df_dep
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"jupyter==1.0.0",
"kiwisolver==1.4.5",
"llvmlite==0.42.0",
"loguru==0.7.3",
"matplotlib==3.8.3",
"multiprocess==0.70.16",
"networkx==3.2.1",
Expand Down Expand Up @@ -50,6 +51,10 @@ dependencies = [
"virtualenv==20.28.0",
"wandb==0.16.3",
"bert-score>=0.3.13",
"argparse>=1.4.0",
"huggingface-hub==0.23.4",
"fastapi>=0.109.0",
"uvicorn>=0.27.0",
]

[project.optional-dependencies]
Expand Down
Loading