diff --git a/README.md b/README.md index b38b46d8..e87e97e1 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ Other tools require API keys, such as paper-qa for literature searches. We recom 1. Copy the `.env.example` file and rename it to `.env`: `cp .env.example .env` 2. Replace the placeholder values in `.env` with your actual keys +## Using Streamlit Interface +If you'd like to use MDAgent via the streamlit app, make sure you have completed the steps above. Then, in your terminal, run `streamlit run st_app.py` in the project root directory. + +From there you may upload files to use during the run. Note: the app is currently limited to uploading .pdb and .cif files, and the max size is defaulted at 200MB. +- To upload larger files, instead run `streamlit run st_app.py --server.maxUploadSize=some_large_number` +- To add different file types, you can add your desired file type to the list in the [streamlit app file](https://github.com/ur-whitelab/md-agent/blob/main/st_app.py). + ## Contributing diff --git a/mdagent/mainagent/agent.py b/mdagent/mainagent/agent.py index 6f8d7f3b..61bff83d 100644 --- a/mdagent/mainagent/agent.py +++ b/mdagent/mainagent/agent.py @@ -47,9 +47,14 @@ def __init__( resume=False, top_k_tools=20, # set "all" if you want to use all tools (& skills if resume) use_human_tool=False, + uploaded_files=[], # user input files to add to path registry ): if path_registry is None: path_registry = PathRegistry.get_instance() + self.uploaded_files = uploaded_files + for file in uploaded_files: # todo -> allow users to add descriptions? + path_registry.map_path(file, file, description="User uploaded file") + self.agent_type = agent_type self.user_tools = tools self.tools_llm = _make_llm(tools_model, temp, verbose) diff --git a/mdagent/subagents/subagent_fxns.py b/mdagent/subagents/subagent_fxns.py index 8b1c84ce..b010ba60 100644 --- a/mdagent/subagents/subagent_fxns.py +++ b/mdagent/subagents/subagent_fxns.py @@ -2,6 +2,8 @@ import os from typing import Optional +import streamlit as st + from .subagent_setup import SubAgentInitializer, SubAgentSettings @@ -76,6 +78,7 @@ def _run_loop(self, task, full_history, skills): """ critique = None print("\n\033[46m action agent is running, writing code\033[0m") + st.markdown("action agent is running, writing code", unsafe_allow_html=True) success, code, fxn_name, code_output = self.action._run_code( full_history, task, skills ) @@ -126,12 +129,20 @@ def _run_iterations(self, run, task): # give successful code to tool/skill manager print("\n\033[46mThe new code is complete, running skill agent\033[0m") + st.markdown( + "The new code is complete, running skill agent", + unsafe_allow_html=True, + ) tool_name = self.skill.add_new_tool(fxn_name, code) return success, tool_name iter += 1 # if max iterations reached without success, save failures to file print("\n\033[46m Max iterations reached, saving failed history to file\033[0m") + st.markdown( + "Max iterations reached, saving failed history to file", + unsafe_allow_html=True, + ) tool_name = None full_failed = self._add_to_history( full_history, diff --git a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py index b5433f82..684d5f37 100644 --- a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import MDAnalysis as mda import numpy as np +import streamlit as st from langchain.tools import BaseTool from MDAnalysis.analysis import align, diffusionmap, rms from pydantic import BaseModel, Field @@ -44,15 +45,27 @@ def calculate_rmsd( if rmsd_type == "rmsd": if self.ref_file: print("Calculating 1-D RMSD between two sets of coordinates...") + st.markdown( + "Calculating 1-D RMSD between two sets of coordinates...", + unsafe_allow_html=True, + ) return self.compute_rmsd_2sets(selection=selection) else: print("Calculating time-dependent RMSD...") + st.markdown( + "Calculating time-dependent RMSD...", unsafe_allow_html=True + ) return self.compute_rmsd(selection=selection, plot=plot) elif rmsd_type == "pairwise_rmsd": print("Calculating pairwise RMSD...") + st.markdown("Calculating pairwise RMSD...", unsafe_allow_html=True) return self.compute_2d_rmsd(selection=selection, plot_heatmap=plot) elif rmsd_type == "rmsf": print("Calculating root mean square fluctuation (RMSF)...") + st.markdown( + "Calculating root mean square fluctuation (RMSF)...", + unsafe_allow_html=True, + ) return self.compute_rmsf(selection=selection, plot=plot) else: raise ValueError( diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index fbad3a6f..45dacde9 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Type, Union import requests +import streamlit as st from langchain.tools import BaseTool from pdbfixer import PDBFixer from pydantic import BaseModel, Field, ValidationError, root_validator @@ -39,6 +40,7 @@ def get_pdb(query_string, path_registry=None): if "result_set" in r.json() and len(r.json()["result_set"]) > 0: pdbid = r.json()["result_set"][0]["identifier"] print(f"PDB file found with this ID: {pdbid}") + st.markdown(f"PDB file found with this ID: {pdbid}", unsafe_allow_html=True) url = f"https://files.rcsb.org/download/{pdbid}.{filetype}" pdb = requests.get(url) filename = path_registry.write_file_name( diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index 62176806..19ffaac6 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Type import langchain +import streamlit as st from langchain.base_language import BaseLanguageModel from langchain.chains import LLMChain from langchain.prompts import PromptTemplate @@ -316,7 +317,8 @@ def _setup_and_run_simulation(self, query, PathRegistry): ] Forcefield = Forcefield_files[0] Water_model = Forcefield_files[1] - print("Setting up forcields :", Forcefield, Water_model) + print("Setting up forcefields :", Forcefield, Water_model) + st.markdown("Setting up forcefields", unsafe_allow_html=True) # check if forcefields end in .xml if Forcefield.endswith(".xml") and Water_model.endswith(".xml"): forcefield = ForceField(Forcefield, Water_model) @@ -355,6 +357,7 @@ def _setup_and_run_simulation(self, query, PathRegistry): _timestep, "fs", ) + st.markdown("Setting up Langevin integrator", unsafe_allow_html=True) if params["Ensemble"] == "NPT": _pressure = params["Pressure"].split(" ")[0].strip() system.addForce(MonteCarloBarostat(_pressure * bar, _temp * kelvin)) @@ -378,6 +381,7 @@ def _setup_and_run_simulation(self, query, PathRegistry): "bar", ) print("Setting up Verlet integrator with Parameters:", _timestep, "fs") + st.markdown("Setting up Verlet integrator", unsafe_allow_html=True) integrator = VerletIntegrator(float(_timestep) * picoseconds) simulation = Simulation(modeller.topology, system, integrator) @@ -682,6 +686,7 @@ def __init__( def setup_system(self): print("Building system...") + st.markdown("Building system", unsafe_allow_html=True) self.pdb_id = self.params["pdb_id"] self.pdb_path = self.path_registry.get_mapped_path(name=self.pdb_id) self.pdb = PDBFile(self.pdb_path) @@ -703,6 +708,7 @@ def setup_system(self): def setup_integrator(self): print("Setting up integrator...") + st.markdown("Setting up integrator", unsafe_allow_html=True) int_params = self.int_params integrator_type = int_params.get("integrator_type", "LangevinMiddle") @@ -727,6 +733,7 @@ def setup_integrator(self): def create_simulation(self): print("Creating simulation...") + st.markdown("Creating simulation", unsafe_allow_html=True) self.simulation = Simulation( self.pdb.topology, self.system, @@ -1049,13 +1056,16 @@ def remove_leading_spaces(text): file.write(script_content) print(f"Standalone simulation script written to {directory}/{filename}") + st.markdown("Standalone simulation script written", unsafe_allow_html=True) def run(self): # Minimize and Equilibrate print("Performing energy minimization...") + st.markdown("Performing energy minimization", unsafe_allow_html=True) self.simulation.minimizeEnergy() print("Minimization complete!") + st.markdown("Minimization complete! Equilibrating...", unsafe_allow_html=True) print("Equilibrating...") _temp = self.int_params["Temperature"] self.simulation.context.setVelocitiesToTemperature(_temp) @@ -1063,9 +1073,11 @@ def run(self): self.simulation.step(_eq_steps) # Simulate print("Simulating...") + st.markdown("Simulating...", unsafe_allow_html=True) self.simulation.currentStep = 0 self.simulation.step(self.sim_params["Number of Steps"]) print("Done!") + st.markdown("Done!", unsafe_allow_html=True) if not self.save: if os.path.exists("temp_trajectory.dcd"): os.remove("temp_trajectory.dcd") @@ -1134,6 +1146,7 @@ def _run(self, **input_args): input, self.path_registry, save, sim_id, pdb_id ) print("simulation set!") + st.markdown("simulation set!", unsafe_allow_html=True) except ValueError as e: return str(e) + f"This were the inputs {input_args}" except FileNotFoundError: @@ -1594,9 +1607,11 @@ def check_system_params(cls, values): forcefield_files = values.get("forcefield_files") if forcefield_files is None or forcefield_files is []: print("Setting default forcefields") + st.markdown("Setting default forcefields", unsafe_allow_html=True) forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"] elif len(forcefield_files) == 0: print("Setting default forcefields v2") + st.markdown("Setting default forcefields", unsafe_allow_html=True) forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"] else: for file in forcefield_files: diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index 5fca1faf..d64f4284 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -2,6 +2,7 @@ import os from typing import Optional, Type +import streamlit as st from dotenv import load_dotenv from langchain import agents from langchain.base_language import BaseLanguageModel @@ -179,6 +180,10 @@ def get_tools( print(f"Invalid index {index}.") print("Some tools may be duplicated.") print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.") + st.markdown( + "Invalid index. Some tools may be duplicated Try to delete VDB.", + unsafe_allow_html=True, + ) return retrieved_tools @@ -232,6 +237,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None): current_tools=curr_tools, ) print("running iterator to draft a new tool") + st.markdown("Running iterator to draft a new tool", unsafe_allow_html=True) tool_name = newcode_iterator.run(task, orig_prompt) if not tool_name: return "The 'CreateNewTool' tool failed to build a new tool." @@ -242,6 +248,7 @@ def _run(self, task, orig_prompt, curr_tools, execute=True, args=None): if execute: try: print("\nexecuting tool") + st.markdown("Executing tool", unsafe_allow_html=True) agent_initializer = SubAgentInitializer(self.subagent_settings) skill = agent_initializer.create_skill_manager(resume=True) if skill is None: diff --git a/setup.py b/setup.py index 531ed583..a25ff0db 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "requests", "rmrkl", "tiktoken", + "streamlit", ], test_suite="tests", long_description=long_description, diff --git a/st_app.py b/st_app.py new file mode 100644 index 00000000..ab21a360 --- /dev/null +++ b/st_app.py @@ -0,0 +1,80 @@ +import os +from typing import List + +import streamlit as st +from dotenv import load_dotenv +from langchain.callbacks import StreamlitCallbackHandler +from langchain.callbacks.base import BaseCallbackHandler + +from mdagent import MDAgent + +load_dotenv() + + +st_callback = StreamlitCallbackHandler(st.container()) + + +# Streamlit app +st.title("MDAgent") + +# option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"]) +# if option == "Explore & Learn": +# explore = True +# else: +# explore = False + +resume_op = st.selectbox("Resume:", ["False", "True"]) +if resume_op == "True": + resume = True +else: + resume = False + +# for now I'm just going to allow pdb and cif files - we can add more later +uploaded_files = st.file_uploader( + "Upload a .pdb or .cif file", type=["pdb", "cif"], accept_multiple_files=True +) +files: List[str] = [] +# write file to disk +if uploaded_files: + for file in uploaded_files: + with open(file.name, "wb") as f: + f.write(file.getbuffer()) + + st.write("Files successfully uploaded!") + uploaded_file = [os.path.join(os.getcwd(), file.name) for file in uploaded_files] +else: + uploaded_file = [] + +mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file) + + +def generate_response(prompt): + result = mdagent.run(prompt) + return result + + +# make new container to store scratch +scratch = st.empty() +scratch.write( + """Hi! I am MDAgent, your MD automation assistant. + How can I help you today?""" +) + + +# This allows streaming of llm tokens +class TokenStreamlitCallbackHandler(BaseCallbackHandler): + def __init__(self, container): + self.container = container + + def on_llm_new_token(self, token, **kwargs): + self.container.write("".join(token)) + + +token_st_callback = TokenStreamlitCallbackHandler(scratch) + +if prompt := st.chat_input(): + st.chat_message("user").write(prompt) + with st.chat_message("assistant"): + st_callback = StreamlitCallbackHandler(st.container()) + response = mdagent.run(prompt, callbacks=[st_callback, token_st_callback]) + st.write(response)