Skip to content

Commit d014ff7

Browse files
committed
1. fix bug at getting id and filename from modified simulations. 2. Change stucture in make_tools so modifyscript is with other llm tools. 3. Fix bug at modifyscriputils during its init (added the llm when called)
1 parent fbbfc2d commit d014ff7

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

mdagent/tools/base_tools/simulation_tools/create_simulation.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from langchain.tools import BaseTool
99
from pydantic import BaseModel, Field
1010

11-
from mdagent.utils import PathRegistry
11+
from mdagent.utils import FileType, PathRegistry
1212

1313

1414
class ModifyScriptUtils:
15+
llm: Optional[BaseLanguageModel]
16+
1517
def __init__(self, llm):
16-
llm = llm
18+
self.llm = llm
1719

1820
Examples = [
1921
"""
@@ -153,12 +155,16 @@ class ModifyBaseSimulationScriptTool(BaseTool):
153155
llm: Optional[BaseLanguageModel]
154156
path_registry: Optional[PathRegistry]
155157

156-
def __init__(self, path_registry: Optional[PathRegistry], llm: BaseLanguageModel):
158+
def __init__(self, path_registry: Optional[PathRegistry], llm):
157159
super().__init__()
158160
self.path_registry = path_registry
159161
self.llm = llm
162+
print(f"fModifyScriptTool initialized, llm is {llm}")
160163

161164
def _run(self, *args, **input):
165+
if self.llm is None: # this should not happen
166+
print("No language model provided at ModifyScriptTool")
167+
return "llm not initialized"
162168
if len(args) > 0:
163169
return (
164170
"This tool expects you to provide the input as a "
@@ -178,7 +184,7 @@ def _run(self, *args, **input):
178184
with open(base_script_path, "r") as file:
179185
base_script = file.read()
180186
base_script = "".join(base_script)
181-
utils = ModifyScriptUtils()
187+
utils = ModifyScriptUtils(self.llm)
182188

183189
description = input.get("query")
184190
answer = utils._prompt_summary(
@@ -194,9 +200,9 @@ def _run(self, *args, **input):
194200
script_content = textwrap.dedent(script_content).strip()
195201
# Write to file
196202
filename = self.path_registry.write_file_name(
197-
type="SIMULATION", Sim_id=base_script_id, modified=True
203+
type=FileType.SIMULATION, Sim_id=base_script_id, modified=True
198204
)
199-
file_id = self.path_registry.get_fileid(filename, type="SIMULATION")
205+
file_id = self.path_registry.get_fileid(filename, type=FileType.SIMULATION)
200206
directory = "files/simulations"
201207
if not os.path.exists(directory):
202208
os.makedirs(directory)

mdagent/tools/maketools.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
PPIDistance,
2525
RMSDCalculator,
2626
Scholar2ResultLLM,
27-
SerpGitTool,
2827
SetUpandRunFunction,
2928
SimulationOutputFigures,
3029
VisualizeProtein,
@@ -63,15 +62,17 @@ def make_all_tools(
6362
):
6463
load_dotenv()
6564
all_tools = []
66-
65+
path_instance = PathRegistry.get_instance() # get instance first
6766
if llm:
6867
all_tools += agents.load_tools(["llm-math"], llm)
6968
all_tools += [PythonREPLTool()] # or PythonREPLTool(llm=llm)?
69+
all_tools += [
70+
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm)
71+
]
7072
if human:
7173
all_tools += [agents.load_tools(["human"], llm)[0]]
7274

7375
# get path registry
74-
path_instance = PathRegistry.get_instance() # get instance first
7576

7677
# add base tools
7778
base_tools = [
@@ -110,10 +111,10 @@ def make_all_tools(
110111
all_tools += base_tools + subagents_tools + learned_tools
111112

112113
# add other tools depending on api keys
113-
serp_key = os.getenv("SERP_API_KEY")
114+
os.getenv("SERP_API_KEY")
114115
pqa_key = os.getenv("PQA_API_KEY")
115-
if serp_key:
116-
all_tools.append(SerpGitTool(serp_key)) # github issues search
116+
# if serp_key:
117+
# all_tools.append(SerpGitTool(serp_key)) # github issues search
117118
if pqa_key:
118119
all_tools.append(Scholar2ResultLLM(pqa_key)) # literature search
119120
return all_tools

mdagent/utils/path_registry.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,20 @@ def write_file_name(self, type: FileType, **kwargs):
159159
conditions = kwargs.get("conditions", None)
160160
Sim_id = kwargs.get("Sim_id", None)
161161
modified = kwargs.get("modified", False)
162-
162+
file_name = ""
163163
if type == FileType.PROTEIN:
164-
file_name = f"{protein_name}_{description}_{time_stamp}.{file_format}"
164+
file_name += f"{protein_name}_{description}_{time_stamp}.{file_format}"
165165
if type == FileType.SIMULATION:
166+
print("im here inside")
166167
if conditions:
167-
file_name = f"{type_of_sim}_{protein_file_id}_{conditions}_{time_stamp}"
168+
file_name += (
169+
f"{type_of_sim}_{protein_file_id}_{conditions}_{time_stamp}.py"
170+
)
168171
elif modified:
169-
file_name = f"{Sim_id}_MOD_{time_stamp}"
172+
print("I got here!!!!")
173+
file_name += f"{Sim_id}_MOD_{time_stamp}.py"
170174
else:
171-
file_name = f"{type_of_sim}_{protein_file_id}_{time_stamp}"
172-
if type == FileType.RECORD:
173-
file_name = f"{protein_file_id}_{Sim_id}_{time_stamp}"
174-
175+
file_name += f"{type_of_sim}_{protein_file_id}_{time_stamp}.py"
176+
if file_name == "":
177+
file_name += "ErrorDuringNaming_error.py"
175178
return file_name

0 commit comments

Comments
 (0)