Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the tools tl modules to API agent __init__.py #228

4 changes: 4 additions & 0 deletions biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BlastQueryParameters,
)
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .scanpy_tl import ScanpyTLQueryBuilder, ScanpyTLQueryFetcher, ScanpyTLQueryInterpreter

__all__ = [
"BaseFetcher",
Expand All @@ -28,4 +29,7 @@
"BioToolsInterpreter",
"BioToolsQueryBuilder",
"APIAgent",
"ScanpyTLQueryBuilder",
"ScanpyTLQueryFetcher",
"ScanpyTLQueryInterpreter",
]
77 changes: 77 additions & 0 deletions biochatter/api_agent/generate_pydantic_classes_from_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import inspect
from typing import Any, Dict, Optional, Type
from types import ModuleType
from docstring_parser import parse
from langchain_core.pydantic_v1 import BaseModel, Field, create_model

def generate_pydantic_classes(module: ModuleType) -> list[Type[BaseModel]]:
"""
Generate Pydantic classes for each callable (function/method) in a given module.

Extracts parameters from docstrings using docstring-parser. Each generated class
has fields corresponding to the parameters of the function. If a parameter name
conflicts with BaseModel attributes, it is aliased.

Parameters
----------
module : ModuleType
The Python module from which to extract functions and generate models.

Returns
-------
Dict[str, Type[BaseModel]]
A dictionary mapping function names to their corresponding Pydantic model classes.
"""
base_attributes = set(dir(BaseModel))
classes_list = []

# Iterate over all callables in the module
for name, func in inspect.getmembers(module, inspect.isfunction):
# skip if method starts with _
if name.startswith("_"):
continue
doc = inspect.getdoc(func)
if not doc:
# If no docstring, still create a model with no fields
TLParametersModel = create_model(f"{name}")
classes_list.append(TLParametersModel)
continue

parsed_doc = parse(doc)

# Collect parameter descriptions
param_info = {}
for p in parsed_doc.params:
if p.arg_name not in param_info:
param_info[p.arg_name] = p.description or "No description available."

# Prepare fields for create_model
fields = {}
alias_map = {}

for param_name, param_desc in param_info.items():
field_kwargs = {"default": None, "description": param_desc}
field_name = param_name

# Alias if conflicts with BaseModel attributes
if param_name in base_attributes:
aliased_name = param_name + "_param"
field_kwargs["alias"] = param_name
alias_map[aliased_name] = param_name
field_name = aliased_name

# Without type info, default to Optional[str]
fields[field_name] = (Optional[str], Field(**field_kwargs))

# Dynamically create the model for this function
TLParametersModel = create_model(name, **fields)
classes_list.append(TLParametersModel)

return classes_list


# Example usage:
#import scanpy as sc
#generated_classes = generate_pydantic_classes(sc.tl)
#for func in generated_classes:
# print(func.schema())
152 changes: 152 additions & 0 deletions biochatter/api_agent/scanpy_tl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Module for interacting with the `scanpy` API for data transformation tools (`tl`)."""
from collections.abc import Callable
from typing import TYPE_CHECKING

import requests
from langchain.chains.openai_functions import create_structured_output_runnable
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import PydanticToolsParser
if TYPE_CHECKING:
from biochatter.llm_connect import Conversation

from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder
from .generate_pydantic_classes_from_module import generate_pydantic_classes

SCANPY_QUERY_PROMPT = """
You are a world class algorithm for creating queries in structured formats. Your task is to use the scanpy python package
to provide the user with the appropriate function call to answer their question. You focus on the scanpy.tl module, which has
the following overview:
Any transformation of the data matrix that is not *preprocessing*. In contrast to a *preprocessing* function, a *tool* usually adds an easily interpretable annotation to the data matrix, which can then be visualized with a corresponding plotting function.

### Embeddings

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

pp.pca
tl.tsne
tl.umap
tl.draw_graph
tl.diffmap
```

Compute densities on embeddings.

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.embedding_density
```

### Clustering and trajectory inference

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.leiden
tl.louvain
tl.dendrogram
tl.dpt
tl.paga
```

### Data integration

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.ingest
```

### Marker genes

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.rank_genes_groups
tl.filter_rank_genes_groups
tl.marker_gene_overlap
```

### Gene scores, Cell cycle

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.score_genes
tl.score_genes_cell_cycle
```

### Simulations

```{eval-rst}
.. autosummary::
:nosignatures:
:toctree: ../generated/

tl.sim

```
"""
class ScanpyTLQueryBuilder(BaseQueryBuilder):
"""A class for building an ScanpyTLQuery object."""

def create_runnable(
self,
query_parameters: BaseModel,
conversation: "Conversation",
):
pass

def parameterise_query(
self,
question: str,
conversation: "Conversation",
):
"""Generate an ScanpyTLQuery object.

Generate a ScanpyTLQuery object based on the given question, prompt,
and BioChatter conversation. Uses a Pydantic model to define the API
fields. Using langchains .bind_tools method to allow the LLM to parameterise
the function call, based on the functions available in thescanpy.tl module.

Args:
----
question (str): The question to be answered.

conversation: The conversation object used for parameterising the
BioToolsQuery.

Returns:
-------
BioToolsQueryParameters: the parameterised query object (Pydantic
model)

"""
import scanpy as sc
module = sc.tl
generated_classes = generate_pydantic_classes(module)
llm = conversation.chat
llm_with_tools = llm.bind_tools(generated_classes)
query = [
("system", "You're an expert data scientist"),
("human", {question}),
]
chain = llm_with_tools | PydanticToolsParser(tools=generated_classes)
result = chain.invoke(query)
return result
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ python = ">=3.10,<3.13"
langchain = "^0.2.5"
openai = "^1.1.0"
pymupdf = "^1.22.3"
pymilvus = "2.2.8"
pymilvus = ">=2.2.8"
nltk = "^3.8.1"
redis = "^4.5.5"
retry = "^0.9.2"
Expand All @@ -52,6 +52,7 @@ rouge_score = "0.1.2"
evaluate = "^0.4.1"
pillow = ">=10.2,<11.0"
pdf2image = "^1.16.0"
scanpy = { version = "^1.11.0", optional = true }
langchain-community = "^0.2.5"
langgraph = "^0.1.5"
langchain-openai = "^0.1.14"
Expand All @@ -62,6 +63,7 @@ colorcet = "^3.1.0"

langchain-anthropic = "^0.1.22"
anthropic = "^0.33.0"
docstring-parser = "^0.16.0"
[tool.poetry.extras]
streamlit = ["streamlit"]
podcast = ["gTTS"]
Expand Down
79 changes: 79 additions & 0 deletions test/test_api_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
OncoKBQueryBuilder,
OncoKBQueryParameters,
)
from biochatter.api_agent.scanpy_tl import ScanpyTLQueryBuilder
from biochatter.llm_connect import Conversation, GptConversation


Expand Down Expand Up @@ -422,3 +423,81 @@ def test_summarise_results(mock_prompt, mock_conversation, mock_chain):
mock_chain.invoke.assert_called_once_with(
{"input": {expected_summary_prompt}},
)

class TestScanpyTLQueryBuilder:
@pytest.fixture()
def mock_generate_pydantic_classes(self):
with patch("biochatter.api_agent.generate_pydantic_classes_from_module.generate_pydantic_classes") as mock:
# Return a fake dictionary of generated classes
mock.return_value = {"leiden": MagicMock()}
yield mock

@pytest.fixture()
def mock_pydantic_tools_parser(self):
with patch("langchain_core.output_parsers.PydanticToolsParser") as mock_parser_cls:
mock_parser_instance = MagicMock()
mock_parser_cls.return_value = mock_parser_instance
yield mock_parser_cls, mock_parser_instance

def test_parameterise_query(
self,
mock_generate_pydantic_classes,
mock_pydantic_tools_parser
):
# Arrange
query_builder = ScanpyTLQueryBuilder()
mock_conversation = MagicMock()
mock_llm = MagicMock()
mock_conversation.chat = mock_llm

# Mock the LLM with tools
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools

# When we do llm_with_tools | PydanticToolsParser(...) it should return a mock chain
mock_parser_cls, mock_parser_instance = mock_pydantic_tools_parser
mock_chain = MagicMock()
# The '|' operator (pipe) can be emulated by setting return value on __or__
mock_llm_with_tools.__or__.return_value = mock_chain

# The chain.invoke(...) result
mock_result = MagicMock()
mock_chain.invoke.return_value = mock_result

question = "Find the best parameters for leiden clustering."

# Act
result = query_builder.parameterise_query(question, mock_conversation)

# Assert
# Check that generate_pydantic_classes was called with scanpy.tl
args, kwargs = mock_generate_pydantic_classes.call_args
assert "scanpy.tl" in str(args[0]) # or more robust checks depending on your imports

# Check that bind_tools was called on the llm
mock_llm.bind_tools.assert_called_once()

# The query should have been passed to chain.invoke
# query is built as:
# query = [
# ("system", "You're an expert data scientist"),
# ("human", {question}),
# ]
mock_chain.invoke.assert_called_once_with([
("system", "You're an expert data scientist"),
("human", {question}),
])

# Ensure the returned result is the mock_result
assert result == mock_result





class TestScanpyPlFetcher:
pass


class TestScanpyPlInterpreter:
pass