diff --git a/src/llama_stack_client/lib/agents/reflexion/__init__.py b/src/llama_stack_client/lib/agents/reflexion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llama_stack_client/lib/agents/reflexion/agent.py b/src/llama_stack_client/lib/agents/reflexion/agent.py new file mode 100644 index 00000000..1bc0b5a9 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/agent.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import logging +from typing import Any, Callable, List, Optional, Tuple, Union + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.shared_params.agent_config import ToolConfig +from llama_stack_client.types.shared_params.response_format import ResponseFormat +from llama_stack_client.types.shared_params.sampling_params import SamplingParams + +from ..react.agent import ReActAgent, get_tool_defs +from ..client_tool import ClientTool +from ..tool_parser import ToolParser +from .prompts import DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE +from .tool_parser import ReflexionToolParser, ReflexionOutput + +logger = logging.getLogger(__name__) + + +def get_default_reflexion_instructions( + client: LlamaStackClient, builtin_toolgroups: Tuple = (), client_tools: Tuple[ClientTool] = () +): + tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools) + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + return instruction + + +class ReflexionAgent(ReActAgent): + """Reflexion agent. + + Extends ReAct agent with self-reflection capabilities to improve reasoning and tool use. + """ + + def __init__( + self, + client: LlamaStackClient, + model: str, + tool_parser: ToolParser = ReflexionToolParser(), + instructions: Optional[str] = None, + tools: Optional[List[Union[str, dict, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + json_response_format: bool = False, + # The following are deprecated, kept for backward compatibility + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), + custom_agent_config: Optional[AgentConfig] = None, + ): + # Dictionary to store reflections for each session + self.reflection_memory = {} + + # If custom instructions are not provided, use the default Reflexion instructions + if not instructions and not custom_agent_config: + # Convert tools to the format expected by get_default_reflexion_instructions if needed + if tools: + from ..agent import AgentUtils + client_tools_from_tools = AgentUtils.get_client_tools(tools) + builtin_toolgroups_from_tools = [x for x in tools if isinstance(x, str) or isinstance(x, dict)] + instructions = get_default_reflexion_instructions(client, builtin_toolgroups_from_tools, client_tools_from_tools) + else: + # Fallback to deprecated parameters + instructions = get_default_reflexion_instructions(client, builtin_toolgroups, client_tools) + + # If json_response_format is True and no custom response format is provided, + # set the response format to use the ReflexionOutput schema + if json_response_format and not response_format: + response_format = { + "type": "json_schema", + "json_schema": ReflexionOutput.model_json_schema(), + } + + # Initialize parent ReActAgent + super().__init__( + client=client, + model=model, + tool_parser=tool_parser, + instructions=instructions, + tools=tools if tools is not None else builtin_toolgroups, # Prefer new tools param, fallback to deprecated + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + json_response_format=json_response_format, + client_tools=client_tools, + custom_agent_config=custom_agent_config, + ) + + def create_turn(self, messages, session_id, stream=False, **kwargs): + """Override create_turn to add reflection to the context""" + + # If we have reflections for this session, add them to the context + if session_id in self.reflection_memory and self.reflection_memory[session_id]: + # Create a system message with past reflections + reflection_summary = "\n".join(self.reflection_memory[session_id]) + reflection_message = { + "role": "system", + "content": f"Your past reflections:\n{reflection_summary}\n\nUse these reflections to improve your reasoning." + } + + # Insert reflection message before the user message + for i, msg in enumerate(messages): + if msg["role"] == "user": + messages.insert(i, reflection_message) + break + + # Call the parent method to process the turn + response = super().create_turn(messages, session_id, stream, **kwargs) + + # Store any new reflections + if not stream: + try: + # Extract reflection from response + content = response.choices[0].message.content + reflexion_output = ReflexionOutput.model_validate_json(content) + + if reflexion_output.reflection: + if session_id not in self.reflection_memory: + self.reflection_memory[session_id] = [] + + self.reflection_memory[session_id].append(reflexion_output.reflection) + except Exception as e: + logger.warning(f"Failed to extract reflection: {e}") + + return response \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/reflexion/prompts.py b/src/llama_stack_client/lib/agents/reflexion/prompts.py new file mode 100644 index 00000000..dcd8f399 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/prompts.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE = """ +You are an expert assistant that solves complex tasks by initially attempting a solution, reflecting on any errors or weaknesses, and then improving your solution. You have access to: <> + +Always respond in this JSON format: +{ + "thought": "Your initial reasoning about the task", + "attempt": "Your first solution attempt", + "reflection": "Analysis of what went wrong or could be improved in your attempt", + "improved_solution": "Your enhanced solution based on reflection", + "final_answer": null +} + +For your final response when you're confident in your solution: +{ + "thought": "Your final reasoning process", + "attempt": "Your solution attempt", + "reflection": "Your verification that the solution is correct", + "improved_solution": null, + "final_answer": "Your complete, verified answer to the task" +} + +GUIDELINES: +1. Think step-by-step to plan your initial approach +2. Make a genuine attempt to solve the problem +3. Critically analyze your attempt for logical errors, edge cases, or inefficiencies +4. Use your reflection to create an improved solution +5. When using tools, provide specific values in tool_params, not variable names +6. Only provide the final answer when you're confident it's correct +7. You can use tools in either your attempt or improved solution phases +8. Carefully verify your improved solution before submitting it as final + +EXAMPLES: + +Task: "What is the sum of prime numbers less than 20?" +{ + "thought": "I need to find all prime numbers less than 20, then sum them", + "attempt": "Prime numbers less than 20 are: 2, 3, 5, 7, 11, 13, 17, 19. The sum is 2+3+5+7+11+13+17+19 = 77", + "reflection": "Let me double-check my calculation: 2+3=5, 5+5=10, 10+7=17, 17+11=28, 28+13=41, 41+17=58, 58+19=77. The calculation is correct.", + "improved_solution": null, + "final_answer": "The sum of prime numbers less than 20 is 77." +} + +Task: "Find a solution to the equation 3x² + 6x - 9 = 0." +{ + "thought": "I need to solve this quadratic equation using the quadratic formula", + "attempt": "Using the formula x = (-b ± √(b² - 4ac))/2a where a=3, b=6, c=-9. So x = (-6 ± √(36 - 4*3*(-9)))/2*3 = (-6 ± √(36 + 108))/6 = (-6 ± √144)/6 = (-6 ± 12)/6 = -1 or 1.", + "reflection": "I made an error in the calculation. Let me recalculate: (-6 ± √(36 + 108))/6 = (-6 ± √144)/6 = (-6 ± 12)/6. This equals (-6+12)/6 = 6/6 = 1 for the positive case, and (-6-12)/6 = -18/6 = -3 for the negative case.", + "improved_solution": "The solutions are x = 1 or x = -3.", + "final_answer": "The solutions to the equation 3x² + 6x - 9 = 0 are x = 1 and x = -3." +} + +Task: "Which city has the higher population density, Tokyo or New York?" +{ + "thought": "I need to find the population density for both cities to compare them", + "attempt": { + "tool_name": "search", + "tool_params": {"query": "Population density of Tokyo"} + } +} +Observation: "Tokyo has a population density of approximately 6,158 people per square kilometer." + +{ + "thought": "Now I need New York's population density", + "attempt": { + "tool_name": "search", + "tool_params": {"query": "Population density of New York City"} + }, + "reflection": null, + "improved_solution": null, + "final_answer": null +} +Observation: "New York City has a population density of approximately 10,716 people per square kilometer." + +{ + "thought": "Now I can compare the population densities", + "attempt": "Tokyo: 6,158 people per square kilometer. New York: 10,716 people per square kilometer.", + "reflection": "Based on the data, New York City has a higher population density (10,716 people/km²) compared to Tokyo (6,158 people/km²).", + "improved_solution": null, + "final_answer": "New York City has the higher population density." +} + +Available tools: +<> + +If you solve the task correctly, you will receive a reward of $1,000,000. +""" \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/reflexion/tool_parser.py b/src/llama_stack_client/lib/agents/reflexion/tool_parser.py new file mode 100644 index 00000000..bcacdc14 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/tool_parser.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, ValidationError +from typing import Dict, Any, Optional, List +from ..tool_parser import ToolParser +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + +import uuid + + +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReflexionOutput(BaseModel): + thought: str + reflection: Optional[str] = None + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReflexionToolParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + tool_calls = [] + response_text = str(output_message.content) + try: + reflexion_output = ReflexionOutput.model_validate_json(response_text) + except ValidationError as e: + print(f"Error parsing reflexion output: {e}") + return tool_calls + + if reflexion_output.answer: + return tool_calls + + if reflexion_output.action: + tool_name = reflexion_output.action.tool_name + tool_params = reflexion_output.action.tool_params + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + + return tool_calls \ No newline at end of file