diff --git a/nemoguardrails/eval/check.py b/nemoguardrails/eval/check.py index ba1b145d6a..6071c22306 100644 --- a/nemoguardrails/eval/check.py +++ b/nemoguardrails/eval/check.py @@ -23,7 +23,7 @@ from rich.progress import Progress from rich.text import Text -from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails import RailsConfig from nemoguardrails.actions.llm.utils import llm_call from nemoguardrails.context import llm_call_info_var from nemoguardrails.eval.models import ( @@ -36,6 +36,7 @@ InteractionSet, ) from nemoguardrails.eval.ui.utils import EvalData +from nemoguardrails.llm.models.initializer import init_llm_model from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.rails.llm.config import Model @@ -44,6 +45,17 @@ executor = ThreadPoolExecutor(max_workers=1) +def _prepare_model_kwargs(model_config: Model): + kwargs = dict(model_config.parameters or {}) + + if model_config.api_key_env_var: + api_key = os.environ.get(model_config.api_key_env_var) + if api_key: + kwargs["api_key"] = api_key + + return kwargs + + class LLMJudgeComplianceChecker: """LLM Judge compliance checker.""" @@ -107,15 +119,20 @@ def __init__( console.print(f"The model `{self.llm_judge_model}` is not defined in the evaluation configuration.") exit(1) - model_cls, kwargs = LLMRails.get_model_cls_and_kwargs(model_config) - self.llm = model_cls(**kwargs) + self.llm = init_llm_model( + model_name=model_config.model, + provider_name=model_config.engine, + mode=model_config.mode, + kwargs=_prepare_model_kwargs(model_config), + ) # We create a minimal RailsConfig object, so we can initialize an LLMTaskManager. # We add a placeholder main model, to avoid some edge case errors when one is not defined. - _config = RailsConfig( - models=self.eval_config.models + [Model(type="main", engine="", model="")], - prompts=self.eval_config.prompts, - ) + task_manager_models = list(self.eval_config.models) + if not any(_model.type == "main" for _model in task_manager_models): + task_manager_models.append(model_config.model_copy(update={"type": "main"})) + + _config = RailsConfig(models=task_manager_models, prompts=self.eval_config.prompts) # Initializer the LLMTaskManager self.llm_task_manager = LLMTaskManager(config=_config) diff --git a/tests/eval/test_eval_check.py b/tests/eval/test_eval_check.py new file mode 100644 index 0000000000..e26c2ccade --- /dev/null +++ b/tests/eval/test_eval_check.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +from nemoguardrails.eval.check import LLMJudgeComplianceChecker +from nemoguardrails.eval.models import ( + EvalConfig, + EvalOutput, + InteractionLog, + InteractionOutput, + InteractionSet, + Policy, +) +from nemoguardrails.eval.ui.utils import EvalData +from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.rails.llm.config import Model + + +def _checker(*, force=False, reset=False, policy_apply_to_all=True, response="Reason: ok\nCompliance: Yes"): + checker = LLMJudgeComplianceChecker.__new__(LLMJudgeComplianceChecker) + checker.eval_config = EvalConfig(policies=[Policy(id="policy", description="policy")], interactions=[], prompts=[]) + checker.policies = checker.eval_config.policies + checker.policy_by_id = {"policy": Policy(id="policy", description="policy", apply_to_all=policy_apply_to_all)} + checker.policy_ids = ["policy"] + checker.verbose = False + checker.force = force + checker.reset = reset + checker.parallel = 1 + checker.llm = MagicMock() + checker.llm_judge_model = "judge" + checker.progress = MagicMock() + checker.llm_task_manager = MagicMock() + checker.llm_task_manager.render_task_prompt.return_value = "rendered prompt" + checker.llm_response = response + return checker + + +def _interaction_set(*, include=None, exclude=None, expected=None): + return InteractionSet( + id="set", + inputs=["hello"], + expected_output=expected or [], + include_policies=include or [], + exclude_policies=exclude or [], + ) + + +def test_compliance_checker_init_builds_model_task_manager_and_eval_data(tmp_path, monkeypatch): + monkeypatch.setenv("JUDGE_API_KEY", "token") + eval_config = EvalConfig( + policies=[Policy(id="policy", description="policy")], + interactions=[], + models=[ + Model( + type="judge", + engine="mock", + model="judge-model", + api_key_env_var="JUDGE_API_KEY", + parameters={"temperature": 0}, + ) + ], + prompts=[], + ) + with ( + patch("nemoguardrails.eval.check.EvalConfig.from_path", return_value=eval_config) as mock_from_path, + patch("nemoguardrails.eval.check.init_llm_model", return_value="llm") as mock_init_llm_model, + ): + checker = LLMJudgeComplianceChecker( + eval_config_path=str(tmp_path), + output_paths=["run-a"], + llm_judge_model="judge-model", + policy_ids=[], + verbose=True, + force=True, + reset=True, + parallel=2, + ) + + mock_from_path.assert_called_once_with(str(tmp_path)) + mock_init_llm_model.assert_called_once_with( + model_name="judge-model", + provider_name="mock", + mode="chat", + kwargs={"temperature": 0, "api_key": "token"}, + ) + # The task manager is built through the real RailsConfig/LLMTaskManager path, + # so the models/prompts contract is validated by the actual constructors. + assert isinstance(checker.llm_task_manager, LLMTaskManager) + task_manager_models = checker.llm_task_manager.config.models + assert any(model.type == "judge" and model.model == "judge-model" for model in task_manager_models) + main_models = [model for model in task_manager_models if model.type == "main"] + assert len(main_models) == 1 + assert main_models[0].model == "judge-model" + assert checker.llm == "llm" + assert checker.policy_ids == ["policy"] + assert checker.eval_data.output_paths == ["run-a"] + assert checker.verbose is True + assert checker.force is True + assert checker.reset is True + assert checker.parallel == 2 + + +def test_compliance_checker_print_helpers_delegate_to_progress(): + checker = LLMJudgeComplianceChecker.__new__(LLMJudgeComplianceChecker) + checker.progress = MagicMock() + checker.parallel = 1 + + checker.print_prompt("[cyan]prompt[/]\n[/]\nplain") + checker.print_completion("completion") + checker.print_progress_detail("detail") + + printed = [call.args[0] for call in checker.progress.print.call_args_list] + assert len(printed) == 4 + assert printed[-1] == "detail" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("response", "expected"), + [ + ("Reason: good\nCompliance: Yes", True), + ("Reason: bad\nCompliance: No", False), + ("Reason: skip\nCompliance: n/a", "n/a"), + ], +) +async def test_check_interaction_compliance_records_valid_judgements(response, expected): + checker = _checker(response=response) + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": None}) + log = InteractionLog(id="set/0", events=[{"type": "Event"}]) + + with patch("nemoguardrails.eval.check.llm_call", AsyncMock(return_value=SimpleNamespace(content=response))): + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is True + assert output.compliance["policy"] == expected + assert output.compliance_checks[0].method == "judge" + assert output.compliance_checks[0].compliance == {"policy": expected} + assert log.compliance_checks[0].llm_calls[0].task == "llm_judge_check_single_policy_compliance" + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_turns_targeted_na_into_failure(): + checker = _checker(response="Reason: skip\nCompliance: n/a") + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": None}) + log = InteractionLog(id="set/0", events=[]) + + with patch( + "nemoguardrails.eval.check.llm_call", + AsyncMock(return_value=SimpleNamespace(content="Reason: skip\nCompliance: n/a")), + ): + changed = await checker.check_interaction_compliance( + output, + log, + _interaction_set(include=["policy"]), + 1, + ) + + assert changed is True + assert output.compliance["policy"] is False + assert "not acceptable" in output.compliance_checks[0].details + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_skips_not_applicable_policy(): + checker = _checker(policy_apply_to_all=False) + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": None}) + log = InteractionLog(id="set/0", events=[]) + + with patch("nemoguardrails.eval.check.llm_call", AsyncMock()) as mock_llm_call: + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is True + assert output.compliance["policy"] == "n/a" + mock_llm_call.assert_not_called() + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_skips_existing_rating_without_force(): + checker = _checker(force=False) + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": True}) + log = InteractionLog(id="set/0", events=[]) + + with patch("nemoguardrails.eval.check.llm_call", AsyncMock()) as mock_llm_call: + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is False + assert output.compliance_checks == [] + mock_llm_call.assert_not_called() + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_force_rechecks_existing_rating(): + checker = _checker(force=True, response="Reason: changed\nCompliance: No") + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": True}) + log = InteractionLog(id="set/0", events=[]) + + with patch( + "nemoguardrails.eval.check.llm_call", + AsyncMock(return_value=SimpleNamespace(content="Reason: changed\nCompliance: No")), + ): + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is True + assert output.compliance["policy"] is False + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_reset_clears_existing_checks(): + checker = _checker(reset=True) + output = InteractionOutput( + id="set/0", + input="hello", + compliance={"policy": None}, + compliance_checks=[ + { + "id": "old", + "created_at": "2024-01-01T00:00:00", + "interaction_id": "set/0", + "method": "old", + "compliance": {"policy": False}, + "details": "", + } + ], + ) + log = InteractionLog(id="set/0", compliance_checks=[{"id": "old", "llm_calls": []}]) + + with patch( + "nemoguardrails.eval.check.llm_call", + AsyncMock(return_value=SimpleNamespace(content="Reason: ok\nCompliance: Yes")), + ): + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is True + assert len(output.compliance_checks) == 1 + assert output.compliance_checks[0].method == "judge" + assert len(log.compliance_checks) == 1 + + +@pytest.mark.asyncio +async def test_check_interaction_compliance_ignores_invalid_response(): + checker = _checker(response="not parseable") + output = InteractionOutput(id="set/0", input="hello", compliance={"policy": None}) + log = InteractionLog(id="set/0", events=[]) + + with patch( + "nemoguardrails.eval.check.llm_call", + AsyncMock(return_value=SimpleNamespace(content="not parseable")), + ): + changed = await checker.check_interaction_compliance(output, log, _interaction_set(), 1) + + assert changed is False + assert output.compliance["policy"] is None + assert output.compliance_checks == [] + + +@pytest.mark.asyncio +async def test_compliance_checker_run_updates_changed_outputs(tmp_path): + interaction_set = _interaction_set() + eval_config = EvalConfig( + policies=[Policy(id="policy", description="policy")], + interactions=[interaction_set], + ) + eval_output = EvalOutput( + results=[ + InteractionOutput(id="set/0", input="hello", output="hi", compliance={"policy": None}), + InteractionOutput(id="set/1", input="bye", output="bye", compliance={"policy": None}), + ], + logs=[InteractionLog(id="set/0"), InteractionLog(id="set/1")], + ) + checker = LLMJudgeComplianceChecker.__new__(LLMJudgeComplianceChecker) + checker.output_paths = [str(tmp_path)] + checker.eval_config = eval_config + checker.eval_data = EvalData( + eval_config_path="config", + eval_config=eval_config, + output_paths=[str(tmp_path)], + eval_outputs={}, + ) + checker.parallel = 1 + checker.check_interaction_compliance = AsyncMock(side_effect=[True, False]) + checker.progress_idx = 0 + + with ( + patch("nemoguardrails.eval.check.EvalOutput.from_path", return_value=eval_output), + patch.object(EvalData, "update_results_and_logs", autospec=True) as mock_update_results_and_logs, + ): + await checker.run() + + assert checker.eval_data.eval_outputs[str(tmp_path)] == eval_output + assert checker.check_interaction_compliance.await_count == 2 + assert checker.check_interaction_compliance.await_args_list[0].kwargs["interaction_set"] == interaction_set + assert mock_update_results_and_logs.call_args_list == [ + call(checker.eval_data, str(tmp_path)), + call(checker.eval_data, str(tmp_path)), + ] diff --git a/tests/eval/test_eval_cli.py b/tests/eval/test_eval_cli.py new file mode 100644 index 0000000000..fbf2d94109 --- /dev/null +++ b/tests/eval/test_eval_cli.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from nemoguardrails.eval import cli as eval_cli + +runner = CliRunner() + + +def test_run_command_requires_guardrail_config(tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + + result = runner.invoke(eval_cli.app, ["run", "-e", str(config_dir)]) + + assert result.exit_code == 1 + assert "No guardrail configuration provided" in result.stdout + + +def test_run_command_invokes_run_eval(tmp_path): + eval_dir = tmp_path / "eval" + guardrails_dir = tmp_path / "guardrails" + output_dir = tmp_path / "output" + eval_dir.mkdir() + guardrails_dir.mkdir() + + with patch("nemoguardrails.eval.cli.run_eval", AsyncMock()) as mock_run_eval: + result = runner.invoke( + eval_cli.app, + [ + "run", + "-e", + str(eval_dir), + "-g", + str(guardrails_dir), + "-o", + str(output_dir), + "--output-format", + "YAML", + "--parallel", + "3", + ], + ) + + assert result.exit_code == 0 + mock_run_eval.assert_awaited_once_with( + eval_config_path=str(eval_dir.resolve()), + guardrail_config_path=str(guardrails_dir), + output_path=str(output_dir), + output_format="yaml", + parallel=3, + ) + + +def test_check_compliance_command_uses_explicit_output_paths(tmp_path): + eval_dir = tmp_path / "eval" + out_a = tmp_path / "out-a" + out_b = tmp_path / "out-b" + eval_dir.mkdir() + out_a.mkdir() + out_b.mkdir() + checker = MagicMock() + checker.run = AsyncMock() + + with patch("nemoguardrails.eval.cli.LLMJudgeComplianceChecker", return_value=checker) as mock_checker: + result = runner.invoke( + eval_cli.app, + [ + "check-compliance", + "--llm-judge", + "judge-model", + "-e", + str(eval_dir), + "-o", + f"{out_a},{out_b}", + "-p", + "policy-a", + "-p", + "policy-b", + "--force", + "--reset", + "--parallel", + "2", + "--disable-llm-cache", + ], + ) + + assert result.exit_code == 0 + mock_checker.assert_called_once_with( + eval_config_path=str(eval_dir), + output_paths=[str(out_a), str(out_b)], + llm_judge_model="judge-model", + policy_ids=["policy-a", "policy-b"], + verbose=False, + force=True, + reset=True, + parallel=2, + ) + checker.run.assert_awaited_once() + + +def test_check_compliance_command_discovers_output_paths(tmp_path): + eval_dir = tmp_path / "eval" + eval_dir.mkdir() + checker = MagicMock() + checker.run = AsyncMock() + + with ( + patch("nemoguardrails.eval.cli.get_output_paths", return_value=["run-a"]) as mock_get_output_paths, + patch("nemoguardrails.eval.cli.LLMJudgeComplianceChecker", return_value=checker) as mock_checker, + ): + result = runner.invoke( + eval_cli.app, + [ + "check-compliance", + "--llm-judge", + "judge-model", + "-e", + str(eval_dir), + "--disable-llm-cache", + ], + ) + + assert result.exit_code == 0 + mock_get_output_paths.assert_called_once() + mock_checker.assert_called_once() + assert mock_checker.call_args.kwargs["output_paths"] == ["run-a"] + + +def test_ui_command_launches_readme_page(tmp_path): + eval_dir = tmp_path / "eval" + eval_dir.mkdir() + + with ( + patch("nemoguardrails.eval.cli.get_output_paths", return_value=["run-a"]) as mock_get_output_paths, + patch("nemoguardrails.eval.cli._launch_ui") as mock_launch_ui, + ): + result = runner.invoke(eval_cli.app, ["ui", "--eval-config-path", str(eval_dir)]) + + assert result.exit_code == 0 + mock_get_output_paths.assert_called_once() + mock_launch_ui.assert_called_once_with("README.py", port=8501) + + +def test_launch_ui_exits_when_streamlit_is_missing(): + with patch.dict("sys.modules", {"streamlit.web": None}), pytest.raises(SystemExit) as exc_info: + eval_cli._launch_ui("README.py") + + assert exc_info.value.code == 1 diff --git a/tests/eval/test_eval_runtime.py b/tests/eval/test_eval_runtime.py new file mode 100644 index 0000000000..748db68c43 --- /dev/null +++ b/tests/eval/test_eval_runtime.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nemoguardrails.eval.eval import ( + _extract_interaction_log, + _extract_interaction_outputs, + _extract_spans, + _load_eval_output, + run_eval, +) +from nemoguardrails.eval.models import ( + EvalConfig, + EvalOutput, + InteractionLog, + InteractionOutput, + InteractionSet, + Policy, + Span, +) +from nemoguardrails.eval.utils import ( + _collect_span_metrics, + get_output_paths, + load_dict_from_file, + load_dict_from_path, + save_dict_to_file, + save_eval_output, + update_dict_at_path, +) +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.rails.llm.options import ( + ActivatedRail, + ExecutedAction, + GenerationLog, + GenerationResponse, +) + + +def _eval_config(): + return EvalConfig( + policies=[ + Policy(id="global", description="global policy"), + Policy(id="targeted", description="targeted policy", apply_to_all=False), + Policy(id="included", description="included policy", apply_to_all=False), + Policy(id="excluded", description="excluded policy"), + ], + interactions=[ + InteractionSet( + id="set", + inputs=[ + "hello", + { + "type": "messages", + "messages": [{"role": "user", "content": "hello"}], + }, + ], + expected_output=[{"type": "refusal", "policy": "targeted"}], + include_policies=["included"], + exclude_policies=["excluded"], + ) + ], + ) + + +def _activated_rail(): + return ActivatedRail( + type="input", + name="check input", + started_at=10.0, + finished_at=13.0, + duration=3.0, + executed_actions=[ + ExecutedAction( + action_name="self_check", + started_at=10.5, + finished_at=12.5, + duration=2.0, + llm_calls=[ + LLMCallInfo( + llm_model_name="model/name", + started_at=11.0, + finished_at=12.0, + duration=1.0, + prompt_tokens=7, + completion_tokens=3, + total_tokens=10, + ) + ], + ) + ], + ) + + +def test_extract_interaction_outputs_initializes_policy_statuses(): + outputs = _extract_interaction_outputs(_eval_config()) + + assert [output.id for output in outputs] == ["set/0", "set/1"] + assert outputs[0].input == "hello" + assert outputs[1].input["messages"][0]["content"] == "hello" + assert outputs[0].compliance == { + "global": None, + "targeted": None, + "included": None, + "excluded": "n/a", + } + + +def test_load_eval_output_reuses_matching_results_and_replaces_changed_inputs(tmp_path): + existing_output = EvalOutput( + results=[ + InteractionOutput(id="set/0", input="hello", output="old output", compliance={"global": True}), + InteractionOutput(id="set/1", input="old input", output="stale output", compliance={"global": False}), + ], + logs=[InteractionLog(id="set/0", events=[{"type": "Existing"}])], + ) + save_eval_output(existing_output, str(tmp_path), "json") + + output = _load_eval_output(str(tmp_path), _eval_config()) + + assert output.results[0].output == "old output" + assert output.logs[0].events == [{"type": "Existing"}] + assert output.results[1].output is None + assert output.results[1].input["messages"][0]["content"] == "hello" + assert output.logs[1] == InteractionLog(id="set/1") + + +def test_extract_spans_builds_trace_and_metrics(): + spans = _extract_spans([_activated_rail()]) + + assert [span.name for span in spans] == [ + "interaction", + "rail: check input", + "action: self_check", + "LLM: model/name", + ] + assert spans[0].duration == 3.0 + assert spans[1].parent_id == spans[0].span_id + assert spans[2].metrics["action_self_check_seconds_total"] == 2.0 + assert spans[3].metrics["llm_call_model_name_prompt_tokens_total"] == 7 + assert spans[3].metrics["llm_call_model_name_tokens_total"] == 10 + + +def test_extract_interaction_log_uses_generation_log_data(): + interaction = InteractionOutput(id="set/0", input="hello") + generation_log = GenerationLog( + activated_rails=[_activated_rail()], + internal_events=[{"type": "UserIntent", "intent": "greet"}], + ) + + log = _extract_interaction_log(interaction, generation_log) + + assert log.id == "set/0" + assert log.activated_rails == [_activated_rail()] + assert log.events == [{"type": "UserIntent", "intent": "greet"}] + assert log.trace[0].name == "interaction" + + +def test_collect_span_metrics_sums_totals_and_averages_avg_metrics(): + metrics = _collect_span_metrics( + [ + Span(span_id="1", name="one", start_time=0, end_time=1, duration=1, metrics={"calls_total": 1}), + Span( + span_id="2", + name="two", + start_time=1, + end_time=2, + duration=1, + metrics={"calls_total": 2, "latency_avg": 4.0}, + ), + Span( + span_id="3", + name="three", + start_time=2, + end_time=3, + duration=1, + metrics={"latency_avg": 6.0}, + ), + ] + ) + + assert metrics == {"calls_total": 3, "latency_avg": 5.0} + + +def test_eval_output_compute_compliance_counts_statuses(): + eval_config = EvalConfig( + policies=[ + Policy(id="policy", description="policy"), + Policy(id="missing", description="missing"), + ], + interactions=[], + ) + output = EvalOutput( + results=[ + InteractionOutput(id="1", input="a", compliance={"policy": True, "missing": None}), + InteractionOutput(id="2", input="b", compliance={"policy": False, "missing": None}), + InteractionOutput(id="3", input="c", compliance={"policy": "n/a", "missing": None}), + InteractionOutput(id="4", input="d", compliance={"policy": None, "missing": None}), + ] + ) + + compliance = output.compute_compliance(eval_config) + + assert compliance["policy"]["rate"] == pytest.approx(1 / 3) + assert compliance["policy"]["interactions_comply_count"] == 1 + assert compliance["policy"]["interactions_violation_count"] == 1 + assert compliance["policy"]["interactions_not_applicable_count"] == 1 + assert compliance["policy"]["interactions_not_rated_count"] == 1 + assert compliance["missing"]["interactions_not_rated_count"] == 4 + + +def test_load_save_and_update_dict_helpers(tmp_path, monkeypatch): + config_dir = tmp_path / "config" + config_dir.mkdir() + yaml_file = config_dir / "policies.yaml" + json_file = config_dir / "interactions.json" + yaml_file.write_text("policies:\n - id: p1\n description: one\n", encoding="utf-8") + json_file.write_text(json.dumps({"interactions": [{"id": "i1"}]}), encoding="utf-8") + + assert load_dict_from_file(str(yaml_file)) == {"policies": [{"id": "p1", "description": "one"}]} + assert load_dict_from_path(str(config_dir)) == { + "policies": [{"id": "p1", "description": "one"}], + "interactions": [{"id": "i1"}], + } + + update_dict_at_path(str(config_dir), {"interactions": [{"id": "i2"}]}) + assert load_dict_from_file(str(json_file)) == {"interactions": [{"id": "i2"}]} + + save_dict_to_file({"value": 1}, str(tmp_path / "saved"), "json") + assert load_dict_from_file(str(tmp_path / "saved.json")) == {"value": 1} + + (tmp_path / "run-a").mkdir() + (tmp_path / "config").mkdir(exist_ok=True) + (tmp_path / ".hidden").mkdir() + monkeypatch.chdir(tmp_path) + assert get_output_paths() == [str(tmp_path / "run-a")] + + +@pytest.mark.asyncio +async def test_run_eval_generates_for_string_and_message_inputs(tmp_path): + eval_config = _eval_config() + rails = MagicMock() + rails.generate_async = AsyncMock( + return_value=GenerationResponse( + response="ok", + log=GenerationLog( + activated_rails=[_activated_rail()], + internal_events=[{"type": "Done"}], + ), + ) + ) + + with ( + patch("nemoguardrails.eval.eval.EvalConfig.from_path", return_value=eval_config) as mock_eval_config, + patch("nemoguardrails.eval.eval.RailsConfig.from_path", return_value=SimpleNamespace()) as mock_rails_config, + patch("nemoguardrails.eval.eval.LLMRails", return_value=rails) as mock_rails_cls, + ): + await run_eval( + eval_config_path=str(tmp_path / "eval"), + guardrail_config_path=str(tmp_path / "guardrails"), + output_path=str(tmp_path / "output"), + output_format="json", + parallel=1, + ) + + mock_eval_config.assert_called_once() + mock_rails_config.assert_called_once_with(str(tmp_path / "guardrails")) + mock_rails_cls.assert_called_once() + assert rails.generate_async.await_count == 2 + first_call = rails.generate_async.await_args_list[0].kwargs + second_call = rails.generate_async.await_args_list[1].kwargs + assert first_call["prompt"] == "hello" + assert second_call["messages"] == [{"role": "user", "content": "hello"}] + output = EvalOutput.from_path(str(tmp_path / "output")) + assert [result.output for result in output.results] == ["ok", "ok"] + assert output.logs[0].events == [{"type": "Done"}] diff --git a/tests/eval/test_eval_ui_utils.py b/tests/eval/test_eval_ui_utils.py new file mode 100644 index 0000000000..e4450551c2 --- /dev/null +++ b/tests/eval/test_eval_ui_utils.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from nemoguardrails.eval.models import ( + EvalConfig, + EvalOutput, + InteractionLog, + InteractionOutput, + Policy, + Span, +) + +# The eval UI modules import `plotly.express` at module load time, so a stub must +# be present before the imports below (a fixture runs too late: these imports +# happen during collection). We only inject the stub when plotly cannot really be +# imported, so an installed plotly is never shadowed, and we remove our own stubs +# after this module's tests so they do not leak into other files on the worker. +_injected_plotly_modules = {} +try: + import plotly.express # noqa: F401 +except ImportError: + for _module_name in ("plotly", "plotly.express"): + if _module_name not in sys.modules: + _injected_plotly_modules[_module_name] = types.ModuleType(_module_name) + sys.modules[_module_name] = _injected_plotly_modules[_module_name] + + +@pytest.fixture(scope="module", autouse=True) +def _cleanup_plotly_stub(): + yield + for _module_name, stub_module in _injected_plotly_modules.items(): + if sys.modules.get(_module_name) is stub_module: + del sys.modules[_module_name] + # The eval UI modules below were imported while the plotly stub was active, so + # drop them too; otherwise they stay cached bound to the stubbed plotly.express + # and a later import in this worker would not rebind the real dependency. + if _injected_plotly_modules: + for _ui_module_name in ("nemoguardrails.eval.ui.common", "nemoguardrails.eval.ui.chart_utils"): + sys.modules.pop(_ui_module_name, None) + + +ui_common = importlib.import_module("nemoguardrails.eval.ui.common") +ui_utils = importlib.import_module("nemoguardrails.eval.ui.utils") +chart_utils = importlib.import_module("nemoguardrails.eval.ui.chart_utils") +streamlit_utils = importlib.import_module("nemoguardrails.eval.ui.streamlit_utils") +readme_page = importlib.import_module("nemoguardrails.eval.ui.README") + +_get_compliance_df = ui_common._get_compliance_df +_get_resource_usage_and_latencies_df = ui_common._get_resource_usage_and_latencies_df +EvalData = ui_utils.EvalData +collect_interaction_metrics = ui_utils.collect_interaction_metrics +collect_interaction_metrics_with_expected_latencies = ui_utils.collect_interaction_metrics_with_expected_latencies + + +class _Context: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Sidebar(_Context): + def expander(self, *args, **kwargs): + return _Context() + + +class _FakeStreamlit: + def __init__(self, checkbox_values=None, button_values=None): + self.checkbox_values = list(checkbox_values or []) + self.button_values = list(button_values or []) + self.session_state = SimpleNamespace(use_expected_latencies=False) + self.sidebar = _Sidebar() + self.calls = [] + + def _record(self, name, *args, **kwargs): + self.calls.append((name, args, kwargs)) + + def checkbox(self, *args, **kwargs): + self._record("checkbox", *args, **kwargs) + return self.checkbox_values.pop(0) if self.checkbox_values else True + + def button(self, *args, **kwargs): + self._record("button", *args, **kwargs) + return self.button_values.pop(0) if self.button_values else False + + def expander(self, *args, **kwargs): + self._record("expander", *args, **kwargs) + return _Context() + + def rerun(self): + self._record("rerun") + + def __getattr__(self, name): + def method(*args, **kwargs): + self._record(name, *args, **kwargs) + + return method + + +def _eval_data(): + eval_config = EvalConfig( + policies=[Policy(id="p1", description="one"), Policy(id="p2", description="two")], + interactions=[ + { + "id": "1", + "inputs": ["a"], + "expected_output": [], + "tags": ["keep"], + }, + { + "id": "2", + "inputs": ["b"], + "expected_output": [], + "tags": ["drop"], + }, + ], + expected_latencies={ + "llm_call_main_fixed_latency": 1.0, + "llm_call_main_prompt_token_latency": 0.5, + "llm_call_main_completion_token_latency": 0.25, + }, + ) + output = EvalOutput( + results=[ + InteractionOutput( + id="1", + input="a", + compliance={"p1": True, "p2": False}, + resource_usage={"llm_call_main_total": 1, "tokens_total": 10}, + latencies={"llm_call_main_seconds_avg": 2.0}, + ), + InteractionOutput( + id="2", + input="b", + compliance={"p1": False, "p2": "n/a"}, + resource_usage={"llm_call_main_total": 1, "tokens_total": 5}, + latencies={"llm_call_main_seconds_avg": 4.0}, + ), + ], + logs=[ + InteractionLog( + id="1", + trace=[ + Span( + span_id="interaction", + name="interaction", + start_time=0, + end_time=2, + duration=2, + metrics={ + "interaction_seconds_avg": 2.0, + "interaction_seconds_total": 2.0, + }, + ), + Span( + span_id="llm", + parent_id="interaction", + name="LLM: main", + start_time=0, + end_time=2, + duration=2, + metrics={ + "llm_call_main_total": 1, + "llm_call_main_seconds_avg": 2.0, + "llm_call_main_seconds_total": 2.0, + "llm_call_main_prompt_tokens_total": 4, + "llm_call_main_completion_tokens_total": 2, + "llm_call_main_tokens_total": 6, + }, + ), + ], + ), + InteractionLog(id="2"), + ], + ) + return EvalData( + eval_config_path="config", + eval_config=eval_config, + output_paths=["run-a"], + eval_outputs={"run-a": output}, + ) + + +def test_collect_interaction_metrics_sums_resource_usage_and_averages_latency(): + metrics = collect_interaction_metrics(_eval_data().eval_outputs["run-a"].results) + + assert metrics["llm_call_main_total"] == 2 + assert metrics["tokens_total"] == 15 + assert metrics["llm_call_main_seconds_avg"] == 3.0 + + +def test_collect_interaction_metrics_with_expected_latencies_recomputes_llm_spans(): + eval_data = _eval_data() + + metrics = collect_interaction_metrics_with_expected_latencies( + [eval_data.eval_outputs["run-a"].results[0]], + eval_data.eval_outputs["run-a"].logs, + eval_data.eval_config.expected_latencies, + ) + + assert metrics["llm_call_main_seconds_avg"] == pytest.approx(3.5) + assert metrics["interaction_seconds_avg"] == pytest.approx(3.5) + assert metrics["tokens_total"] == 10 + + +def test_get_compliance_df_builds_policy_rows(): + df = _get_compliance_df(["run-a"], ["p1", "p2"], _eval_data()) + + rows = df.to_dict("records") + assert rows == [ + { + "Guardrail Config": "run-a", + "Policy": "p1", + "Compliance Rate": 50.0, + "Violations Count": 1, + "Interactions Count": 2, + }, + { + "Guardrail Config": "run-a", + "Policy": "p2", + "Compliance Rate": 0.0, + "Violations Count": 1, + "Interactions Count": 1, + }, + ] + + +def test_get_resource_usage_and_latencies_df_splits_metric_tables(): + eval_data = _eval_data() + + resource_df, latency_df = _get_resource_usage_and_latencies_df( + ["run-a"], + eval_data, + eval_data.eval_config, + use_expected_latencies=False, + ) + + assert resource_df.to_dict("records") == [ + {"Metric": "llm_call_main_total", "run-a": 2}, + {"Metric": "tokens_total", "run-a": 15}, + ] + assert latency_df.to_dict("records") == [{"Metric": "llm_call_main_seconds_avg", "run-a": 3.0}] + + +def test_eval_data_update_methods_delegate_to_update_dict_at_path(): + eval_data = _eval_data() + eval_data.selected_output_path = "run-a" + calls = [] + + def fake_update(path, data): + calls.append((path, data)) + + from nemoguardrails.eval.ui import utils as ui_utils + + original = ui_utils.update_dict_at_path + ui_utils.update_dict_at_path = fake_update + try: + eval_data.update_results() + eval_data.update_results_and_logs("run-a") + eval_data.update_config_latencies() + finally: + ui_utils.update_dict_at_path = original + + assert calls[0][0] == "run-a" + assert "results" in calls[0][1] + assert calls[1][0] == "run-a" + assert set(calls[1][1]) == {"results", "logs"} + assert calls[2] == ("config", {"expected_latencies": eval_data.eval_config.expected_latencies}) + + +def test_chart_utils_render_charts_and_optional_tables(monkeypatch): + fake_st = _FakeStreamlit() + fake_bar = MagicMock(return_value="figure") + monkeypatch.setattr(chart_utils, "st", fake_st) + monkeypatch.setattr(chart_utils.px, "bar", fake_bar, raising=False) + + import pandas as pd + + chart_utils.plot_as_series(pd.DataFrame({"Config": ["a"], "Value": [1]}), include_table=True) + chart_utils.plot_bar_series(pd.DataFrame({"Config": ["a"], "Metric": ["m"], "Value": [1]}), include_table=True) + chart_utils.plot_matrix_series( + pd.DataFrame({"Metric": ["m"], "run-a": [1]}), + var_name="Guardrail Config", + value_name="Value", + include_table=True, + ) + + assert fake_bar.call_count == 3 + assert [call[0] for call in fake_st.calls].count("plotly_chart") == 3 + assert [call[0] for call in fake_st.calls].count("dataframe") == 3 + + +def test_streamlit_utils_get_span_colors_is_stable(): + output = EvalOutput( + logs=[ + InteractionLog( + id="1", + trace=[ + Span(span_id="1", name="interaction", start_time=0, end_time=1, duration=1), + Span(span_id="2", name="rail", start_time=0, end_time=1, duration=1), + ], + ), + InteractionLog( + id="2", + trace=[Span(span_id="3", name="interaction", start_time=0, end_time=1, duration=1)], + ), + ] + ) + + colors = streamlit_utils.get_span_colors(output) + + assert set(colors) == {"interaction", "rail"} + assert all(value.startswith("#") and len(value) == 7 for value in colors.values()) + + +def test_streamlit_utils_load_eval_data_uses_discovered_paths(tmp_path, monkeypatch): + eval_config = EvalConfig(policies=[Policy(id="p1", description="one")], interactions=[]) + eval_output = EvalOutput(results=[InteractionOutput(id="1/0", input="a")], logs=[InteractionLog(id="1/0")]) + hidden_output = EvalOutput(results=[InteractionOutput(id="2/0", input="b")], logs=[InteractionLog(id="2/0")]) + output_dir = tmp_path / "run-a" + hidden_dir = tmp_path / ".hidden" + output_dir.mkdir() + hidden_dir.mkdir() + + with ( + patch("sys.argv", ["streamlit", "--eval-config-path", str(tmp_path / "config")]), + patch.object(streamlit_utils.EvalConfig, "from_path", return_value=eval_config) as mock_config, + patch.object(streamlit_utils.EvalOutput, "from_path", side_effect=[eval_output, hidden_output]) as mock_output, + patch.object(streamlit_utils, "get_output_paths", return_value=[str(output_dir), str(hidden_dir)]), + ): + monkeypatch.chdir(tmp_path) + streamlit_utils.load_eval_data.clear() + data = streamlit_utils.load_eval_data() + + mock_config.assert_called_once_with(str((tmp_path / "config").resolve())) + assert mock_output.call_count == 1 + assert data.output_paths == [str(output_dir), str(hidden_dir)] + assert data.eval_outputs == {"run-a": eval_output} + + +def test_eval_readme_page_renders_markdown(monkeypatch): + fake_st = _FakeStreamlit() + monkeypatch.setattr(readme_page, "st", fake_st) + + readme_page.main() + + markdown_calls = [call for call in fake_st.calls if call[0] == "markdown"] + assert len(markdown_calls) == 1 + assert markdown_calls[0][2]["unsafe_allow_html"] is True + + +def test_render_sidebar_filters_and_reload(monkeypatch): + fake_st = _FakeStreamlit( + checkbox_values=[True, True, False, True, False, False, True], + button_values=[True], + ) + clear = MagicMock() + monkeypatch.setattr(ui_common, "st", fake_st) + monkeypatch.setattr(ui_common.load_eval_data, "clear", clear) + + output_names, policy_options, tags = ui_common._render_sidebar(["run-a", "run-b"], ["p1", "p2"], ["keep", "drop"]) + + assert output_names == ["run-a"] + assert policy_options == ["p1"] + assert tags == ["drop"] + clear.assert_called_once() + assert any(call[0] == "rerun" for call in fake_st.calls) + + +def test_render_compliance_data_full_and_short(monkeypatch): + fake_st = _FakeStreamlit() + plot_calls = [] + monkeypatch.setattr(ui_common, "st", fake_st) + monkeypatch.setattr(ui_common, "plot_as_series", lambda *args, **kwargs: plot_calls.append(("series", kwargs))) + monkeypatch.setattr(ui_common, "plot_bar_series", lambda *args, **kwargs: plot_calls.append(("bar", kwargs))) + + ui_common._render_compliance_data(["run-a"], ["p1", "p2"], _eval_data(), short=False) + ui_common._render_compliance_data(["run-a"], ["p1"], _eval_data(), short=True) + + assert any(call[0] == "info" for call in fake_st.calls) + assert [call[0] for call in plot_calls].count("series") == 2 + assert [call[0] for call in plot_calls].count("bar") == 3 + + +def test_render_resource_usage_and_latencies_full_and_short(monkeypatch): + fake_st = _FakeStreamlit(checkbox_values=[True, True]) + fake_st.session_state.use_expected_latencies = False + plot_calls = [] + monkeypatch.setattr(ui_common, "st", fake_st) + monkeypatch.setattr(ui_common, "plot_as_series", lambda *args, **kwargs: plot_calls.append(("series", kwargs))) + monkeypatch.setattr(ui_common, "plot_bar_series", lambda *args, **kwargs: plot_calls.append(("bar", kwargs))) + monkeypatch.setattr(ui_common, "plot_matrix_series", lambda *args, **kwargs: plot_calls.append(("matrix", kwargs))) + + eval_data = _eval_data() + for result in eval_data.eval_outputs["run-a"].results: + result.resource_usage.update( + { + "llm_call_aux_total": 1, + "llm_call_main_prompt_tokens_total": 4, + "llm_call_main_completion_tokens_total": 2, + "llm_call_main_tokens_total": 6, + } + ) + result.latencies.update( + { + "interaction_seconds_total": 2.0, + "interaction_seconds_avg": 1.0, + "llm_call_main_seconds_total": 1.0, + "llm_call_main_seconds_avg": 1.0, + "action_self_check_seconds_total": 1.0, + "action_self_check_seconds_avg": 1.0, + } + ) + + ui_common._render_resource_usage_and_latencies(["run-a"], eval_data, eval_data.eval_config, short=False) + ui_common._render_resource_usage_and_latencies(["run-a"], eval_data, eval_data.eval_config, short=True) + + assert any(call[0] == "dataframe" for call in fake_st.calls) + assert [call[0] for call in plot_calls].count("matrix") >= 5 + assert [call[0] for call in plot_calls].count("series") >= 3 + + +def test_render_summary_filters_by_selected_tags(monkeypatch): + eval_data = _eval_data() + fake_st = _FakeStreamlit() + monkeypatch.setattr(ui_common, "st", fake_st) + monkeypatch.setattr(ui_common, "load_eval_data", MagicMock(return_value=eval_data)) + monkeypatch.setattr( + ui_common, + "_render_sidebar", + MagicMock(return_value=(["run-a"], ["p1"], ["keep"])), + ) + render_compliance = MagicMock() + render_resources = MagicMock() + monkeypatch.setattr(ui_common, "_render_compliance_data", render_compliance) + monkeypatch.setattr(ui_common, "_render_resource_usage_and_latencies", render_resources) + + ui_common.render_summary(short=True) + + filtered_eval_data = render_compliance.call_args.args[2] + assert [result.id for result in filtered_eval_data.eval_outputs["run-a"].results] == ["1"] + render_resources.assert_called_once() diff --git a/tests/evaluate/test_evaluate_cli_and_data.py b/tests/evaluate/test_evaluate_cli_and_data.py new file mode 100644 index 0000000000..9a92eda219 --- /dev/null +++ b/tests/evaluate/test_evaluate_cli_and_data.py @@ -0,0 +1,338 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from nemoguardrails.evaluate.cli import evaluate +from nemoguardrails.evaluate.cli.simplify_formatter import SimplifyFormatter +from nemoguardrails.evaluate.data.moderation import process_anthropic_dataset +from nemoguardrails.evaluate.data.topical.dataset_tools import ( + Banking77Connector, + ChitChatConnector, + DatasetConnector, + Intent, + IntentExample, +) + +runner = CliRunner() + + +def test_topical_command_rejects_multiple_configs(tmp_path): + config_a = tmp_path / "a" + config_b = tmp_path / "b" + config_a.mkdir() + config_b.mkdir() + + result = runner.invoke( + evaluate.app, + ["topical", "--config", str(config_a), "--config", str(config_b)], + ) + + assert result.exit_code == 1 + assert "Multiple configurations are not supported" in result.stdout + + +def test_topical_command_invokes_evaluation_class(tmp_path): + config = tmp_path / "config" + output = tmp_path / "output" + config.mkdir() + topical_eval = MagicMock() + + with ( + patch("nemoguardrails.evaluate.cli.evaluate.TopicalRailsEvaluation", return_value=topical_eval) as mock_cls, + patch("nemoguardrails.evaluate.cli.evaluate.set_verbose") as mock_set_verbose, + ): + result = runner.invoke( + evaluate.app, + [ + "topical", + "--config", + str(config), + "--verbose", + "--test-percentage", + "0.4", + "--max-tests-intent", + "4", + "--max-samples-intent", + "2", + "--results-frequency", + "5", + "--sim-threshold", + "0.75", + "--random-seed", + "123", + "--output-dir", + str(output), + ], + ) + + assert result.exit_code == 0 + mock_set_verbose.assert_called_once_with(True) + mock_cls.assert_called_once_with( + config=str(config), + verbose=True, + test_set_percentage=0.4, + max_samples_per_intent=2, + max_tests_per_intent=4, + print_test_results_frequency=5, + similarity_threshold=0.75, + random_seed=123, + output_dir=str(output), + ) + topical_eval.evaluate_topical_rails.assert_called_once() + + +def test_rail_commands_invoke_evaluation_classes(tmp_path): + config = tmp_path / "config" + dataset = tmp_path / "dataset.txt" + output = tmp_path / "output" + config.mkdir() + dataset.write_text("prompt\n", encoding="utf-8") + + cases = [ + ( + [ + "moderation", + "--config", + str(config), + "--dataset-path", + str(dataset), + "--num-samples", + "3", + "--output-dir", + str(output), + "--split", + "helpful", + ], + "ModerationRailsEvaluation", + (str(config), str(dataset), 3, True, True, str(output), True, "helpful"), + ), + ( + [ + "hallucination", + "--config", + str(config), + "--dataset-path", + str(dataset), + "--num-samples", + "2", + "--output-dir", + str(output), + ], + "HallucinationRailsEvaluation", + (str(config), str(dataset), 2, str(output), True), + ), + ( + [ + "fact-checking", + "--config", + str(config), + "--dataset-path", + str(dataset), + "--num-samples", + "4", + "--output-dir", + str(output), + ], + "FactCheckEvaluation", + (str(config), str(dataset), 4, True, str(output), True), + ), + ] + + for args, class_name, expected_args in cases: + instance = MagicMock() + with patch(f"nemoguardrails.evaluate.cli.evaluate.{class_name}", return_value=instance) as mock_cls: + result = runner.invoke(evaluate.app, args) + + assert result.exit_code == 0 + mock_cls.assert_called_once_with(*expected_args) + instance.run.assert_called_once() + + +def test_process_anthropic_dataset_helpers(tmp_path, monkeypatch): + harmful_path = tmp_path / "harmful.json" + helpful_path = tmp_path / "helpful.jsonl" + harmful_path.write_text( + json.dumps( + [ + { + "transcript": "\n\nHuman: unsafe\n\nAssistant: no", + "task_descripton_harmlessness_score": -2, + }, + { + "transcript": "\n\nHuman: safe\n\nAssistant: ok", + "task_descripton_harmlessness_score": 0, + }, + ] + ), + encoding="utf-8", + ) + helpful_path.write_text( + json.dumps({"chosen": "\n\nHuman: useful\n\nAssistant: ok"}) + "\n", + encoding="utf-8", + ) + + assert process_anthropic_dataset.split_messages("\n\nHuman: hello\n\nAssistant: hi") == ( + ["hello"], + ["hi"], + ) + assert len(process_anthropic_dataset.load_dataset(str(harmful_path), "harmful")) == 2 + assert len(process_anthropic_dataset.load_dataset(str(helpful_path), "helpful")) == 1 + + monkeypatch.chdir(tmp_path) + process_anthropic_dataset.process_anthropic_harmful_data(str(harmful_path), rating=4.0) + process_anthropic_dataset.process_anthropic_helpful_data(str(helpful_path)) + + assert (tmp_path / "anthropic_harmful.txt").read_text(encoding="utf-8") == "unsafe\n" + assert (tmp_path / "anthropic_helpful.txt").read_text(encoding="utf-8") == "useful\n" + + +def test_process_anthropic_dataset_main_dispatches(tmp_path): + with ( + patch.object(process_anthropic_dataset, "process_anthropic_harmful_data") as mock_harmful, + patch.object(process_anthropic_dataset, "process_anthropic_helpful_data") as mock_helpful, + ): + process_anthropic_dataset.main(dataset_path="data.json", rating=3.0, split="harmful") + process_anthropic_dataset.main(dataset_path="data.jsonl", rating=3.0, split="helpful") + + mock_harmful.assert_called_once_with("data.json", 3.0) + mock_helpful.assert_called_once_with("data.jsonl") + + +def test_simplify_formatter_masks_noisy_log_details(): + formatter = SimplifyFormatter("%(message)s") + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg=( + "Process internal event: {'id': '123e4567-e89b-12d3-a456-426614174000', " + "'_created_at': '2024-01-01T00:00:00.000000+00:00', " + "'final_transcript': 'hello', 'loop_id': 'main-123'} " + ), + args=(), + exc_info=None, + ) + + text = formatter.format(record) + + assert "123e..." in text + assert "_created_at" not in text + assert "<>" in text + assert ":thumbs_up:'final_transcript': 'hello':thumbs_up:" in text + assert "Process internal event" not in text + assert formatter.format("Processing event details") == "" + assert formatter.format("prefix :: hidden") == "" + + +def test_dataset_connector_sampling_and_colang_output(tmp_path): + connector = DatasetConnector(name="test") + intent = Intent(intent_name="greet", canonical_form="greet user") + connector.intents.add(intent) + connector.intent_examples = [ + IntentExample(intent=intent, text='say "hello"'), + IntentExample(intent=intent, text="good morning"), + ] + + sample = connector.get_intent_sample("greet", num_samples=1) + assert len(sample) == 1 + assert sample[0] in {'say "hello"', "good morning"} + + output_path = tmp_path / "user.co" + connector.write_colang_output(str(output_path), num_samples_per_intent=0) + content = output_path.read_text(encoding="utf-8") + assert "define user greet user" in content + assert '"say hello"' in content + + +def test_dataset_connector_edge_cases(tmp_path): + connector = DatasetConnector(name="test") + with pytest.raises(NotImplementedError): + connector.read_dataset("missing") + + assert connector.write_colang_output(None) is None + + no_canonical = Intent(intent_name="unclear") + duplicate_a = Intent(intent_name="a", canonical_form="same canonical") + duplicate_b = Intent(intent_name="b", canonical_form="same canonical") + connector.intents.update([no_canonical, duplicate_a, duplicate_b]) + connector.intent_examples.extend( + [ + IntentExample(intent=duplicate_a, text="sample a"), + IntentExample(intent=duplicate_b, text="sample b"), + ] + ) + output_path = tmp_path / "duplicates.co" + + connector.write_colang_output(str(output_path), num_samples_per_intent=1) + + assert "define user same canonical" in output_path.read_text(encoding="utf-8") + + +def test_banking_connector_reads_canonical_forms_and_dataset(tmp_path, monkeypatch): + canonical_path = tmp_path / "canonical.json" + canonical_path.write_text(json.dumps([["greet", "greet user"], ["bad"], ["bye", "say bye"]]), encoding="utf-8") + + assert Banking77Connector._read_canonical_forms(str(canonical_path)) == { + "greet": "greet user", + "bye": "say bye", + } + + dataset_dir = tmp_path / "banking" + dataset_dir.mkdir() + (dataset_dir / "train.csv").write_text("text,category\nhello,greet\n", encoding="utf-8") + (dataset_dir / "test.csv").write_text("goodbye,bye\nunknown,missing\n", encoding="utf-8") + monkeypatch.setattr( + Banking77Connector, + "_read_canonical_forms", + staticmethod(lambda: {"greet": "greet user", "bye": "say bye"}), + ) + connector = Banking77Connector() + + connector.read_dataset(str(dataset_dir) + "/") + + assert Intent(intent_name="greet", canonical_form="greet user") in connector.intents + assert Intent(intent_name="bye", canonical_form="say bye") in connector.intents + assert Intent(intent_name="missing", canonical_form=None) in connector.intents + assert [example.dataset_split for example in connector.intent_examples] == ["train", "test", "test"] + + +def test_chitchat_connector_reads_rasa_markdown(tmp_path, monkeypatch): + dataset_dir = tmp_path / "chitchat" + dataset_dir.mkdir() + (dataset_dir / "nlu.md").write_text("## intent:greet\n- hello\n- hi\n", encoding="utf-8") + + monkeypatch.setattr(ChitChatConnector, "_read_canonical_forms", staticmethod(lambda: {"greet": "greet user"})) + connector = ChitChatConnector() + connector.read_dataset(str(dataset_dir) + "/") + + assert connector.intents == {Intent(intent_name="greet", canonical_form="greet user")} + assert [example.text for example in connector.intent_examples] == ["hello", "hi"] + + +def test_chitchat_connector_reads_canonical_forms(tmp_path): + canonical_path = tmp_path / "canonical.json" + canonical_path.write_text(json.dumps([["greet", "greet user"], ["bad"], ["bye", "say bye"]]), encoding="utf-8") + + assert ChitChatConnector._read_canonical_forms(str(canonical_path)) == { + "greet": "greet user", + "bye": "say bye", + } diff --git a/tests/evaluate/test_evaluate_runtime_classes.py b/tests/evaluate/test_evaluate_runtime_classes.py new file mode 100644 index 0000000000..c8e2b00507 --- /dev/null +++ b/tests/evaluate/test_evaluate_runtime_classes.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nemoguardrails.evaluate.evaluate_factcheck import FactCheckEvaluation +from nemoguardrails.evaluate.evaluate_hallucination import HallucinationRailsEvaluation +from nemoguardrails.evaluate.evaluate_moderation import ModerationRailsEvaluation + + +def _moderation_evaluator(): + evaluator = ModerationRailsEvaluation.__new__(ModerationRailsEvaluation) + evaluator.llm = MagicMock() + evaluator.llm_task_manager = MagicMock() + evaluator.llm_task_manager.render_task_prompt.return_value = "rendered prompt" + evaluator.check_input = True + evaluator.check_output = True + evaluator.dataset = ["prompt"] + evaluator.split = "harmful" + evaluator.write_outputs = False + evaluator.output_dir = "unused" + evaluator.dataset_path = "harmful.txt" + return evaluator + + +def test_moderation_init_loads_rails_dataset_and_creates_output_dir(tmp_path): + dataset = tmp_path / "moderation.txt" + output_dir = tmp_path / "outputs" + dataset.write_text("one\ntwo\n", encoding="utf-8") + rails = SimpleNamespace(llm="llm") + + with ( + patch( + "nemoguardrails.evaluate.evaluate_moderation.RailsConfig.from_path", return_value="config" + ) as mock_config, + patch("nemoguardrails.evaluate.evaluate_moderation.LLMRails", return_value=rails) as mock_rails, + patch( + "nemoguardrails.evaluate.evaluate_moderation.LLMTaskManager", return_value="task-manager" + ) as mock_task_manager, + ): + evaluator = ModerationRailsEvaluation( + config="config-path", + dataset_path=str(dataset), + num_samples=1, + output_dir=str(output_dir), + ) + + mock_config.assert_called_once_with("config-path") + mock_rails.assert_called_once_with("config") + mock_task_manager.assert_called_once_with("config") + assert evaluator.dataset == ["one\n"] + assert evaluator.llm == "llm" + assert output_dir.exists() + + +def test_moderation_get_jailbreak_results_counts_flags_and_correct_predictions(): + evaluator = _moderation_evaluator() + results = {"flagged": 0, "correct": 0, "error": 0, "label": "yes"} + + with patch( + "nemoguardrails.evaluate.evaluate_moderation.llm_call", + AsyncMock(return_value=SimpleNamespace(content="YES")), + ): + prediction, updated = evaluator.get_jailbreak_results("prompt", results) + + assert prediction == "yes" + assert updated["flagged"] == 1 + assert updated["correct"] == 1 + evaluator.llm_task_manager.render_task_prompt.assert_called_once() + + +def test_moderation_get_jailbreak_results_records_error_after_retries(): + evaluator = _moderation_evaluator() + results = {"flagged": 0, "correct": 0, "error": 0, "label": "yes"} + + mock_llm_call = AsyncMock(side_effect=RuntimeError("failed")) + with patch("nemoguardrails.evaluate.evaluate_moderation.llm_call", mock_llm_call): + prediction, updated = evaluator.get_jailbreak_results("prompt", results) + + assert prediction is None + assert updated["error"] == 1 + # The max_tries loop must exhaust all three attempts before recording the error. + assert mock_llm_call.await_count == 3 + + +def test_moderation_get_check_output_results_counts_flags_and_correct_predictions(): + evaluator = _moderation_evaluator() + results = {"flagged": 0, "correct": 0, "error": 0, "label": "yes"} + + with patch( + "nemoguardrails.evaluate.evaluate_moderation.llm_call", + AsyncMock( + side_effect=[ + SimpleNamespace(content="bot response"), + SimpleNamespace(content="yes"), + ] + ), + ): + bot_response, prediction, updated = evaluator.get_check_output_results("prompt", results) + + assert bot_response == "bot response" + assert prediction == "yes" + assert updated["flagged"] == 1 + assert updated["correct"] == 1 + + +def test_moderation_check_moderation_combines_enabled_checks(): + evaluator = _moderation_evaluator() + evaluator.get_jailbreak_results = MagicMock( + return_value=("yes", {"flagged": 1, "correct": 1, "error": 0, "label": "yes"}) + ) + evaluator.get_check_output_results = MagicMock( + return_value=( + "bot", + "yes", + {"flagged": 1, "correct": 1, "error": 0, "label": "yes"}, + ) + ) + + predictions, jailbreak_results, check_output_results = evaluator.check_moderation() + + assert predictions == [ + { + "prompt": "prompt", + "jailbreak": "yes", + "bot_response": "bot", + "check_output": "yes", + } + ] + assert jailbreak_results["correct"] == 1 + assert check_output_results["flagged"] == 1 + + +def test_moderation_run_writes_predictions(tmp_path): + evaluator = _moderation_evaluator() + evaluator.write_outputs = True + evaluator.output_dir = str(tmp_path) + evaluator.dataset_path = "harmful.txt" + evaluator.check_moderation = MagicMock( + return_value=( + [{"prompt": "prompt", "jailbreak": "yes"}], + {"flagged": 1, "correct": 1, "error": 0}, + {"flagged": 0, "correct": 0, "error": 0}, + ) + ) + + evaluator.run() + + output_path = tmp_path / "harmful_harmful_moderation_results.json" + assert json.loads(output_path.read_text(encoding="utf-8")) == [{"prompt": "prompt", "jailbreak": "yes"}] + + +def _hallucination_evaluator(): + evaluator = HallucinationRailsEvaluation.__new__(HallucinationRailsEvaluation) + evaluator.dataset = ["question"] + evaluator.write_outputs = False + evaluator.output_dir = "unused" + evaluator.dataset_path = "sample.txt" + evaluator.llm_task_manager = MagicMock() + evaluator.llm_task_manager.render_task_prompt.return_value = "check hallucination" + evaluator.llm = MagicMock(return_value="no") + return evaluator + + +def test_hallucination_init_loads_rails_dataset_and_creates_output_dir(tmp_path): + dataset = tmp_path / "hallucination.txt" + output_dir = tmp_path / "outputs" + dataset.write_text("one\ntwo\n", encoding="utf-8") + rails = SimpleNamespace(llm="llm") + + with ( + patch( + "nemoguardrails.evaluate.evaluate_hallucination.RailsConfig.from_path", return_value="config" + ) as mock_config, + patch("nemoguardrails.evaluate.evaluate_hallucination.LLMRails", return_value=rails) as mock_rails, + patch( + "nemoguardrails.evaluate.evaluate_hallucination.LLMTaskManager", return_value="task-manager" + ) as mock_task_manager, + ): + evaluator = HallucinationRailsEvaluation( + config="config-path", + dataset_path=str(dataset), + num_samples=1, + output_dir=str(output_dir), + ) + + mock_config.assert_called_once_with("config-path") + mock_rails.assert_called_once_with("config") + mock_task_manager.assert_called_once_with("config") + assert evaluator.dataset == ["one\n"] + assert evaluator.llm == "llm" + assert output_dir.exists() + + +def test_hallucination_get_response_with_retries_uses_bound_llm_params(): + evaluator = _hallucination_evaluator() + bound = MagicMock(return_value="bound response") + evaluator.llm = MagicMock() + evaluator.llm.bind.return_value = bound + + response = evaluator.get_response_with_retries( + "prompt", + max_tries=2, + llm_params={"temperature": 1.0}, + ) + + assert response == "bound response" + evaluator.llm.bind.assert_called_once_with(temperature=1.0) + bound.assert_called_once_with("prompt") + + +def test_hallucination_get_extra_responses_skips_errors(): + evaluator = _hallucination_evaluator() + evaluator.get_response_with_retries = MagicMock(side_effect=[None, "extra"]) + + assert evaluator.get_extra_responses("prompt", num_responses=2) == ["extra"] + + +def test_hallucination_get_response_with_retries_returns_none_after_errors(): + evaluator = _hallucination_evaluator() + evaluator.llm = MagicMock(side_effect=RuntimeError("failed")) + + assert evaluator.get_response_with_retries("prompt", max_tries=2) is None + # Both attempts must be made before giving up. + assert evaluator.llm.call_count == 2 + + +def test_hallucination_self_check_counts_no_as_flagged(): + evaluator = _hallucination_evaluator() + evaluator.get_response_with_retries = MagicMock(return_value="main response") + evaluator.get_extra_responses = MagicMock(return_value=["extra one", "extra two"]) + + predictions, num_flagged, num_error = evaluator.self_check_hallucination() + + assert num_flagged == 1 + assert num_error == 0 + assert predictions[0]["hallucination_agreement"] == "no" + + +def test_hallucination_self_check_records_error_when_main_response_fails(): + evaluator = _hallucination_evaluator() + evaluator.get_response_with_retries = MagicMock(return_value=None) + + predictions, num_flagged, num_error = evaluator.self_check_hallucination() + + assert num_flagged == 0 + assert num_error == 1 + assert predictions[0]["hallucination_agreement"] == "na" + + +def test_hallucination_run_writes_predictions(tmp_path): + evaluator = _hallucination_evaluator() + evaluator.write_outputs = True + evaluator.output_dir = str(tmp_path) + evaluator.dataset_path = "sample.txt" + evaluator.self_check_hallucination = MagicMock(return_value=([{"question": "q"}], 1, 0)) + + evaluator.run() + + output_path = tmp_path / "sample_hallucination_predictions.json" + assert json.loads(output_path.read_text(encoding="utf-8")) == [{"question": "q"}] + + +def _factcheck_evaluator(): + evaluator = FactCheckEvaluation.__new__(FactCheckEvaluation) + evaluator.dataset = [ + { + "question": "q", + "evidence": "e", + "answer": "a", + "incorrect_answer": "bad", + } + ] + evaluator.llm = MagicMock() + evaluator.llm_task_manager = MagicMock() + evaluator.llm_task_manager.render_task_prompt.return_value = "fact prompt" + evaluator.llm_task_manager.get_stop_tokens.return_value = ["stop"] + evaluator.create_negatives = False + evaluator.write_outputs = False + evaluator.output_dir = "unused" + evaluator.dataset_path = "sample.json" + return evaluator + + +def test_factcheck_init_loads_rails_dataset_and_creates_output_dir(tmp_path): + dataset = tmp_path / "fact.json" + output_dir = tmp_path / "outputs" + dataset.write_text(json.dumps([{"question": "q"}]), encoding="utf-8") + rails = SimpleNamespace(llm="llm") + + with ( + patch("nemoguardrails.evaluate.evaluate_factcheck.RailsConfig.from_path", return_value="config") as mock_config, + patch("nemoguardrails.evaluate.evaluate_factcheck.LLMRails", return_value=rails) as mock_rails, + patch( + "nemoguardrails.evaluate.evaluate_factcheck.LLMTaskManager", return_value="task-manager" + ) as mock_task_manager, + ): + evaluator = FactCheckEvaluation( + config="config-path", + dataset_path=str(dataset), + num_samples=1, + output_dir=str(output_dir), + ) + + mock_config.assert_called_once_with("config-path") + mock_rails.assert_called_once_with("config") + mock_task_manager.assert_called_once_with("config") + assert evaluator.dataset == [{"question": "q"}] + assert evaluator.llm == "llm" + assert output_dir.exists() + + +@pytest.mark.asyncio +async def test_factcheck_create_negative_samples_adds_incorrect_answer(): + evaluator = _factcheck_evaluator() + evaluator.llm.generate_async = AsyncMock(return_value=SimpleNamespace(content=" incorrect ")) + dataset = [{"question": "q", "evidence": "e", "answer": "a"}] + + result = await evaluator.create_negative_samples(dataset) + + assert result[0]["incorrect_answer"] == "incorrect" + + +def test_factcheck_check_facts_uses_positive_and_negative_labels(): + evaluator = _factcheck_evaluator() + + with ( + patch( + "nemoguardrails.evaluate.evaluate_factcheck.llm_call", + AsyncMock(return_value=SimpleNamespace(content="yes")), + ) as mock_llm_call, + patch("nemoguardrails.evaluate.evaluate_factcheck.time.sleep"), + ): + predictions, num_correct, total_time = evaluator.check_facts(split="positive") + + assert predictions[0]["answer"] == "a" + assert predictions[0]["label"] == "yes" + assert num_correct == 1 + assert total_time >= 0 + mock_llm_call.assert_awaited_once() + + with ( + patch( + "nemoguardrails.evaluate.evaluate_factcheck.llm_call", + AsyncMock(return_value=SimpleNamespace(content="no")), + ), + patch("nemoguardrails.evaluate.evaluate_factcheck.time.sleep"), + ): + predictions, num_correct, _ = evaluator.check_facts(split="negative") + + assert predictions[0]["answer"] == "bad" + assert predictions[0]["label"] == "no" + assert num_correct == 1 + + +def test_factcheck_run_writes_positive_and_negative_predictions(tmp_path): + evaluator = _factcheck_evaluator() + evaluator.write_outputs = True + evaluator.output_dir = str(tmp_path) + evaluator.dataset_path = "sample.json" + evaluator.check_facts = MagicMock( + side_effect=[ + ([{"label": "yes"}], 1, 0.1), + ([{"label": "no"}], 1, 0.2), + ] + ) + + evaluator.run() + + positive_path = tmp_path / "sample_positive_fact_check_predictions.json" + negative_path = tmp_path / "sample_negative_fact_check_predictions.json" + assert json.loads(positive_path.read_text(encoding="utf-8")) == [{"label": "yes"}] + assert json.loads(negative_path.read_text(encoding="utf-8")) == [{"label": "no"}] + + +def test_factcheck_run_creates_negative_samples_when_enabled(tmp_path): + evaluator = _factcheck_evaluator() + evaluator.create_negatives = True + original_dataset = evaluator.dataset + negatives = [{"question": "q", "evidence": "e", "answer": "a", "incorrect_answer": "sentinel"}] + evaluator.create_negative_samples = AsyncMock(return_value=negatives) + evaluator.check_facts = MagicMock( + side_effect=[ + ([{"label": "yes"}], 1, 0.1), + ([{"label": "no"}], 1, 0.2), + ] + ) + + evaluator.run() + + evaluator.create_negative_samples.assert_awaited_once_with(original_dataset) + # run() must assign the coroutine result back onto self.dataset before checking facts. + assert evaluator.dataset == negatives + assert evaluator.check_facts.call_count == 2 diff --git a/tests/evaluate/test_evaluate_topical.py b/tests/evaluate/test_evaluate_topical.py new file mode 100644 index 0000000000..32acb3696b --- /dev/null +++ b/tests/evaluate/test_evaluate_topical.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nemoguardrails.evaluate.evaluate_topical import ( + TopicalRailsEvaluation, + _split_test_set_from_config, + cosine_similarity, + sync_wrapper, +) +from nemoguardrails.evaluate.utils import load_dataset + + +def test_load_dataset_reads_json_and_text(tmp_path): + json_path = tmp_path / "data.json" + text_path = tmp_path / "data.txt" + json_path.write_text(json.dumps([{"question": "q"}]), encoding="utf-8") + text_path.write_text("one\ntwo\n", encoding="utf-8") + + assert load_dataset(str(json_path)) == [{"question": "q"}] + assert load_dataset(str(text_path)) == ["one\n", "two\n"] + + +def test_cosine_similarity(): + assert cosine_similarity([1, 0], [1, 0]) == pytest.approx(1.0) + assert cosine_similarity([1, 0], [0, 1]) == pytest.approx(0.0) + + +def test_sync_wrapper_runs_async_function(): + async def add(left, right): + return left + right + + assert sync_wrapper(add)(2, 3) == 5 + + +def test_sync_wrapper_falls_back_to_asyncio_run(monkeypatch): + import asyncio + + async def add(left, right): + return left + right + + def _no_event_loop(): + raise RuntimeError("no current event loop") + + monkeypatch.setattr(asyncio, "get_event_loop", _no_event_loop) + + assert sync_wrapper(add)(2, 3) == 5 + + +def test_split_test_set_from_config_uses_seed_and_limits_remaining_samples(): + config = SimpleNamespace(user_messages={"greet": ["a", "b", "c", "d"]}) + test_set = {} + + _split_test_set_from_config( + config, + test_set_percentage=0.5, + test_set=test_set, + max_samples_per_intent=1, + random_seed=7, + ) + + assert len(test_set["greet"]) == 2 + assert len(config.user_messages["greet"]) == 1 + assert set(test_set["greet"]).isdisjoint(config.user_messages["greet"]) + + +def test_split_test_set_ignores_single_sample_intents(): + config = SimpleNamespace(user_messages={"solo": ["only"]}) + test_set = {} + + _split_test_set_from_config(config, 0.5, test_set, 0) + + assert test_set == {} + assert config.user_messages == {"solo": ["only"]} + + +def test_topical_init_initializes_rails_seed_and_embeddings(tmp_path): + rails_config = SimpleNamespace(user_messages={"greet": ["a", "b"]}, flows=[], models=[]) + rails_app = SimpleNamespace(config=rails_config) + + with ( + patch( + "nemoguardrails.evaluate.evaluate_topical.RailsConfig.from_path", return_value=rails_config + ) as mock_config, + patch("nemoguardrails.evaluate.evaluate_topical.LLMRails", return_value=rails_app) as mock_rails, + patch("nemoguardrails.evaluate.evaluate_topical.random.seed") as mock_seed, + patch.object(TopicalRailsEvaluation, "_initialize_embeddings_model") as mock_embeddings, + ): + evaluator = TopicalRailsEvaluation( + config=str(tmp_path), + verbose=True, + test_set_percentage=0.5, + max_tests_per_intent=1, + max_samples_per_intent=1, + similarity_threshold=0.0, + random_seed=11, + ) + + mock_config.assert_called_once_with(config_path=str(tmp_path)) + mock_rails.assert_called_once_with(rails_config, verbose=True) + mock_seed.assert_called_once_with(11) + mock_embeddings.assert_called_once_with() + assert evaluator.rails_app == rails_app + assert len(evaluator.test_set["greet"]) == 1 + + +def test_topical_initialize_embeddings_model_import_error(monkeypatch): + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + + def fake_import(name, *args, **kwargs): + if name == "sentence_transformers": + raise ImportError("missing") + return original_import(name, *args, **kwargs) + + original_import = __import__ + monkeypatch.setattr("builtins.__import__", fake_import) + + with pytest.raises(ImportError, match="sentence_transformers"): + evaluator._initialize_embeddings_model() + + +def test_topical_initialize_embeddings_model_creates_model(monkeypatch): + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + evaluator.similarity_threshold = 0.5 + sentence_transformers = types.ModuleType("sentence_transformers") + model_cls = MagicMock(return_value="model") + sentence_transformers.SentenceTransformer = model_cls + monkeypatch.setitem(sys.modules, "sentence_transformers", sentence_transformers) + + evaluator._initialize_embeddings_model() + + model_cls.assert_called_once_with("all-MiniLM-L6-v2") + assert evaluator._model == "model" + + +def test_topical_helper_methods_compute_embeddings_and_similarity(): + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + evaluator.similarity_threshold = 0.5 + evaluator._model = SimpleNamespace( + encode=lambda values: [[1.0, 0.0], [0.0, 1.0]] if isinstance(values, list) else [1.0, 0.0] + ) + + evaluator._compute_intent_embeddings(["greet", "bye"]) + + assert evaluator._intent_embeddings == {"greet": [1.0, 0.0], "bye": [0.0, 1.0]} + assert evaluator._get_most_similar_intent("hello") == "greet" + + +def test_topical_helper_methods_return_original_intent_without_model(): + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + evaluator.similarity_threshold = 0.9 + evaluator._model = None + + evaluator._compute_intent_embeddings(["greet"]) + + assert not hasattr(evaluator, "_intent_embeddings") + assert evaluator._get_most_similar_intent("generated") == "generated" + + +def test_get_main_llm_model(): + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + evaluator.rails_app = SimpleNamespace( + config=SimpleNamespace( + models=[ + SimpleNamespace(type="embedding", model="embed"), + SimpleNamespace(type="main", model="main-model"), + ] + ) + ) + + assert evaluator._get_main_llm_model() == "main-model" + + evaluator.rails_app.config.models = [] + assert evaluator._get_main_llm_model() == "unknown_main_llm" + + +def test_evaluate_topical_rails_runs_with_mock_runtime_and_writes_predictions(tmp_path): + generate_events = AsyncMock( + return_value=[ + {"type": "UserIntent", "intent": "wrong"}, + {"type": "BotIntent", "intent": "wrong bot"}, + {"type": "StartUtteranceBotAction", "script": "unexpected"}, + ] + ) + evaluator = TopicalRailsEvaluation.__new__(TopicalRailsEvaluation) + evaluator.config_path = str(tmp_path / "configs" / "topical") + evaluator.test_set = {"greet": ["hello"]} + evaluator.max_tests_per_intent = 1 + evaluator.max_samples_per_intent = 0 + evaluator.print_test_results_frequency = 1 + evaluator.similarity_threshold = 0.0 + evaluator.output_dir = str(tmp_path) + evaluator._model = None + evaluator.rails_app = SimpleNamespace( + runtime=SimpleNamespace(generate_events=generate_events), + config=SimpleNamespace( + flows=[ + { + "elements": [ + {"_type": "UserIntent", "intent_name": "greet"}, + { + "_type": "run_action", + "action_name": "utter", + "action_params": {"value": "bot greet"}, + }, + ] + } + ], + bot_messages={"bot greet": ["hello there"]}, + models=[SimpleNamespace(type="main", model="mock-model")], + ), + ) + + evaluator.evaluate_topical_rails() + + generate_events.assert_awaited_once_with([{"type": "UtteranceUserActionFinished", "final_transcript": "hello"}]) + output_files = list(tmp_path.glob("*_topical_results.json")) + assert len(output_files) == 1 + predictions = json.loads(output_files[0].read_text(encoding="utf-8")) + assert predictions[0]["generated_user_intent"] == "wrong" + assert predictions[0]["generated_bot_intent"] == "wrong bot" diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..0c6f21f390 --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(autouse=True) +def _disable_chat_ui(monkeypatch): + # Force the isolated path for every server test, regardless of any ambient + # value; monkeypatch restores the prior state afterwards. + monkeypatch.setenv("NEMO_GUARDRAILS_DISABLE_CHAT_UI", "true") diff --git a/tests/server/test_api.py b/tests/server/test_api.py index 9e469dae71..a6bc904cef 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -1028,7 +1028,10 @@ def test_list_models_forwards_auth_header(): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch.dict(os.environ, {"MAIN_MODEL_BASE_URL": "http://localhost:8000"}): + with patch.dict( + os.environ, + {"MAIN_MODEL_ENGINE": "openai", "MAIN_MODEL_BASE_URL": "http://localhost:8000", "OPENAI_API_KEY": ""}, + ): with patch("httpx.AsyncClient", return_value=mock_client): response = client.get( "/v1/models", @@ -1041,6 +1044,32 @@ def test_list_models_forwards_auth_header(): assert call_kwargs.kwargs["headers"]["Authorization"] == "Bearer my-token" +def test_list_models_prefers_openai_api_key_over_auth_header(): + mock_response = _make_httpx_response({"data": []}) + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch.dict( + os.environ, + { + "MAIN_MODEL_ENGINE": "openai", + "MAIN_MODEL_BASE_URL": "http://localhost:8000", + "OPENAI_API_KEY": "sk-env-key", + }, + ): + with patch("httpx.AsyncClient", return_value=mock_client): + response = client.get( + "/v1/models", + headers={"Authorization": "Bearer not-used"}, + ) + + assert response.status_code == 200 + call_kwargs = mock_client.get.call_args + assert call_kwargs.kwargs["headers"]["Authorization"] == "Bearer sk-env-key" + + def test_list_models_uses_openai_api_key_fallback(): """Test /v1/models falls back to OPENAI_API_KEY when no auth header.""" mock_response = _make_httpx_response({"data": []}) @@ -1052,6 +1081,7 @@ def test_list_models_uses_openai_api_key_fallback(): with patch.dict( os.environ, { + "MAIN_MODEL_ENGINE": "openai", "MAIN_MODEL_BASE_URL": "http://localhost:8000", "OPENAI_API_KEY": "sk-test-key", }, diff --git a/tests/server/test_openai_integration.py b/tests/server/test_openai_integration.py index 26f6010e5c..90aff7eb1b 100644 --- a/tests/server/test_openai_integration.py +++ b/tests/server/test_openai_integration.py @@ -25,6 +25,12 @@ from nemoguardrails.server import api +# The live `test_list_models_*` tests below are opt-in: they require LIVE_TEST_MODE +# (or TEST_LIVE_MODE) to be set in addition to the relevant provider API key. A +# provider key alone is intentionally NOT sufficient, so CI that exports only +# OPENAI_API_KEY (etc.) skips them and never reaches a live provider. +LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") or os.environ.get("TEST_LIVE_MODE") + @pytest.fixture(scope="function", autouse=True) def set_rails_config_path(): @@ -45,17 +51,22 @@ def set_rails_config_path(): @pytest.fixture(scope="function") -def openai_client(): +def make_openai_client(): """Create an OpenAI client that uses the guardrails FastAPI app via TestClient.""" - # Create a TestClient for the FastAPI app - test_client = TestClient(api.app) - client = OpenAI( - api_key="dummy-key", - base_url="http://dummy-url/v1", - http_client=test_client, - ) - return client + def _make_client(api_key="dummy-key"): + return OpenAI( + api_key=api_key, + base_url="http://dummy-url/v1", + http_client=TestClient(api.app), + ) + + return _make_client + + +@pytest.fixture(scope="function") +def openai_client(make_openai_client): + return make_openai_client() def test_openai_client_chat_completion(openai_client): @@ -313,14 +324,15 @@ def test_openai_client_with_rails_disabled(openai_client): @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY"), - reason="OPENAI_API_KEY is required for this test.", + not (LIVE_TEST_MODE and os.environ.get("OPENAI_API_KEY")), + reason="LIVE_TEST_MODE or TEST_LIVE_MODE and OPENAI_API_KEY are required for this test.", ) -def test_list_models_openai(openai_client): +def test_list_models_openai(make_openai_client, monkeypatch): """List models from the OpenAI API.""" - os.environ.setdefault("MAIN_MODEL_BASE_URL", "https://api.openai.com") - os.environ["MAIN_MODEL_ENGINE"] = "openai" + monkeypatch.setenv("MAIN_MODEL_BASE_URL", os.environ.get("MAIN_MODEL_BASE_URL") or "https://api.openai.com") + monkeypatch.setenv("MAIN_MODEL_ENGINE", "openai") + openai_client = make_openai_client("not-used") models = list(openai_client.models.list()) assert len(models) > 0 @@ -331,14 +343,15 @@ def test_list_models_openai(openai_client): @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY"), - reason="OPENAI_API_KEY is required for this test.", + not (LIVE_TEST_MODE and os.environ.get("OPENAI_API_KEY")), + reason="LIVE_TEST_MODE or TEST_LIVE_MODE and OPENAI_API_KEY are required for this test.", ) -def test_list_models_openai_fields(openai_client): +def test_list_models_openai_fields(make_openai_client, monkeypatch): """Verify that well-known OpenAI models appear with expected fields.""" - os.environ.setdefault("MAIN_MODEL_BASE_URL", "https://api.openai.com") - os.environ["MAIN_MODEL_ENGINE"] = "openai" + monkeypatch.setenv("MAIN_MODEL_BASE_URL", os.environ.get("MAIN_MODEL_BASE_URL") or "https://api.openai.com") + monkeypatch.setenv("MAIN_MODEL_ENGINE", "openai") + openai_client = make_openai_client("not-used") models = {m.id: m for m in openai_client.models.list()} # At least one GPT model should be present @@ -351,13 +364,14 @@ def test_list_models_openai_fields(openai_client): @pytest.mark.skipif( - not os.environ.get("ANTHROPIC_API_KEY"), - reason="ANTHROPIC_API_KEY is required for this test.", + not (LIVE_TEST_MODE and os.environ.get("ANTHROPIC_API_KEY")), + reason="LIVE_TEST_MODE or TEST_LIVE_MODE and ANTHROPIC_API_KEY are required for this test.", ) -def test_list_models_anthropic(openai_client): +def test_list_models_anthropic(make_openai_client, monkeypatch): """List models from the Anthropic API.""" - os.environ["MAIN_MODEL_ENGINE"] = "anthropic" + monkeypatch.setenv("MAIN_MODEL_ENGINE", "anthropic") + openai_client = make_openai_client("not-used") models = list(openai_client.models.list()) assert len(models) > 0 @@ -369,13 +383,14 @@ def test_list_models_anthropic(openai_client): @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY"), - reason="COHERE_API_KEY is required for this test.", + not (LIVE_TEST_MODE and os.environ.get("COHERE_API_KEY")), + reason="LIVE_TEST_MODE or TEST_LIVE_MODE and COHERE_API_KEY are required for this test.", ) -def test_list_models_cohere(openai_client): +def test_list_models_cohere(make_openai_client, monkeypatch): """List models from the Cohere API.""" - os.environ["MAIN_MODEL_ENGINE"] = "cohere" + monkeypatch.setenv("MAIN_MODEL_ENGINE", "cohere") + openai_client = make_openai_client("not-used") models = list(openai_client.models.list()) assert len(models) > 0 @@ -387,13 +402,14 @@ def test_list_models_cohere(openai_client): @pytest.mark.skipif( - not (os.environ.get("AZURE_OPENAI_ENDPOINT") and os.environ.get("AZURE_OPENAI_API_KEY")), - reason="AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY are required for this test.", + not (LIVE_TEST_MODE and os.environ.get("AZURE_OPENAI_ENDPOINT") and os.environ.get("AZURE_OPENAI_API_KEY")), + reason="LIVE_TEST_MODE or TEST_LIVE_MODE, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_API_KEY are required.", ) -def test_list_models_azure(openai_client): +def test_list_models_azure(make_openai_client, monkeypatch): """List models from Azure OpenAI.""" - os.environ["MAIN_MODEL_ENGINE"] = "azure" + monkeypatch.setenv("MAIN_MODEL_ENGINE", "azure") + openai_client = make_openai_client("not-used") models = list(openai_client.models.list()) assert len(models) > 0 diff --git a/tests/server/test_schema_utils.py b/tests/server/test_schema_utils.py index 81a7002b03..53f784fd9f 100644 --- a/tests/server/test_schema_utils.py +++ b/tests/server/test_schema_utils.py @@ -582,13 +582,26 @@ async def test_fetch_unknown_engine_no_base_url(): async def test_fetch_auth_forwarded(): """Incoming Authorization header is forwarded for OpenAI-compatible providers.""" mock = _mock_httpx({"data": []}) - with patch.dict(os.environ, {"MAIN_MODEL_BASE_URL": "http://localhost:8000"}): + with patch.dict(os.environ, {"MAIN_MODEL_BASE_URL": "http://localhost:8000", "OPENAI_API_KEY": ""}): with patch("httpx.AsyncClient", return_value=mock): await fetch_models("openai", {"Authorization": "Bearer user-token"}) call_headers = mock.get.call_args.kwargs["headers"] assert call_headers["Authorization"] == "Bearer user-token" +@pytest.mark.asyncio +async def test_fetch_env_auth_precedes_forwarded_auth(): + mock = _mock_httpx({"data": []}) + with patch.dict( + os.environ, + {"MAIN_MODEL_BASE_URL": "http://localhost:8000", "OPENAI_API_KEY": "sk-env-key"}, + ): + with patch("httpx.AsyncClient", return_value=mock): + await fetch_models("openai", {"Authorization": "Bearer not-used"}) + call_headers = mock.get.call_args.kwargs["headers"] + assert call_headers["Authorization"] == "Bearer sk-env-key" + + @pytest.mark.asyncio async def test_fetch_non_dict_items_skipped(): """Non-dict items in the response data are skipped.""" diff --git a/tests/server/test_server_coverage.py b/tests/server/test_server_coverage.py new file mode 100644 index 0000000000..3e2279cb58 --- /dev/null +++ b/tests/server/test_server_coverage.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib +import os +import sys +import types +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nemoguardrails.exceptions import StreamingNotSupportedError +from nemoguardrails.server import api +from nemoguardrails.server.datastore import redis_store +from nemoguardrails.server.datastore.datastore import DataStore + + +@pytest.fixture +def reset_api_state(): + old_challenges = list(api.challenges) + old_instances = dict(api.llm_rails_instances) + old_history = dict(api.llm_rails_events_history_cache) + old_datastore = api.datastore + old_loggers = list(api.registered_loggers) + old_attrs = { + "rails_config_path": api.app.rails_config_path, + "single_config_mode": api.app.single_config_mode, + "single_config_id": api.app.single_config_id, + "default_config_id": api.app.default_config_id, + "auto_reload": api.app.auto_reload, + "stop_signal": api.app.stop_signal, + } + yield + api.challenges[:] = old_challenges + api.llm_rails_instances.clear() + api.llm_rails_instances.update(old_instances) + api.llm_rails_events_history_cache.clear() + api.llm_rails_events_history_cache.update(old_history) + api.datastore = old_datastore + api.registered_loggers[:] = old_loggers + for name, value in old_attrs.items(): + setattr(api.app, name, value) + + +@pytest.mark.asyncio +async def test_lifespan_loads_challenges_and_single_config(tmp_path, reset_api_state): + app = api.GuardrailsApp() + app.rails_config_path = str(tmp_path) + (tmp_path / "config.yml").write_text("models: []\n", encoding="utf-8") + (tmp_path / "challenges.json").write_text('[{"name": "c", "content": "prompt"}]', encoding="utf-8") + + with patch("nemoguardrails.telemetry.set_deployment_type") as mock_set_deployment_type: + async with api.lifespan(app): + assert app.single_config_mode is True + assert app.single_config_id == tmp_path.name + assert api.challenges[-1] == {"name": "c", "content": "prompt"} + + mock_set_deployment_type.assert_called_once() + + +@pytest.mark.asyncio +async def test_lifespan_loads_config_py_init(tmp_path, reset_api_state): + app = api.GuardrailsApp() + app.rails_config_path = str(tmp_path) + (tmp_path / "config.py").write_text("def init(app):\n app.loaded_from_config_py = True\n", encoding="utf-8") + + async with api.lifespan(app): + assert app.loaded_from_config_py is True + + +@pytest.mark.asyncio +async def test_lifespan_auto_reload_sets_and_cancels_task(tmp_path, reset_api_state): + app = api.GuardrailsApp() + app.rails_config_path = str(tmp_path) + app.auto_reload = True + task = MagicMock() + loop = MagicMock() + loop.run_in_executor.return_value = task + + with patch("asyncio.get_running_loop", return_value=loop): + async with api.lifespan(app): + assert app.loop == loop + assert app.task == task + + assert app.stop_signal is True + task.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_rails_cache_and_invalid_single_config(reset_api_state): + cached = SimpleNamespace(events_history_cache={}) + api.llm_rails_instances["cfg"] = cached + assert await api._get_rails(["cfg"]) is cached + + api.app.single_config_mode = True + api.app.single_config_id = "only" + with pytest.raises(ValueError, match="Invalid configuration ids"): + await api._get_rails(["other"]) + + +@pytest.mark.asyncio +async def test_get_rails_rejects_bad_config_ids(tmp_path, reset_api_state): + api.app.rails_config_path = str(tmp_path) + for config_id in ["../outside", "nested/config"]: + with pytest.raises(ValueError): + await api._get_rails([config_id]) + + +@pytest.mark.asyncio +async def test_get_rails_loads_config_updates_model_and_restores_history(tmp_path, reset_api_state): + config_dir = tmp_path / "cfg" + config_dir.mkdir() + api.app.rails_config_path = str(tmp_path) + api.llm_rails_events_history_cache["cfg:override-model"] = {"events": [1]} + config = MagicMock() + config.models = [] + config.model_copy.side_effect = lambda update: SimpleNamespace(models=update["models"]) + rails = SimpleNamespace(events_history_cache={}) + + with ( + patch.object(api.RailsConfig, "from_path", return_value=config) as mock_from_path, + patch.object(api, "LLMRails", return_value=rails) as mock_llm_rails, + patch.dict("os.environ", {"MAIN_MODEL_ENGINE": "openai", "MAIN_MODEL_BASE_URL": "http://model"}), + ): + result = await api._get_rails(["cfg"], model_name="override-model") + + assert result is rails + expected_path = os.path.normpath(os.path.join(os.path.abspath(str(tmp_path)), "cfg")) + mock_from_path.assert_called_once_with(expected_path) + mock_llm_rails.assert_called_once() + assert rails.events_history_cache == {"events": [1]} + assert "cfg:override-model" in api.llm_rails_instances + + +@pytest.mark.asyncio +async def test_format_streaming_response_yields_error_and_done(): + async def stream(): + yield '{"error": {"message": "bad"}}' + yield "ignored" + + chunks = [chunk async for chunk in api._format_streaming_response(stream(), model_name="model")] + + assert '"message": "bad"' in chunks[0] + assert chunks[-1] == "data: [DONE]\n\n" + + +def test_process_chunk_handles_unexpected_validation_error(monkeypatch): + def boom(value): + raise RuntimeError("bad validator") + + monkeypatch.setattr(api.ChunkError, "model_validate_json", boom) + + assert api.process_chunk("plain") == "plain" + + +def test_registration_helpers(reset_api_state): + api.register_challenges([{"name": "one"}]) + assert asyncio.run(api.get_challenges()) == [{"name": "one"}] + + store = object() + logger = object() + api.register_datastore(store) + api.register_logger(logger) + api.set_default_config_id("default") + + assert api.datastore is store + assert logger in api.registered_loggers + assert api.app.default_config_id == "default" + assert isinstance(api.GuardrailsConfigurationError(), Exception) + + +def test_start_auto_reload_monitoring_clears_changed_config_cache(tmp_path, monkeypatch, reset_api_state): + watchdog = types.ModuleType("watchdog") + events = types.ModuleType("watchdog.events") + observers = types.ModuleType("watchdog.observers") + + class FileSystemEventHandler: + pass + + class Observer: + def schedule(self, handler, path, recursive): + self.handler = handler + assert path == api.app.rails_config_path + assert recursive is True + + def start(self): + self.handler.on_any_event(SimpleNamespace(is_directory=True, event_type="modified", src_path="ignored")) + self.handler.on_any_event( + SimpleNamespace(is_directory=False, event_type="modified", src_path=str(tmp_path / "cfg" / ".hidden")) + ) + self.handler.on_any_event( + SimpleNamespace( + is_directory=False, event_type="modified", src_path=str(tmp_path / "cfg" / "config.yml") + ) + ) + + def stop(self): + self.stopped = True + + def join(self): + self.joined = True + + events.FileSystemEventHandler = FileSystemEventHandler + observers.Observer = Observer + monkeypatch.setitem(sys.modules, "watchdog", watchdog) + monkeypatch.setitem(sys.modules, "watchdog.events", events) + monkeypatch.setitem(sys.modules, "watchdog.observers", observers) + (tmp_path / "cfg").mkdir() + (tmp_path / "cfg" / "config.yml").write_text("models: []\n", encoding="utf-8") + api.app.rails_config_path = str(tmp_path) + api.app.stop_signal = False + api.llm_rails_instances["cfg"] = SimpleNamespace(events_history_cache={"cached": True}) + + def stop_loop(seconds): + api.app.stop_signal = True + + monkeypatch.setattr(api.time, "sleep", stop_loop) + + api.start_auto_reload_monitoring() + + assert "cfg" not in api.llm_rails_instances + assert api.llm_rails_events_history_cache["cfg"] == {"cached": True} + + +def test_start_auto_reload_monitoring_import_error(monkeypatch): + original_import = __import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("watchdog"): + raise ImportError("missing") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + monkeypatch.setattr(api.os, "_exit", MagicMock(side_effect=SystemExit(-1))) + + with pytest.raises(SystemExit): + api.start_auto_reload_monitoring() + + +class FakeSession: + def __init__(self): + self.values = {} + + def set(self, key, value): + self.values[key] = value + + def get(self, key): + return self.values.get(key) + + +class FakeMessage: + sent = [] + streamed = [] + updated = [] + + def __init__(self, content=""): + self.content = content + + async def send(self): + self.sent.append(self) + return self + + async def stream_token(self, token): + self.streamed.append(token) + + async def update(self): + self.updated.append(self.content) + + +class FakeChatSettings: + def __init__(self, widgets): + self.widgets = widgets + + async def send(self): + return {"config_id": "cfg2"} + + +class FakeSelect: + def __init__(self, **kwargs): + self.kwargs = kwargs + + +class FakeStarter: + def __init__(self, **kwargs): + self.kwargs = kwargs + + +@pytest.fixture +def server_app_module(monkeypatch): + import nemoguardrails.server.app as server_app + + server_app = importlib.reload(server_app) + FakeMessage.sent = [] + FakeMessage.streamed = [] + FakeMessage.updated = [] + session = FakeSession() + fake_cl = SimpleNamespace( + User=object, + Starter=FakeStarter, + Message=FakeMessage, + ChatSettings=FakeChatSettings, + input_widget=SimpleNamespace(Select=FakeSelect), + user_session=session, + ) + monkeypatch.setattr(server_app, "cl", fake_cl) + old_challenges = list(server_app.challenges) + old_path = server_app.app.rails_config_path + old_single = server_app.app.single_config_mode + old_single_id = server_app.app.single_config_id + old_default = server_app.app.default_config_id + yield server_app, session + server_app.challenges[:] = old_challenges + server_app.app.rails_config_path = old_path + server_app.app.single_config_mode = old_single + server_app.app.single_config_id = old_single_id + server_app.app.default_config_id = old_default + + +@pytest.mark.asyncio +async def test_chat_app_starters_and_config_discovery(tmp_path, server_app_module): + server_app, _ = server_app_module + server_app.challenges[:] = [] + assert await server_app.set_starters() == [] + + server_app.challenges[:] = [{"name": "n", "content": "content", "icon": "i"}] + starters = await server_app.set_starters() + assert starters[0].kwargs == {"label": "n", "message": "content", "icon": "i"} + + server_app.app.single_config_mode = True + server_app.app.single_config_id = "single" + assert server_app._discover_configs() == ["single"] + + server_app.app.single_config_mode = False + server_app.app.rails_config_path = str(tmp_path) + (tmp_path / "cfg").mkdir() + (tmp_path / "cfg" / "config.yml").write_text("models: []\n", encoding="utf-8") + (tmp_path / ".hidden").mkdir() + assert server_app._discover_configs() == ["cfg"] + + server_app.app.rails_config_path = str(tmp_path / "missing") + assert server_app._discover_configs() == [] + + +@pytest.mark.asyncio +async def test_chat_app_start_settings_and_no_config(server_app_module): + server_app, session = server_app_module + with patch.object(server_app, "_discover_configs", return_value=[]): + await server_app.on_chat_start() + assert session.values["config_id"] is None + assert "No guardrails" in FakeMessage.sent[-1].content + + with patch.object(server_app, "_discover_configs", return_value=["cfg1", "cfg2"]): + server_app.app.default_config_id = "cfg1" + await server_app.on_chat_start() + assert session.values["config_id"] == "cfg2" + + await server_app.on_settings_update({"config_id": "cfg1"}) + assert session.values["config_id"] == "cfg1" + assert session.values["messages"] == [] + + +@pytest.mark.asyncio +async def test_chat_app_on_message_paths(server_app_module): + server_app, session = server_app_module + await server_app.on_message(SimpleNamespace(content="hello")) + assert "No guardrails" in FakeMessage.sent[-1].content + + session.set("config_id", "cfg") + session.set("messages", []) + with patch.object(server_app, "_get_rails", AsyncMock(side_effect=RuntimeError("bad"))): + await server_app.on_message(SimpleNamespace(content="hello")) + assert session.get("messages") == [] + assert "Error loading" in FakeMessage.sent[-1].content + + async def stream_success(messages): + yield "hi" + yield " there" + + rails = SimpleNamespace(stream_async=stream_success) + with patch.object(server_app, "_get_rails", AsyncMock(return_value=rails)): + await server_app.on_message(SimpleNamespace(content="hello")) + assert FakeMessage.streamed[-2:] == ["hi", " there"] + assert session.get("messages")[-1] == {"role": "assistant", "content": "hi there"} + + async def stream_unsupported(messages): + raise StreamingNotSupportedError("no stream") + yield "never" + + rails = SimpleNamespace( + stream_async=stream_unsupported, generate_async=AsyncMock(return_value={"content": "fallback"}) + ) + with patch.object(server_app, "_get_rails", AsyncMock(return_value=rails)): + await server_app.on_message(SimpleNamespace(content="again")) + assert FakeMessage.updated[-1] == "fallback" + + async def stream_error(messages): + raise RuntimeError("boom") + yield "never" + + rails = SimpleNamespace(stream_async=stream_error) + with patch.object(server_app, "_get_rails", AsyncMock(return_value=rails)): + await server_app.on_message(SimpleNamespace(content="bad")) + assert "An error occurred" in FakeMessage.updated[-1] + + +@pytest.mark.asyncio +async def test_datastore_base_methods_raise(): + store = DataStore() + with pytest.raises(NotImplementedError): + await store.set("key", "value") + with pytest.raises(NotImplementedError): + await store.get("key") + + +@pytest.mark.asyncio +async def test_redis_store_import_error_and_client_calls(monkeypatch): + monkeypatch.setattr(redis_store, "aioredis", None) + with pytest.raises(ImportError, match="aioredis is required"): + redis_store.RedisStore("redis://localhost") + + client = AsyncMock() + fake_aioredis = SimpleNamespace(from_url=MagicMock(return_value=client)) + monkeypatch.setattr(redis_store, "aioredis", fake_aioredis) + store = redis_store.RedisStore("redis://localhost", username="user", password="pass") + await store.set("key", "value") + client.get.return_value = "value" + assert await store.get("key") == "value" + fake_aioredis.from_url.assert_called_once_with( + url="redis://localhost", username="user", password="pass", decode_responses=True + ) + client.set.assert_awaited_once_with("key", "value") + client.get.assert_awaited_once_with("key") diff --git a/tests/test_actions_math.py b/tests/test_actions_math.py new file mode 100644 index 0000000000..a99c85148a --- /dev/null +++ b/tests/test_actions_math.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import nemoguardrails.actions.math as math_actions +from nemoguardrails.actions.actions import ActionResult + + +class FakeWolframResponse: + def __init__(self, status, body): + self.status = status + self.body = body + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def text(self): + return self.body + + +class FakeWolframSession: + def __init__(self, response): + self.response = response + self.url = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def get(self, url): + self.url = url + return self.response + + +@pytest.mark.asyncio +async def test_wolfram_alpha_requires_query(monkeypatch): + monkeypatch.setattr(math_actions, "APP_ID", "app-id") + + with pytest.raises(Exception, match="No query was provided"): + await math_actions.wolfram_alpha_request() + + +@pytest.mark.asyncio +async def test_wolfram_alpha_app_id_missing(monkeypatch): + monkeypatch.setattr(math_actions, "APP_ID", None) + + result = await math_actions.wolfram_alpha_request(context={"last_user_message": "2+2"}) + + assert isinstance(result, ActionResult) + assert result.return_value is False + assert [event["intent"] for event in result.events if event["type"] == "BotIntent"] == [ + "inform wolfram alpha app id not set", + "stop", + ] + assert [event["script"] for event in result.events if event["type"] == "StartUtteranceBotAction"] == [ + "Wolfram Alpha app ID is not set. Please set the WOLFRAM_ALPHA_APP_ID environment variable.", + ] + + +@pytest.mark.asyncio +async def test_wolfram_alpha_success(monkeypatch): + monkeypatch.setattr(math_actions, "APP_ID", "app-id") + monkeypatch.setattr(math_actions, "API_URL_BASE", "https://example.test/v2/result?appid=app-id") + session = FakeWolframSession(FakeWolframResponse(200, "4")) + monkeypatch.setattr(math_actions.aiohttp, "ClientSession", lambda: session) + + result = await math_actions.wolfram_alpha_request("2 + 2") + + assert result == "4" + assert session.url == "https://example.test/v2/result?appid=app-id&i=2+%2B+2" + + +@pytest.mark.asyncio +async def test_wolfram_alpha_non_200_returns_action_result(monkeypatch): + monkeypatch.setattr(math_actions, "APP_ID", "app-id") + monkeypatch.setattr(math_actions, "API_URL_BASE", "https://example.test/v2/result?appid=app-id") + session = FakeWolframSession(FakeWolframResponse(500, "error")) + monkeypatch.setattr(math_actions.aiohttp, "ClientSession", lambda: session) + + result = await math_actions.wolfram_alpha_request("integrate x") + + assert isinstance(result, ActionResult) + assert result.return_value is False + assert [event["intent"] for event in result.events if event["type"] == "BotIntent"] == [ + "inform wolfram alpha not working", + "stop", + ] + assert [event["script"] for event in result.events if event["type"] == "StartUtteranceBotAction"] == [ + "Apologies, but I cannot answer this question at this time. " + "I am having trouble getting the answer from Wolfram Alpha.", + ] diff --git a/tests/test_actions_validation.py b/tests/test_actions_validation.py index 357a88d16d..23f3f0a423 100644 --- a/tests/test_actions_validation.py +++ b/tests/test_actions_validation.py @@ -13,9 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import builtins +import sys +from types import SimpleNamespace + import pytest from nemoguardrails.actions.validation import validate_input, validate_response +from nemoguardrails.actions.validation.filter_secrets import contains_secrets @validate_input("name", validators=["length"], max_len=100) @@ -71,3 +76,36 @@ def test_cls_validation(): # length is smaller than max len validation assert s_name.run(name="IP 10.40.139.92 should be trimmed") == "IP should be trimmed" + + +def test_contains_secrets_detects_scan_result(monkeypatch): + class DefaultSettings: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + fake_detect_secrets = SimpleNamespace( + settings=SimpleNamespace(default_settings=DefaultSettings), + scan_adhoc_string=lambda resp: resp, + ) + monkeypatch.setitem(sys.modules, "detect_secrets", fake_detect_secrets) + + assert contains_secrets("AWSKeyDetector: False\nTokenDetector: True") is True + assert contains_secrets("AWSKeyDetector: False\nTokenDetector: False") is False + + +def test_contains_secrets_missing_dependency(monkeypatch): + original_import = builtins.__import__ + monkeypatch.delitem(sys.modules, "detect_secrets", raising=False) + + def fake_import(name, *args, **kwargs): + if name == "detect_secrets": + raise ModuleNotFoundError(name) + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + with pytest.raises(ValueError, match="Could not import detect_secrets"): + contains_secrets("secret") diff --git a/tests/v2_x/test_generation_actions_unit.py b/tests/v2_x/test_generation_actions_unit.py new file mode 100644 index 0000000000..1886757073 --- /dev/null +++ b/tests/v2_x/test_generation_actions_unit.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from nemoguardrails.actions.v2_x import generation as v2_generation +from nemoguardrails.colang.v2_x.lang.colang_ast import Flow, SpecOp +from nemoguardrails.colang.v2_x.runtime.errors import LlmResponseError +from nemoguardrails.colang.v2_x.runtime.flows import InternalEvents +from nemoguardrails.context import generation_options_var, raw_llm_request +from nemoguardrails.rails.llm.options import GenerationOptions + + +class FakeIndex: + def __init__(self, results=None): + self.results = results or [] + self.items = [] + self.built = False + self.search_calls = [] + + async def add_items(self, items): + self.items.extend(items) + + async def build(self): + self.built = True + + async def search(self, **kwargs): + self.search_calls.append(kwargs) + return self.results + + +class FakeTaskManager: + def __init__(self, prompt="prompt"): + self.prompt = prompt + self.rendered_prompts = [] + self.parsed_outputs = [] + self.rendered_strings = [] + + def render_task_prompt(self, task, events, context): + self.rendered_prompts.append((task, events, context)) + return self.prompt + + def get_stop_tokens(self, task): + return ["STOP"] + + def parse_task_output(self, task, output): + self.parsed_outputs.append((task, output)) + return output + + def _render_string(self, text, context, events): + self.rendered_strings.append((text, context, events)) + return text.format(**context) + + +class RuntimeFlowConfig: + def __init__(self, flow_id, source_code="", elements=None, decorators=None, has_user_intent=False): + self.id = flow_id + self.source_code = source_code + self.elements = elements or [] + self.decorators = decorators or {} + self.has_user_intent = has_user_intent + + def has_meta_tag(self, tag): + return tag == "user_intent" and self.has_user_intent + + +def make_actions(flows=None, prompt="prompt"): + actions = v2_generation.LLMGenerationActionsV2dotx.__new__(v2_generation.LLMGenerationActionsV2dotx) + user_messages = SimpleNamespace( + embeddings_only=False, + embeddings_only_similarity_threshold=0.7, + embeddings_only_fallback_intent=None, + ) + actions.config = SimpleNamespace( + core=SimpleNamespace(embedding_search_provider=None), + flows=flows or [], + rails=SimpleNamespace(dialog=SimpleNamespace(user_messages=user_messages)), + lowest_temperature=0.0, + ) + actions.llm = object() + actions.llm_task_manager = FakeTaskManager(prompt=prompt) + actions.user_message_index = None + actions.flows_index = None + actions.instruction_flows_index = None + actions._init_lock = asyncio.Lock() + actions._last_docstring = "Fallback for {name}" + actions.get_embedding_search_provider_instance = lambda provider: FakeIndex() + return actions + + +def patch_llm_call(monkeypatch, *outputs): + calls = [] + pending = list(outputs) + + async def fake_llm_call(llm, prompt, **kwargs): + calls.append((llm, prompt, kwargs)) + return SimpleNamespace(content=pending.pop(0)) + + monkeypatch.setattr(v2_generation, "llm_call", fake_llm_call) + return calls + + +@pytest.mark.asyncio +async def test_init_flows_index_builds_instruction_index(): + flows = [ + Flow(name="included", source_code="flow included\n # instruction\n bot say hi"), + Flow(name="regular", source_code="flow regular\n bot say hi"), + Flow(name="excluded", source_code="flow excluded\n bot say no", file_info={"exclude_from_llm": True}), + ] + actions = make_actions(flows=flows) + indexes = [] + + def make_index(provider): + index = FakeIndex() + indexes.append(index) + return index + + actions.get_embedding_search_provider_instance = make_index + + await actions._init_flows_index() + + assert [item.text for item in indexes[0].items] == [ + "flow included\n # instruction\n bot say hi", + "flow regular\n bot say hi", + ] + assert [item.text for item in indexes[1].items] == ["flow included\n # instruction\n bot say hi"] + assert actions.flows_index is indexes[0] + assert actions.instruction_flows_index is indexes[1] + + +@pytest.mark.asyncio +async def test_collect_user_intent_examples_from_index_and_active_match(monkeypatch): + actions = make_actions() + actions.user_message_index = FakeIndex( + [ + SimpleNamespace(text="hello", meta={"intent": "user greet"}), + SimpleNamespace(text="help", meta={"intent": "user ask help"}), + ] + ) + doc_element = { + "_type": "doc_string_stmt", + "elements": [{"elements": [{"elements": ['"""documented utterance"""']}]}], + } + active_config = RuntimeFlowConfig("documented", elements=[{}, doc_element], has_user_intent=True) + state = SimpleNamespace( + flow_states={"head": object(), "expected": object()}, + flow_configs={"documented": active_config}, + flow_id_states={"documented": [SimpleNamespace(context={})]}, + ) + heads = [SimpleNamespace(flow_state_uid="head"), SimpleNamespace(flow_state_uid="expected")] + + monkeypatch.setattr(v2_generation, "find_all_active_event_matchers", lambda state, event=None: heads) + monkeypatch.setattr(v2_generation, "get_element_from_head", lambda state, head: SpecOp(op="match")) + + def get_event_from_element(state, flow_state, element): + flow_id = "documented" if flow_state is state.flow_states["head"] else "expected only" + return SimpleNamespace(name=InternalEvents.FLOW_FINISHED, arguments={"flow_id": flow_id}) + + monkeypatch.setattr(v2_generation, "get_event_from_element", get_event_from_element) + + intents, examples, is_embedding_only = await actions._collect_user_intent_and_examples(state, "hi", 3) + + assert intents == ["user ask help", "user greet", "expected only"] + assert 'user action: user said "help"' in examples + assert "user action: " in examples + assert "user intent: expected only" in examples + assert is_embedding_only is False + + +@pytest.mark.asyncio +async def test_generate_user_intent_embedding_only_and_llm(monkeypatch): + actions = make_actions() + state = SimpleNamespace(context={"topic": "support"}) + actions._collect_user_intent_and_examples = AsyncMock(return_value=(["user cached intent"], "", True)) + + assert await actions.generate_user_intent(state, [], "hello") == "user cached intent" + + actions._collect_user_intent_and_examples = AsyncMock(return_value=(["user ask"], "examples", False)) + calls = patch_llm_call(monkeypatch, "user intent: ask about account") + + assert await actions.generate_user_intent(state, [{"type": "event"}], "help") == "ask about account" + assert calls[0][2]["llm_params"] == {"temperature": 0.0} + assert actions.llm_task_manager.rendered_prompts[-1][2]["potential_user_intents"] == "user ask" + + +@pytest.mark.asyncio +async def test_generate_user_intent_and_bot_action_success_and_error(monkeypatch): + actions = make_actions() + state = SimpleNamespace(context={}) + actions._collect_user_intent_and_examples = AsyncMock(return_value=(["user ask"], "", False)) + patch_llm_call( + monkeypatch, + 'user intent: ask help\nbot intent: provide help\nbot action: bot say "Here"', + "assistant response only", + ) + + result = await actions.generate_user_intent_and_bot_action(state, [], "help") + + assert result == { + "user_intent": "ask help", + "bot_intent": "provide help", + "bot_action": 'bot say "Here"', + } + monkeypatch.setattr(v2_generation, "get_first_bot_action", lambda lines: None) + with pytest.raises(LlmResponseError): + await actions.generate_user_intent_and_bot_action(state, [], "help") + + +@pytest.mark.asyncio +async def test_passthrough_llm_action_branches(monkeypatch): + actions = make_actions() + events = [{"type": "UtteranceUserActionFinished", "final_transcript": "rewritten"}] + + with pytest.raises(RuntimeError, match="No LLM provided"): + await actions.passthrough_llm_action("message", SimpleNamespace(), events) + + with pytest.raises(RuntimeError, match="couldn't find last user utterance"): + await actions.passthrough_llm_action("message", SimpleNamespace(), [], llm=object()) + + calls = patch_llm_call(monkeypatch, "parsed response") + raw_token = raw_llm_request.set([{"role": "user", "content": "original"}]) + options_token = generation_options_var.set(GenerationOptions(llm_params={"top_p": 0.2})) + try: + result = await actions.passthrough_llm_action("message", SimpleNamespace(), events, llm=object()) + finally: + raw_llm_request.reset(raw_token) + generation_options_var.reset(options_token) + + assert result == "parsed response" + assert calls[0][1] == "message" + assert calls[0][2]["llm_params"] == {"top_p": 0.2} + + +@pytest.mark.asyncio +async def test_check_flow_helpers(monkeypatch): + actions = make_actions() + state = SimpleNamespace(flow_id_states={"known": []}, flow_configs={"defined": object()}) + + assert await actions.check_if_flow_exists(state, "known") is True + assert await actions.check_if_flow_exists(state, "missing") is False + assert await actions.check_if_flow_defined(state, "defined") is True + assert await actions.check_if_flow_defined(state, "missing") is False + + captured = [] + monkeypatch.setattr( + v2_generation, "find_all_active_event_matchers", lambda state, event: captured.append(event) or [object()] + ) + + assert await actions.check_for_active_flow_finished_match(state, InternalEvents.FLOW_FINISHED, flow_id="x") is True + assert await actions.check_for_active_flow_finished_match(state, "SomeActionFinished", uid="x") is True + assert await actions.check_for_active_flow_finished_match(state, "ExternalEvent", uid="x") is True + assert [event.name for event in captured] == [InternalEvents.FLOW_FINISHED, "SomeActionFinished", "ExternalEvent"] + + +@pytest.mark.asyncio +async def test_generate_flow_from_instructions_success_and_fallback(monkeypatch): + actions = make_actions() + actions.instruction_flows_index = FakeIndex([SimpleNamespace(meta={"flow": "flow example\n bot say hi"})]) + state = SimpleNamespace(context={"name": "Ada"}) + monkeypatch.setattr(v2_generation, "new_uuid", lambda: "abcd1234") + patch_llm_call(monkeypatch, '\n bot say "ok"', "bot say missing indent") + + generated = await actions.generate_flow_from_instructions(state, "say ok", []) + fallback = await actions.generate_flow_from_instructions(state, "say ok", []) + + assert generated == {"name": "dynamic_abcd", "body": 'flow dynamic_abcd\n bot say "ok"'} + assert fallback["name"] == "bot inform LLM issue" + assert "GenerateFlowFromInstructionsAction" in fallback["body"] + + +@pytest.mark.asyncio +async def test_generate_flow_from_name_success_variants(monkeypatch): + actions = make_actions() + actions.flows_index = FakeIndex() + actions.instruction_flows_index = FakeIndex([SimpleNamespace(meta={"flow": "flow sample\n bot say hi"})]) + state = SimpleNamespace(context={}) + patch_llm_call(monkeypatch, 'flow generated\n bot say "hello"', 'bot say "hello"') + + assert await actions.generate_flow_from_name(state, "generated", []) == 'flow generated\n bot say "hello"' + assert await actions.generate_flow_from_name(state, "fallback", []) == 'flow fallback\n bot say "hello"' + + actions.flows_index = None + with pytest.raises(RuntimeError, match="No flows index"): + await actions.generate_flow_from_name(state, "missing", []) + + +@pytest.mark.asyncio +async def test_generate_flow_continuation_success_fallback_and_error(monkeypatch): + actions = make_actions() + actions.flows_index = FakeIndex([SimpleNamespace(meta={"flow": "flow example\n bot say hi # remove"})]) + actions.instruction_flows_index = FakeIndex() + state = SimpleNamespace(context={}) + monkeypatch.setattr(v2_generation, "colang", lambda events: "user said hi\nbot line") + monkeypatch.setattr(v2_generation, "new_uuid", lambda: "12345678abcd") + patch_llm_call( + monkeypatch, + 'bot intent: provide answer\nbot action: bot say "Answer"', + "\n", + "assistant response only", + ) + + generated = await actions.generate_flow_continuation(state, [], temperature=0.3) + fallback = await actions.generate_flow_continuation(state, []) + + assert generated["name"] == "_dynamic_12345678 provide answer" + assert generated["parameters"] == [] + assert 'bot say "Answer"' in generated["body"] + assert fallback["name"] == "bot inform LLM issue" + monkeypatch.setattr(v2_generation, "get_first_bot_action", lambda lines: None) + with pytest.raises(LlmResponseError): + await actions.generate_flow_continuation(state, []) + + +@pytest.mark.asyncio +async def test_create_flow_escapes_name_and_applies_decorators(monkeypatch): + actions = make_actions() + monkeypatch.setattr(v2_generation, "new_uuid", lambda: "abcdef123456") + + result = await actions.create_flow([], "bot greet", 'bot say "hello"', decorators="@active") + + assert result == { + "name": "_dynamic_abcdef12 bot greet", + "parameters": [], + "body": '@active\nflow _dynamic_abcdef12 bot greet\n bot say "hello"', + } + + +@pytest.mark.asyncio +async def test_generate_value_parses_prompt_variants_and_errors(monkeypatch): + actions = make_actions(prompt="value =") + actions.flows_index = FakeIndex( + [ + SimpleNamespace(text="flow example\n $value = 1"), + SimpleNamespace(text="flow ignored\n GenerateValueAction()"), + ] + ) + state = SimpleNamespace(context={}) + patch_llm_call(monkeypatch, "value = {'ok': True};", "$answer = ['a'];", "not python") + + assert await actions.generate_value(state, "make dict", [], var_name="value") == {"ok": True} + + actions.llm_task_manager.prompt = [{"role": "user", "content": "$answer = "}] + assert await actions.generate_value(state, "make list", [], var_name="answer") == ["a"] + + with pytest.raises(Exception, match="Invalid LLM response"): + await actions.generate_value(state, "make bad", [], var_name="bad") + + +@pytest.mark.asyncio +async def test_generate_flow_success_and_error_paths(monkeypatch): + actions = make_actions() + trigger_config = RuntimeFlowConfig( + "trigger", + source_code='flow trigger\n """Help {name} using {tool_names}.\n{tools}"""\n ...', + ) + tool_config = RuntimeFlowConfig( + "lookup", + source_code='@meta(tool=True)\nflow lookup $query\n """Lookup docs"""\n await LookupAction()', + decorators={"meta": {"tool": True}}, + ) + state = SimpleNamespace( + context={"name": "Ada"}, + flow_configs={"trigger": trigger_config, "lookup": tool_config}, + flow_id_states={"trigger": [SimpleNamespace(context={"name": "Ada Lovelace"})]}, + ) + monkeypatch.setattr(v2_generation, "new_uuid", lambda: "fedcba987654") + patch_llm_call(monkeypatch, 'codeblock\nbot say "hello"\n user said something') + + result = await actions.generate_flow(state, [], flow_id="trigger") + + assert result["name"] == "_dynamic_fedcba98" + assert result["parameters"] == [] + assert result["body"] == 'flow _dynamic_fedcba98\n bot say "hello"\n wait user input\n ...' + assert "`lookup`" in actions.llm_task_manager.rendered_strings[-1][1]["tool_names"] + + with pytest.raises(RuntimeError, match="No flow_id"): + await actions.generate_flow(state, []) + + state.flow_configs["empty"] = RuntimeFlowConfig("empty", source_code="") + with pytest.raises(RuntimeError, match="No source_code"): + await actions.generate_flow(state, [], flow_id="empty")