diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 0e2d755b28..5b1ef3d2ce 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -931,18 +931,23 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]: def extract_tool_calls_from_events(events: list) -> Optional[list]: - """Extract tool_calls from BotToolCalls events. + """Extract tool_calls from runtime events. - Args: - events: List of events to search through - - Returns: - tool_calls if found in BotToolCalls event, None otherwise + ``StartToolCallBotAction`` carries the tool calls that passed tool-output + rails and should be returned to the caller. ``BotToolCalls`` is used as a + fallback for paths that do not emit the post-rail action event. """ + bot_tool_calls = None + for event in events: - if event.get("type") == "BotToolCalls": - return event.get("tool_calls") - return None + if event.get("type") == "StartToolCallBotAction": + tool_calls = event.get("tool_calls") + if tool_calls is not None: + return tool_calls + elif event.get("type") == "BotToolCalls": + bot_tool_calls = event.get("tool_calls") + + return bot_tool_calls def extract_bot_thinking_from_events(events: list): diff --git a/nemoguardrails/logging/processing_log.py b/nemoguardrails/logging/processing_log.py index 00ddc554a9..056a6ef78f 100644 --- a/nemoguardrails/logging/processing_log.py +++ b/nemoguardrails/logging/processing_log.py @@ -47,6 +47,10 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: "run dialog rails", "process bot message", "run output rails", + "process bot tool call", + "process user tool messages", + "run tool output rails", + "run tool input rails", ] generation_flows = [ "generate bot message", @@ -129,6 +133,22 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: ) generation_log.activated_rails.append(activated_rail) + elif event_type == "StartToolOutputRail": + activated_rail = ActivatedRail( + type="tool_output", + name=event_data["flow_id"], + started_at=event["timestamp"], + ) + generation_log.activated_rails.append(activated_rail) + + elif event_type == "StartToolInputRail": + activated_rail = ActivatedRail( + type="tool_input", + name=event_data["flow_id"], + started_at=event["timestamp"], + ) + generation_log.activated_rails.append(activated_rail) + elif event_type == "StartInternalSystemAction": action_name = event_data["action_name"] if action_name in ignored_actions: @@ -154,7 +174,12 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: executed_action.return_value = event_data["return_value"] executed_action = None - elif event_type in ["InputRailFinished", "OutputRailFinished"]: + elif event_type in [ + "InputRailFinished", + "OutputRailFinished", + "ToolOutputRailFinished", + "ToolInputRailFinished", + ]: if activated_rail is not None: activated_rail.finished_at = event["timestamp"] if activated_rail.finished_at is not None and activated_rail.started_at is not None: @@ -171,6 +196,8 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: if activated_rail is not None and activated_rail.type in [ "input", "output", + "tool_output", + "tool_input", ]: activated_rail.stop = True if "stop" not in activated_rail.decisions: @@ -188,7 +215,7 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: if activated_rail.finished_at is not None and activated_rail.started_at is not None: activated_rail.duration = activated_rail.finished_at - activated_rail.started_at - if activated_rail.type in ["input", "output"]: + if activated_rail.type in ["input", "output", "tool_output", "tool_input"]: activated_rail.stop = True if "stop" not in activated_rail.decisions: activated_rail.decisions.append("stop") diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 63c0fa27b7..35eb99a0ff 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1017,8 +1017,16 @@ async def generate_async( # If the last message is from the assistant, rather than the user, then # we move that to the `$bot_message` variable. This is to enable a more - # convenient interface. (only when dialog rails are disabled) - if messages and messages[-1]["role"] == "assistant" and gen_options and gen_options.rails.dialog is False: + # convenient interface for text output rails. Tool-call assistant messages + # must remain in the history so they can be converted into BotToolCalls + # events and evaluated by tool output rails. + if ( + messages + and messages[-1]["role"] == "assistant" + and not messages[-1].get("tool_calls") + and gen_options + and gen_options.rails.dialog is False + ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] messages = messages[0:-1] diff --git a/tests/test_logging.py b/tests/test_logging.py index 8fa230c683..8f3e64c05e 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -21,11 +21,30 @@ from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo from nemoguardrails.logging.llm_tracker import track_llm_call -from nemoguardrails.logging.processing_log import processing_log_var +from nemoguardrails.logging.processing_log import compute_generation_log, processing_log_var from nemoguardrails.logging.stats import LLMStats from nemoguardrails.types import LLMResponse, UsageInfo +def test_compute_generation_log_includes_tool_rails(): + generation_log = compute_generation_log( + [ + {"type": "step", "flow_id": "process bot tool call", "timestamp": 0.0, "next_steps": []}, + {"type": "event", "timestamp": 1.0, "data": {"type": "StartToolOutputRail", "flow_id": "check tool call"}}, + {"type": "event", "timestamp": 1.25, "data": {"type": "ToolOutputRailFinished"}}, + {"type": "step", "flow_id": "process user tool messages", "timestamp": 2.0, "next_steps": []}, + {"type": "event", "timestamp": 3.0, "data": {"type": "StartToolInputRail", "flow_id": "check tool result"}}, + {"type": "event", "timestamp": 3.5, "data": {"type": "ToolInputRailFinished"}}, + ] + ) + + activated_rails = generation_log.activated_rails + + assert [rail.type for rail in activated_rails] == ["tool_output", "tool_input"] + assert [rail.name for rail in activated_rails] == ["check tool call", "check tool result"] + assert [rail.duration for rail in activated_rails] == [0.25, 0.5] + + @pytest.mark.asyncio async def test_token_usage_tracking_with_usage(): llm_call_info = LLMCallInfo() diff --git a/tests/test_tool_calls_event_extraction.py b/tests/test_tool_calls_event_extraction.py index 4a5548b5d6..5a30c29ab3 100644 --- a/tests/test_tool_calls_event_extraction.py +++ b/tests/test_tool_calls_event_extraction.py @@ -126,6 +126,53 @@ def mock_get_and_clear(): assert result["tool_calls"][0]["name"] == "test_tool" +@pytest.mark.asyncio +async def test_extract_tool_calls_from_start_tool_call_action(): + from nemoguardrails.actions.llm.utils import extract_tool_calls_from_events + + test_tool_calls = [ + { + "id": "call_approved", + "type": "function", + "function": { + "name": "approved_tool", + "arguments": {"data": "safe"}, + }, + } + ] + + events = [{"type": "StartToolCallBotAction", "tool_calls": test_tool_calls}] + + assert extract_tool_calls_from_events(events) == test_tool_calls + + +@pytest.mark.asyncio +async def test_extract_tool_calls_prefers_post_rail_action_event(): + from nemoguardrails.actions.llm.utils import extract_tool_calls_from_events + + pre_rail = [ + { + "id": "call_modified", + "type": "function", + "function": {"name": "lookup", "arguments": {"query": "unfiltered"}}, + } + ] + post_rail = [ + { + "id": "call_modified", + "type": "function", + "function": {"name": "lookup", "arguments": {"query": "filtered"}}, + } + ] + + events = [ + {"type": "BotToolCalls", "tool_calls": pre_rail}, + {"type": "StartToolCallBotAction", "tool_calls": post_rail}, + ] + + assert extract_tool_calls_from_events(events) == post_rail + + @pytest.mark.asyncio async def test_llmrails_extracts_tool_calls_from_events(): config = RailsConfig.from_content(config={"models": [], "passthrough": True}) diff --git a/tests/test_tool_output_rails.py b/tests/test_tool_output_rails.py index f587be5365..464a9f32e8 100644 --- a/tests/test_tool_output_rails.py +++ b/tests/test_tool_output_rails.py @@ -13,16 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from unittest.mock import patch import pytest from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.actions import action +from nemoguardrails.rails.llm.options import GenerationResponse from nemoguardrails.types import LLMResponse, ToolCall, ToolCallFunction from tests.utils import FakeLLMModel, TestChat +def _tool_arguments(func: dict) -> dict: + """Parse OpenAI-style tool call arguments (JSON string or dict).""" + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + return {} + return arguments if isinstance(arguments, dict) else {} + + @action(is_system_action=True) async def validate_tool_parameters(tool_calls, context=None, **kwargs): tool_calls = tool_calls or (context.get("tool_calls", []) if context else []) @@ -31,7 +44,7 @@ async def validate_tool_parameters(tool_calls, context=None, **kwargs): for tool_call in tool_calls: func = tool_call.get("function", {}) - args = func.get("arguments", {}) + args = _tool_arguments(func) for param_value in args.values(): if isinstance(param_value, str): if any(pattern.lower() in param_value.lower() for pattern in dangerous_patterns): @@ -190,3 +203,113 @@ async def test_multiple_tool_output_rails(): assert result["tool_calls"] is not None assert result["tool_calls"][0]["name"] == "test_tool" + + +@pytest.mark.asyncio +async def test_assistant_tool_calls_run_tool_output_rails_when_dialog_disabled(): + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "I cannot execute this tool request because the parameters may be unsafe." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + rails = LLMRails(config) + rails.runtime.register_action(validate_tool_parameters, name="validate_tool_parameters") + + messages = [ + {"role": "user", "content": "Use the requested tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_bad", + "type": "function", + "function": { + "name": "dangerous_tool", + "arguments": '{"param": "eval(\'malicious code\')"}', + }, + } + ], + }, + ] + + result = await rails.generate_async(messages=messages, options={"rails": {"dialog": False}}) + + assert isinstance(result, GenerationResponse) + assert isinstance(result.response, list) + assert "parameters may be unsafe" in result.response[0]["content"] + + +@pytest.mark.asyncio +async def test_approved_assistant_tool_calls_are_returned_when_dialog_disabled(): + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "I cannot execute this tool request because the parameters may be unsafe." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + rails = LLMRails(config) + rails.runtime.register_action(validate_tool_parameters, name="validate_tool_parameters") + + messages = [ + {"role": "user", "content": "Use the requested tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_safe", + "type": "function", + "function": { + "name": "safe_tool", + "arguments": '{"param": "safe value"}', + }, + } + ], + }, + ] + + result = await rails.generate_async(messages=messages, options={"rails": {"dialog": False}}) + + assert isinstance(result, GenerationResponse) + assert result.tool_calls == [ + { + "id": "call_safe", + "type": "function", + "function": { + "name": "safe_tool", + "arguments": '{"param": "safe value"}', + }, + } + ]