Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pqa tool #91

Merged
merged 13 commits into from
Feb 27, 2024
3 changes: 0 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@
# OpenAI API Key
OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret

# PQA API Key
PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret

# Serp API key
SERP_API_KEY=YOUR_SERP_API_KEY_GOES_HERE # pragma: allowlist secret
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python "3.9"
- name: Set up Python "3.11"
uses: actions/setup-python@v2
with:
python-version: "3.9"
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ jobs:
environment-file: environment.yaml
python-version: ${{ matrix.python-version }}
auto-activate-base: true
- name: Install openmm pdbfixer mdanalysis with conda
- name: Install pdbfixer with conda
shell: bash -l {0}
run: |
conda install -c conda-forge openmm pdbfixer mdanalysis
conda install -c conda-forge pdbfixer
- name: Install dependencies
shell: bash -l {0}
run: |
Expand All @@ -45,6 +45,5 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SEMANTIC_SCHOLAR_API_KEY: ${{ secrets.SEMANTIC_SCHOLAR_API_KEY }}
PQA_API_KEY : ${{ secrets.PQA_API_TOKEN }}
run: |
pytest -m "not skip" tests
3 changes: 0 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,5 @@
# Rule for detecting OpenAI API keys
OpenAI API Key: \b[secrets]{3}_[a-zA-Z0-9]{32}\b

# Rule for detecting pqa API keys
PQA API Key: "pqa[a-zA-Z0-9-._]+"

# Rule for detecting serp API keys
# Serp API Key: "[a-zA-Z0-9]{64}"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ To use the OpenMM features in the agent, please set up a conda environment, foll
- Create conda environment: `conda env create -n mdagent -f environment.yaml`
- Activate your environment: `conda activate mdagent`

If you already have a conda environment, you can install the necessary dependencies with the following steps.
- Install the necessary conda dependencies: `conda install -c conda-forge openmm pdbfixer mdanalysis`
If you already have a conda environment, you can install, pdbfixer, a necessary dependency with the following steps.
- Install the necessary conda dependencies: `conda install -c conda-forge pdbfixer`


## Installation
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pre-commit
pytest
pytest-mock
85 changes: 67 additions & 18 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,75 @@
import pqapi
import os
import re

import langchain
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from pypdf.errors import PdfReadError


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}


def paper_search(llm, query):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}. Your response must be at
most 10 words long.
'A search query that would bring up papers that can answer
this question would be: '""",
)

query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
if not os.path.isdir("./query"): # todo: move to ckpt
os.mkdir("query/")
search = query_chain.run(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}")
return papers


def scholar2result_llm(llm, query, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
papers = paper_search(llm, query)
if len(papers) == 0:
return "Not enough papers found"
docs = paperqa.Docs(llm=llm)
not_loaded = 0
for path, data in papers.items():
try:
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}")
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return answer


class Scholar2ResultLLM(BaseTool):
name = "LiteratureSearch"
description = """Input a specific question,
returns an answer from literature search."""
description = (
"Useful to answer questions that require technical "
"knowledge. Ask a specific question."
)
llm: BaseLanguageModel = None

pqa_key: str = ""

def __init__(self, pqa_key: str):
def __init__(self, llm):
super().__init__()
self.pqa_key = pqa_key
self.llm = llm

def _run(self, question: str) -> str:
"""Use the tool"""
try:
response = pqapi.agent_query("default", question)
return response.answer
except Exception:
return "Literature search failed."

async def _arun(self, question: str) -> str:
"""Use the tool asynchronously"""
raise NotImplementedError
def _run(self, query) -> str:
return scholar2result_llm(self.llm, query)

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("this tool does not support async")
9 changes: 1 addition & 8 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def make_all_tools(

# add base tools
base_tools = [
Scholar2ResultLLM(llm=llm),
CleaningToolFunction(path_registry=path_instance),
ListRegistryPaths(path_registry=path_instance),
ProteinName2PDBTool(path_registry=path_instance),
Expand Down Expand Up @@ -108,14 +109,6 @@ def make_all_tools(
learned_tools = get_learned_tools(subagent_settings.ckpt_dir)

all_tools += base_tools + subagents_tools + learned_tools

# add other tools depending on api keys
os.getenv("SERP_API_KEY")
pqa_key = os.getenv("PQA_API_KEY")
# if serp_key:
# all_tools.append(SerpGitTool(serp_key)) # github issues search
if pqa_key:
all_tools.append(Scholar2ResultLLM(pqa_key)) # literature search
return all_tools


Expand Down
121 changes: 121 additions & 0 deletions notebooks/lit_search.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samcox/anaconda3/envs/mda_feb21/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from mdagent import MDAgent"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#until we update to new version\n",
"import nest_asyncio\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mda = MDAgent()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"Are there any studies that show that the use of a mask can reduce the spread of COVID-19?\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"Masks COVID-19 transmission reduction studies\"\n",
"Search: \"Masks COVID-19 transmission reduction studies\"\n",
"\n",
"Found 14 papers but couldn't load 0\n",
"Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies."
]
}
],
"source": [
"answer = mda.run(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mdagent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
license="MIT",
packages=find_packages(),
install_requires=[
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
"chromadb==0.3.29",
"google-search-results",
"langchain==0.0.336",
"langchain_experimental",
"matplotlib",
"nbformat",
"openai",
"paper-qa",
"python-dotenv",
"pqapi",
"requests",
"rmrkl",
"tiktoken",
"rdkit",
"streamlit",
"paper-qa",
"openmm",
"MDAnalysis",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
],
test_suite="tests",
long_description=long_description,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_fxns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from unittest.mock import MagicMock, mock_open, patch

import pytest
from langchain.chat_models import ChatOpenAI

from mdagent.tools.base_tools import (
CleaningTools,
Scholar2ResultLLM,
SimulationFunctions,
VisFunctions,
get_pdb,
Expand Down Expand Up @@ -438,3 +440,24 @@ def test_init_path_registry(path_registry_with_mocked_fs):
# you may need to check the internal state or the contents of the JSON file.
# For example:
assert "water_000000" in path_registry_with_mocked_fs.list_path_names()


@pytest.fixture
def questions():
qs = [
"What are the effects of norhalichondrin B in mammals?",
]
return qs[0]


@pytest.mark.skip(reason="This requires an API call")
def test_litsearch(questions):
llm = ChatOpenAI()

searchtool = Scholar2ResultLLM(llm=llm)
for q in questions:
ans = searchtool._run(q)
assert isinstance(ans, str)
assert len(ans) > 0
if os.path.exists("../query"):
os.rmdir("../query")
Loading