Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bring back tool retrieval #63

Merged
merged 8 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mdagent.subagents import SubAgentSettings
from mdagent.utils import PathRegistry, _make_llm

from ..tools import make_all_tools
from ..tools import get_tools, make_all_tools
from .prompt import openaifxn_prompt, structured_prompt

load_dotenv()
Expand Down Expand Up @@ -35,7 +35,7 @@ class MDAgent:
def __init__(
self,
tools=None,
agent_type="OpenAIFunctionsAgent", # this can also be strucured_chat
agent_type="OpenAIFunctionsAgent", # this can also be structured_chat
model="gpt-4-1106-preview", # current name for gpt-4 turbo
tools_model="gpt-4-1106-preview",
temp=0.1,
Expand All @@ -45,14 +45,16 @@ def __init__(
subagents_model="gpt-4-1106-preview",
ckpt_dir="ckpt",
resume=False,
top_k_tools=10,
top_k_tools=20, # set "all" if you want to use all tools (& skills if resume)
use_human_tool=False,
):
if path_registry is None:
path_registry = PathRegistry.get_instance()
if tools is None:
tools_llm = _make_llm(tools_model, temp, verbose)
tools = make_all_tools(tools_llm, human=use_human_tool)
self.agent_type = agent_type
self.user_tools = tools
self.tools_llm = _make_llm(tools_model, temp, verbose)
self.top_k_tools = top_k_tools
self.use_human_tool = use_human_tool

self.llm = ChatOpenAI(
temperature=temp,
Expand All @@ -61,11 +63,7 @@ def __init__(
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.agent = AgentExecutor.from_agent_and_tools(
tools=tools,
agent=AgentType.get_agent(agent_type).from_llm_and_tools(self.llm, tools),
handle_parsing_errors=True,
)

# assign prompt
if agent_type == "Structured":
self.prompt = structured_prompt
Expand All @@ -80,9 +78,37 @@ def __init__(
verbose=verbose,
ckpt_dir=ckpt_dir,
resume=resume,
retrieval_top_k=top_k_tools,
)

def _initialize_tools_and_agent(self, user_input=None):
"""Retrieve tools and initialize the agent."""
if self.user_tools is not None:
self.tools = self.user_tools
else:
if self.top_k_tools != "all" and user_input is not None:
# retrieve only tools relevant to user input
self.tools = get_tools(
query=user_input,
llm=self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
)
else:
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
)
return AgentExecutor.from_agent_and_tools(
tools=self.tools,
agent=AgentType.get_agent(self.agent_type).from_llm_and_tools(
self.llm,
self.tools,
),
handle_parsing_errors=True,
)

def run(self, user_input, callbacks=None):
# todo: check this for both agent types
self.agent = self._initialize_tools_and_agent(user_input)
return self.agent.run(self.prompt.format(input=user_input), callbacks=callbacks)
18 changes: 11 additions & 7 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,25 @@ def get_tools(
query,
llm: BaseLanguageModel,
subagent_settings: Optional[SubAgentSettings] = None,
ckpt_dir="ckpt",
retrieval_top_k=10,
top_k_tools=15,
subagents_required=True,
human=False,
):
if subagent_settings:
ckpt_dir = subagent_settings.ckpt_dir
else:
ckpt_dir = "ckpt"

retrieved_tools = []
if subagents_required:
# add subagents-related tools by default
PathRegistry.get_instance()
retrieved_tools = [
CreateNewTool(subagent_settings=subagent_settings),
RetryExecuteSkill(subagent_settings=subagent_settings),
SkillRetrieval(subagent_settings=subagent_settings),
WorkflowPlan(subagent_settings=subagent_settings),
]
retrieval_top_k -= len(retrieved_tools)
top_k_tools -= len(retrieved_tools)
all_tools = make_all_tools(
llm, subagent_settings, skip_subagents=True, human=human
)
Expand All @@ -163,7 +166,7 @@ def get_tools(
vectordb.persist()

# retrieve 'k' tools
k = min(retrieval_top_k, vectordb._collection.count())
k = min(top_k_tools, vectordb._collection.count())
if k == 0:
return None
docs = vectordb.similarity_search(query, k=k)
Expand All @@ -173,7 +176,8 @@ def get_tools(
retrieved_tools.append(all_tools[index])
else:
print(f"Invalid index {index}.")
print(f"Try deleting vectordb at {ckpt_dir}/all_tools_vectordb.")
print("Some tools may be duplicated.")
print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.")
return retrieved_tools


Expand Down Expand Up @@ -217,7 +221,7 @@ def get_all_tools_string(self):
all_tools_string += f"{tool.name}: {tool.description}\n"
return all_tools_string

def _run(self, task, orig_prompt, curr_tools, execute, args=None):
def _run(self, task, orig_prompt, curr_tools, execute=True, args=None):
# run iterator
try:
all_tools_string = self.get_all_tools_string()
Expand Down
Binary file removed notebooks/.DS_Store
Binary file not shown.
Loading