Skip to content

Commit

Permalink
Multiple callback types.
Browse files Browse the repository at this point in the history
  • Loading branch information
tcdent committed Feb 15, 2025
1 parent cd3a90c commit 21ff841
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
30 changes: 28 additions & 2 deletions agentstack/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ def _get_builtin_tool_path(name: str) -> Path:
return TOOLS_DIR / name / TOOLS_CONFIG_FILENAME


class Callback(pydantic.BaseModel):
"""A callback to be called after a tool is run."""

script: Optional[str] = None # call a script (this is the current implementation)
method: Optional[str] = None # call a python method
url: Optional[str] = None # call a URL

@pydantic.validator('script', 'method', 'url', mode='after')
def check_callback(cls, v, values, field):
if not any([values.get('script'), values.get('method'), values.get('url')]):
raise ValueError('At least one of script, method, or url must be set')
return v

@property
def SCRIPT(self) -> bool:
return self.script is not None

@property
def METHOD(self) -> bool:
return self.method is not None

@property
def URL(self) -> bool:
return self.url is not None


class ToolConfig(pydantic.BaseModel):
"""
This represents the configuration data for a tool.
Expand All @@ -38,8 +64,8 @@ class ToolConfig(pydantic.BaseModel):
cta: Optional[str] = None
env: Optional[dict] = None
dependencies: Optional[list[str]] = None
post_install: Optional[str] = None
post_remove: Optional[str] = None
post_install: Optional[Callback] = None
post_remove: Optional[Callback] = None

@classmethod
def from_tool_name(cls, name: str) -> 'ToolConfig':
Expand Down
34 changes: 30 additions & 4 deletions agentstack/generation/tool_generation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from typing import Optional
import json
import os, sys
from pathlib import Path
from typing import Optional
from importlib import import_module
import requests
from agentstack import conf, log
from agentstack.conf import ConfigFile
from agentstack.exceptions import ValidationError
from agentstack import frameworks
from agentstack import packaging
from agentstack.utils import term_color
from agentstack._tools import ToolConfig
from agentstack._tools import Callback, ToolConfig
from agentstack.generation import asttools
from agentstack.generation.files import EnvFile


def _handle_callback(callback: Callback) -> None:
if callback.SCRIPT:
log.debug(f'Calling script {callback.script}')
os.system(callback.script)

elif callback.METHOD:
log.debug(f'Calling method {callback.method}')
module_name, method_name = callback.method.rsplit('.', 1)
module = import_module(module_name)
method = getattr(module, method_name)
method()

elif callback.URL:
log.debug(f'Calling URL {callback.url}')
response = requests.get(callback.url) # TODO methods
if response.status_code == 200:
log.info(response.text)
else:
log.error(f'Response Code: {response.status_code}')

else:
raise ValidationError('Invalid callback type')


def add_tool(name: str, agents: Optional[list[str]] = []):
agentstack_config = ConfigFile()
tool = ToolConfig.from_tool_name(name)
Expand All @@ -33,7 +59,7 @@ def add_tool(name: str, agents: Optional[list[str]] = []):
env.append_if_new(var, value)

if tool.post_install:
os.system(tool.post_install)
_handle_callback(tool.post_install)

with agentstack_config as config:
config.tools.append(tool.name)
Expand Down Expand Up @@ -120,7 +146,7 @@ def remove_tool(name: str, agents: Optional[list[str]] = []):
frameworks.remove_tool(tool, agent_name)

if tool.post_remove:
os.system(tool.post_remove)
_handle_callback(tool.post_remove)
# We don't remove the .env variables to preserve user data.

with agentstack_config as config:
Expand Down

0 comments on commit 21ff841

Please sign in to comment.