-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gene retrieval via API #7
base: main
Are you sure you want to change the base?
Changes from all commits
e7476df
a4eb270
3749b33
a719456
a149c13
0dc40c3
6d30f24
9df5b98
7a7ed70
dda1bca
43a3218
b18d3d6
c0e4881
a589308
824b3c2
1205024
54943dd
3a92807
694b729
48f8fb9
4844320
6a7b4f1
4f8368e
4e5bb42
643cc6f
f4299b1
397d631
01c887f
1b3b262
f4fa800
7c6619d
58ff2a6
f30352f
7a9edd7
491d950
6067989
3c60c9a
00de80f
6947b3b
6bbd355
f9e0b39
187de6f
00c7929
dcdafc3
3fcf39a
a7aeb6b
7aa2698
cdb23fc
7e3631e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
from typing import Optional | ||
from fastapi import FastAPI, HTTPException | ||
from pydantic import BaseModel, Field | ||
import pandas as pd | ||
from loguru import logger | ||
|
||
# Import the key functions from the existing codebase | ||
from procyon.inference.retrieval_utils import startup_retrieval, do_retrieval | ||
|
||
app = FastAPI() | ||
|
||
# Global variables to store model and device | ||
model = None | ||
device = None | ||
data_args = None | ||
all_protein_embeddings = None | ||
|
||
|
||
class RetrievalRequest(BaseModel): | ||
task_desc: str = Field(description="The task description.") | ||
disease_desc: str = Field(description="The disease description.") | ||
instruction_source_dataset: str = Field( | ||
description="Dataset source for instructions - either 'disgenet' or 'omim'" | ||
) | ||
k: Optional[int] = Field( | ||
default=None, | ||
description="Number of top results to return. If None, returns all results", | ||
ge=1, | ||
) | ||
|
||
|
||
@app.on_event("startup") | ||
async def startup_event(): | ||
"""Initialize the model and required components on startup""" | ||
global model, device, data_args, all_protein_embeddings | ||
|
||
if not os.getenv("HF_TOKEN"): | ||
raise EnvironmentError("HF_TOKEN environment variable not set") | ||
if not os.getenv("CHECKPOINT_PATH"): | ||
raise EnvironmentError("CHECKPOINT_PATH environment variable not set") | ||
if not os.getenv("HOME_DIR"): | ||
raise EnvironmentError("HOME_DIR environment variable not set") | ||
if not os.getenv("DATA_DIR"): | ||
raise EnvironmentError("DATA_DIR environment variable not set") | ||
if not os.getenv("LLAMA3_PATH"): | ||
raise EnvironmentError("LLAMA3_PATH environment variable not set") | ||
|
||
# Use the existing startup_retrieval function | ||
model, device, data_args, all_protein_embeddings = startup_retrieval( | ||
inference_bool=True | ||
) | ||
logger.info("Model loaded and ready") | ||
|
||
|
||
@app.post("/retrieve") | ||
async def retrieve_proteins(request: RetrievalRequest): | ||
"""Endpoint to perform protein retrieval""" | ||
global model, device, data_args, all_protein_embeddings | ||
|
||
if not all([model, device, data_args, all_protein_embeddings]): | ||
raise HTTPException(status_code=500, detail="Model not initialized") | ||
|
||
# Use the existing do_retrieval function | ||
results_df = do_retrieval( | ||
model=model, | ||
data_args=data_args, | ||
device=device, | ||
instruction_source_dataset=request.instruction_source_dataset, | ||
all_protein_embeddings=all_protein_embeddings, | ||
task_desc=request.task_desc, | ||
disease_desc=request.disease_desc, | ||
) | ||
|
||
results_df = results_df.fillna("") | ||
|
||
# Return all results if k is None, otherwise return top k | ||
if request.k is None: | ||
return {"results": results_df.to_dict(orient="records")} | ||
return {"results": results_df.head(request.k).to_dict(orient="records")} | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
This API endpoint will allow users to perform protein retrieval for a given disease description using the | ||
pre-trained ProCyon model Procyon-Full. | ||
This API script can be run directly using the command `python main.py` | ||
this script will start the FastAPI server on port 8000 | ||
The API will be available at http://localhost:8000 | ||
An example request can be made using curl: | ||
curl -X POST "http://localhost:8000/retrieve" \ | ||
-H "Content-Type: application/json" \ | ||
-d '{"task_desc": "Find proteins related to this disease", | ||
"disease_desc": "Major depressive disorder", | ||
"instruction_source_dataset": "disgenet", | ||
"k": 1000}' | ||
""" | ||
import uvicorn | ||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from huggingface_hub import login as hf_login | ||
from loguru import logger | ||
import pandas as pd | ||
from typing import Dict, Optional, Tuple, Union | ||
import torch | ||
|
||
from procyon.data.inference_utils import ( | ||
create_input_retrieval, | ||
get_proteins_from_embedding, | ||
) | ||
from procyon.evaluate.framework.utils import move_inputs_to_device | ||
from procyon.model.model_unified import UnifiedProCyon | ||
from procyon.training.train_utils import DataArgs | ||
|
||
CKPT_NAME = os.path.expanduser(os.getenv("CHECKPOINT_PATH")) | ||
|
||
|
||
def startup_retrieval( | ||
inference_bool: bool = True, | ||
) -> Tuple[ | ||
Union[UnifiedProCyon, None], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd been using e.g. |
||
Union[torch.device, None], | ||
Union[DataArgs, None], | ||
torch.Tensor, | ||
]: | ||
""" | ||
This function performs startup functions to initiate protein retrieval: | ||
Logs into the huggingface hub and loads the pre-trained ProCyon model. | ||
Args: | ||
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference; | ||
then the model will not be loaded. | ||
Returns: | ||
model (UnifiedProCyon): The pre-trained ProCyon model | ||
device (torch.device): The compute device (GPU or CPU) on which the model is loaded | ||
data_args (DataArgs): The data arguments defined by the pre-trained model | ||
all_protein_embeddings (torch.Tensor): The pre-calculated protein target embeddings | ||
""" | ||
logger.info("Now running startup functions for protein retrieval") | ||
|
||
logger.info("Now logging into huggingface hub") | ||
hf_login(token=os.getenv("HF_TOKEN")) | ||
logger.info("Done logging into huggingface hub") | ||
|
||
if inference_bool: | ||
logger.info("Inference is enabled.") | ||
|
||
# load the pre-trained ProCyon model | ||
model, device, data_args = load_model_onto_device() | ||
else: | ||
logger.info("Inference is disabled.") | ||
# loading the model takes much time and memory, so we skip it if we don't need it | ||
model = None | ||
device = None | ||
data_args = None | ||
|
||
# Load the pre-calculated protein target embeddings | ||
logger.info("Now loading protein target embeddings") | ||
all_protein_embeddings, all_protein_ids = torch.load( | ||
os.path.join(CKPT_NAME, "protein_target_embeddings.pkl") | ||
) | ||
all_protein_embeddings = all_protein_embeddings.float() | ||
logger.info( | ||
f"shape of precalculated embeddings matrix: {all_protein_embeddings.shape}" | ||
) | ||
logger.info("Done loading protein target embeddings") | ||
logger.info("Done running startup functions for protein retrieval") | ||
|
||
return model, device, data_args, all_protein_embeddings | ||
|
||
|
||
def load_model_onto_device() -> Tuple[UnifiedProCyon, torch.device, DataArgs]: | ||
""" | ||
Load the pre-trained ProCyon model and move it to the compute device. | ||
Returns: | ||
model (UnifiedProCyon): The pre-trained ProCyon model | ||
device (torch.device): The compute device (GPU or CPU) on which the model is loaded | ||
data_args (DataArgs): The data arguments defined by the pre-trained model | ||
""" | ||
# Load the pre-trained ProCyon model | ||
logger.info("Now loading pretrained model") | ||
# Replace with the path where you downloaded a pre-trained ProCyon model (e.g. ProCyon-Full) | ||
data_args = torch.load(os.path.join(CKPT_NAME, "data_args.pt")) | ||
model, _ = UnifiedProCyon.from_pretrained(checkpoint_dir=CKPT_NAME) | ||
logger.info("Done loading pretrained model") | ||
|
||
logger.info("Now quantizing the model to a smaller precision") | ||
model.bfloat16() # Quantize the model to a smaller precision | ||
logger.info("Done quantizing the model to a smaller precision") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possibly personal opinion, but the logging seems quite verbose. It might be nice to log some of these messages at the debug level instead and add a command-line option to turn them on. I haven't used |
||
|
||
logger.info("Now setting the model to evaluation mode") | ||
model.eval() | ||
logger.info("Done setting the model to evaluation mode") | ||
|
||
logger.info("Now applying pretrained model to device") | ||
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}") | ||
# identify available devices on the machine | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.to(device) | ||
logger.info(f"Total memory allocated by PyTorch: {torch.cuda.memory_allocated()}") | ||
|
||
logger.info("Done loading model and applying it to compute device") | ||
|
||
return model, device, data_args | ||
|
||
|
||
def do_retrieval( | ||
model: UnifiedProCyon, | ||
data_args: DataArgs, | ||
device: torch.device, | ||
instruction_source_dataset: str, | ||
all_protein_embeddings: torch.Tensor, | ||
inference_bool: bool = True, | ||
task_desc_infile: Path = None, | ||
disease_desc_infile: Path = None, | ||
task_desc: str = None, | ||
disease_desc: str = None, | ||
) -> Optional[pd.DataFrame]: | ||
""" | ||
This function performs protein retrieval for a given disease using the pre-trained ProCyon model. | ||
Args: | ||
model (UnifiedProCyon): The pre-trained ProCyon model | ||
data_args (DataArgs): The data arguments defined by the pre-trained model | ||
device (torch.device): The compute device (GPU or CPU) on which the model is loaded | ||
instruction_source_dataset (str): Dataset source for instructions - either "disgenet" or "omim" | ||
all_protein_embeddings (torch.Tensor): The pre-calculated protein target embeddings | ||
inference_bool (bool): OPTIONAL; choose this if you do not intend to do inference | ||
task_desc_infile (Path): The path to the file containing the task description. | ||
disease_desc_infile (Path): The path to the file containing the disease description. | ||
task_desc (str): The task description. | ||
disease_desc (str): The disease description. | ||
Returns: | ||
df_dep (pd.DataFrame): The DataFrame containing the top protein retrieval results | ||
""" | ||
logger.info("Now performing protein retrieval") | ||
|
||
if instruction_source_dataset not in ["disgenet", "omim"]: | ||
raise ValueError( | ||
'instruction_source_dataset must be either "disgenet" or "omim"' | ||
) | ||
|
||
logger.info("entering task description and prompt") | ||
if task_desc_infile is not None: | ||
if task_desc is not None: | ||
raise ValueError( | ||
"Only one of task_desc_infile and task_desc can be provided." | ||
) | ||
# read the task description from a file | ||
with open(task_desc_infile, "r") as f: | ||
task_desc = f.read() | ||
elif task_desc is None: | ||
raise ValueError("Either task_desc_infile or task_desc must be provided.") | ||
|
||
if disease_desc_infile is not None: | ||
if disease_desc is not None: | ||
raise ValueError( | ||
"Only one of disease_desc_infile and disease_desc can be provided." | ||
) | ||
# read the disease description from a file | ||
with open(disease_desc_infile, "r") as f: | ||
disease_desc = f.read() | ||
elif disease_desc is None: | ||
raise ValueError("Either disease_desc_infile or disease_desc must be provided.") | ||
|
||
task_desc = task_desc.replace("\n", " ") | ||
disease_desc = disease_desc.replace("\n", " ") | ||
logger.info("Done entering task description and prompt") | ||
|
||
if inference_bool: | ||
logger.info("Now performing protein retrieval") | ||
|
||
# Create input for retrieval | ||
input_simple = create_input_retrieval( | ||
input_description=disease_desc, | ||
data_args=data_args, | ||
task_definition=task_desc, | ||
instruction_source_dataset=instruction_source_dataset, | ||
instruction_source_relation="all", | ||
aaseq_type="protein", | ||
icl_example_number=1, # 0, 1, 2 | ||
) | ||
|
||
input_simple = move_inputs_to_device(input_simple, device=device) | ||
with torch.no_grad(): | ||
model_out = model( | ||
inputs=input_simple, | ||
retrieval=True, | ||
aaseq_type="protein", | ||
) | ||
# The script can run up to here without a GPU, but the following line requires a GPU | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if running this on the GPU is a headache. If so, we can change |
||
df_dep = get_proteins_from_embedding( | ||
all_protein_embeddings, model_out, top_k=None | ||
) | ||
|
||
logger.info("Done performing protein retrieval") | ||
|
||
return df_dep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.