Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scanpy-pl #226

Merged
merged 17 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions benchmark/data/benchmark_api_calling_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,24 @@ api_calling:
expected:
parts_of_query:
["https://bio.tools/api/t/", "\\?topic=", "[mM]etabolomics"]
- case: scanpy:pl:scatter
input:
prompt:
exact_variable_names: "Make a scatter plot of n_genes_by_counts vs total_counts."
expected:
parts_of_query:
["sc.pl.scatter\\(", "n_genes_by_counts", "total_counts", "\\)"]
- case: scanpy:pl:pca
input:
prompt:
fuzzy_search: "plot the PCA of the data colored by n_genes_by_counts and total_counts."
expected:
parts_of_query:
["sc.pl.pca\\(", ["n_genes_by_counts", "total_counts"], "\\)"]
- case: scanpy:pl:tsne
input:
prompt:
fuzzy_search: "plot the tsne embeddding of the data colored by n_genes_by_counts."
expected:
parts_of_query:
["sc.pl.tsne\\(", "n_genes_by_counts", "\\)"]
71 changes: 63 additions & 8 deletions benchmark/test_api_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from biochatter._misc import ensure_iterable
from biochatter.api_agent import BioToolsQueryBuilder, OncoKBQueryBuilder
from biochatter.api_agent import BioToolsQueryBuilder, OncoKBQueryBuilder, ScanpyPlQueryBuilder, format_as_rest_call, format_as_python_call

from .benchmark_utils import (
get_result_file_path,
Expand All @@ -15,7 +15,7 @@
from .conftest import calculate_bool_vector_score


def test_api_calling(
def test_web_api_calling(
model_name,
test_data_api_calling,
conversation,
Expand All @@ -32,24 +32,25 @@ def test_api_calling(
pytest.skip(
f"model {model_name} does not support API calling for {task} benchmark",
)
if "scanpy" in yaml_data["case"]:
pytest.skip(
"scanpy is not a web API",
)

def run_test():
conversation.reset() # needs to be reset for each test
if "oncokb" in yaml_data["case"]:
builder = OncoKBQueryBuilder()
elif "biotools" in yaml_data["case"]:
builder = BioToolsQueryBuilder()
elif "scanpy:pl" in yaml_data["case"]:
builder = ScanpyPlQueryBuilder()
parameters = builder.parameterise_query(
question=yaml_data["input"]["prompt"],
conversation=conversation,
)

params = parameters.dict(exclude_none=True)
endpoint = params.pop("endpoint")
base_url = params.pop("base_url")
params.pop("question_uuid")
full_url = f"{base_url.rstrip('/')}/{endpoint.lstrip('/')}"
api_query = f"{full_url}?{urlencode(params)}"
api_query = format_as_rest_call(parameters)

score = []
for expected_part in ensure_iterable(
Expand All @@ -72,3 +73,57 @@ def run_test():
yaml_data["hash"],
get_result_file_path(task),
)

def test_python_api_calling(
model_name,
test_data_api_calling,
conversation,
multiple_testing,
):
"""Test the Python API calling capability."""
task = f"{inspect.currentframe().f_code.co_name.replace('test_', '')}"
yaml_data = test_data_api_calling

skip_if_already_run(
model_name=model_name,
task=task,
md5_hash=yaml_data["hash"],
)

if "scanpy" not in yaml_data["case"]:
pytest.skip(
"Function to be tested is not a Python API",
)

def run_test():
conversation.reset() # needs to be reset for each test
if "scanpy:pl" in yaml_data["case"]:
builder = ScanpyPlQueryBuilder()
parameters = builder.parameterise_query(
question=yaml_data["input"]["prompt"],
conversation=conversation,
)

method_call = format_as_python_call(parameters)

score = []
for expected_part in ensure_iterable(
yaml_data["expected"]["parts_of_query"],
):
if re.search(expected_part, method_call):
score.append(True)
else:
score.append(False)

return calculate_bool_vector_score(score)

mean_score, max, n_iterations = multiple_testing(run_test)

write_results_to_file(
model_name,
yaml_data["case"],
f"{mean_score}/{max}",
f"{n_iterations}",
yaml_data["hash"],
get_result_file_path(task),
)
7 changes: 6 additions & 1 deletion biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
BlastQueryParameters,
)
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .scanpy_pl import ScanpyPlQueryBuilder
from .formatters import format_as_rest_call, format_as_python_call

__all__ = [
"BaseFetcher",
Expand All @@ -28,4 +30,7 @@
"BioToolsInterpreter",
"BioToolsQueryBuilder",
"APIAgent",
]
"ScanpyPlQueryBuilder",
"format_as_rest_call",
"format_as_python_call",
]
41 changes: 41 additions & 0 deletions biochatter/api_agent/formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Formatters for API calls (Pydantic models to strings)."""

from urllib.parse import urlencode

from pydantic import BaseModel

def format_as_rest_call(model: BaseModel) -> str:
"""Convert a parameter model (BaseModel) into a REST API call string.

Args:
model: Pydantic model containing API call parameters

Returns:
String representation of the REST API call

"""
params = model.dict(exclude_none=True)
endpoint = params.pop("endpoint")
base_url = params.pop("base_url")
params.pop("question_uuid", None)

full_url = f"{base_url.rstrip('/')}/{endpoint.strip('/')}"
return f"{full_url}?{urlencode(params)}"

def format_as_python_call(model: BaseModel) -> str:
"""Convert a parameter model into a Python method call string.

Args:
model: Pydantic model containing method parameters

Returns:
String representation of the Python method call

"""
params = model.dict(exclude_none=True)
method_name = params.pop("method_name", None)
params.pop("question_uuid", None)

param_str = ", ".join(f"{k}={repr(v)}" for k, v in params.items())

return f"{method_name}({param_str})"
Loading