diff --git a/mdagent/mainagent/agent.py b/mdagent/mainagent/agent.py index 0e20df8c..6f8d7f3b 100644 --- a/mdagent/mainagent/agent.py +++ b/mdagent/mainagent/agent.py @@ -7,7 +7,7 @@ from mdagent.subagents import SubAgentSettings from mdagent.utils import PathRegistry, _make_llm -from ..tools import make_all_tools +from ..tools import get_tools, make_all_tools from .prompt import openaifxn_prompt, structured_prompt load_dotenv() @@ -35,7 +35,7 @@ class MDAgent: def __init__( self, tools=None, - agent_type="OpenAIFunctionsAgent", # this can also be strucured_chat + agent_type="OpenAIFunctionsAgent", # this can also be structured_chat model="gpt-4-1106-preview", # current name for gpt-4 turbo tools_model="gpt-4-1106-preview", temp=0.1, @@ -45,14 +45,16 @@ def __init__( subagents_model="gpt-4-1106-preview", ckpt_dir="ckpt", resume=False, - top_k_tools=10, + top_k_tools=20, # set "all" if you want to use all tools (& skills if resume) use_human_tool=False, ): if path_registry is None: path_registry = PathRegistry.get_instance() - if tools is None: - tools_llm = _make_llm(tools_model, temp, verbose) - tools = make_all_tools(tools_llm, human=use_human_tool) + self.agent_type = agent_type + self.user_tools = tools + self.tools_llm = _make_llm(tools_model, temp, verbose) + self.top_k_tools = top_k_tools + self.use_human_tool = use_human_tool self.llm = ChatOpenAI( temperature=temp, @@ -61,11 +63,7 @@ def __init__( streaming=True, callbacks=[StreamingStdOutCallbackHandler()], ) - self.agent = AgentExecutor.from_agent_and_tools( - tools=tools, - agent=AgentType.get_agent(agent_type).from_llm_and_tools(self.llm, tools), - handle_parsing_errors=True, - ) + # assign prompt if agent_type == "Structured": self.prompt = structured_prompt @@ -80,9 +78,37 @@ def __init__( verbose=verbose, ckpt_dir=ckpt_dir, resume=resume, - retrieval_top_k=top_k_tools, + ) + + def _initialize_tools_and_agent(self, user_input=None): + """Retrieve tools and initialize the agent.""" + if self.user_tools is not None: + self.tools = self.user_tools + else: + if self.top_k_tools != "all" and user_input is not None: + # retrieve only tools relevant to user input + self.tools = get_tools( + query=user_input, + llm=self.tools_llm, + subagent_settings=self.subagents_settings, + human=self.use_human_tool, + ) + else: + # retrieve all tools, including new tools if any + self.tools = make_all_tools( + self.tools_llm, + subagent_settings=self.subagents_settings, + human=self.use_human_tool, + ) + return AgentExecutor.from_agent_and_tools( + tools=self.tools, + agent=AgentType.get_agent(self.agent_type).from_llm_and_tools( + self.llm, + self.tools, + ), + handle_parsing_errors=True, ) def run(self, user_input, callbacks=None): - # todo: check this for both agent types + self.agent = self._initialize_tools_and_agent(user_input) return self.agent.run(self.prompt.format(input=user_input), callbacks=callbacks) diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index 2009dc3a..17daab31 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -123,22 +123,25 @@ def get_tools( query, llm: BaseLanguageModel, subagent_settings: Optional[SubAgentSettings] = None, - ckpt_dir="ckpt", - retrieval_top_k=10, + top_k_tools=15, subagents_required=True, human=False, ): + if subagent_settings: + ckpt_dir = subagent_settings.ckpt_dir + else: + ckpt_dir = "ckpt" + retrieved_tools = [] if subagents_required: # add subagents-related tools by default - PathRegistry.get_instance() retrieved_tools = [ CreateNewTool(subagent_settings=subagent_settings), RetryExecuteSkill(subagent_settings=subagent_settings), SkillRetrieval(subagent_settings=subagent_settings), WorkflowPlan(subagent_settings=subagent_settings), ] - retrieval_top_k -= len(retrieved_tools) + top_k_tools -= len(retrieved_tools) all_tools = make_all_tools( llm, subagent_settings, skip_subagents=True, human=human ) @@ -163,7 +166,7 @@ def get_tools( vectordb.persist() # retrieve 'k' tools - k = min(retrieval_top_k, vectordb._collection.count()) + k = min(top_k_tools, vectordb._collection.count()) if k == 0: return None docs = vectordb.similarity_search(query, k=k) @@ -173,7 +176,8 @@ def get_tools( retrieved_tools.append(all_tools[index]) else: print(f"Invalid index {index}.") - print(f"Try deleting vectordb at {ckpt_dir}/all_tools_vectordb.") + print("Some tools may be duplicated.") + print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.") return retrieved_tools @@ -217,7 +221,7 @@ def get_all_tools_string(self): all_tools_string += f"{tool.name}: {tool.description}\n" return all_tools_string - def _run(self, task, orig_prompt, curr_tools, execute, args=None): + def _run(self, task, orig_prompt, curr_tools, execute=True, args=None): # run iterator try: all_tools_string = self.get_all_tools_string() diff --git a/notebooks/.DS_Store b/notebooks/.DS_Store deleted file mode 100644 index 78ff399b..00000000 Binary files a/notebooks/.DS_Store and /dev/null differ