diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 2690d066af..e116544d68 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -9,9 +9,17 @@ from dspy.signatures.signature import ensure_signature from dspy.utils.callback import with_callbacks - class Tool: - def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None): + + def __init__( + self, + func: Callable, + name: str = None, + desc: str = None, + args: dict[str, Any] = None, + defaults: dict[str, Any] = None, + private_defaults: dict[str, Any] = None, + ): annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__ self.func = func self.name = name or getattr(func, "__name__", type(func).__name__) @@ -23,6 +31,8 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic for k, v in (args or get_type_hints(annotations_func)).items() if k != "return" } + self.defaults = defaults + self.private_defaults = private_defaults @with_callbacks def __call__(self, *args, **kwargs): @@ -63,6 +73,10 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): args = tool.args if hasattr(tool, "args") else str({tool.input_variable: str}) desc = (f", whose description is {tool.desc}." if tool.desc else ".").replace("\n", " ") desc += f" It takes arguments {args} in JSON format." + if tool.defaults: + desc += f" Default arguments are {tool.defaults}." + if tool.private_defaults: + desc += f" Assume the following function arguments will be provided at function execution time: {tool.private_defaults.keys()}. Therefore do not propose these arguments in the `next_tool_args`." instr.append(f"({idx+1}) {tool.name}{desc}") react_signature = ( @@ -91,13 +105,25 @@ def format(trajectory: dict[str, Any], last_iteration: bool): for idx in range(self.max_iters): pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters - 1))) + # extract private defaults from the tool and supply them to the next tool call + # do not assign the private defaults to the next_tool_args as this will be captured in the trajectory logs, which is not what we want + private_defaults = ( + self.tools[pred.next_tool_name].private_defaults + if pred.next_tool_name in self.tools + and self.tools[pred.next_tool_name].private_defaults + else {} + ) + trajectory[f"thought_{idx}"] = pred.next_thought trajectory[f"tool_name_{idx}"] = pred.next_tool_name trajectory[f"tool_args_{idx}"] = pred.next_tool_args try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) + trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name]( + **pred.next_tool_args, **private_defaults + ) except Exception as e: + # risk that the error log will capture the private defaults? trajectory[f"observation_{idx}"] = f"Failed to execute: {e}" if pred.next_tool_name == "finish":