-
Notifications
You must be signed in to change notification settings - Fork 739
feat(iorails): Add checks to run input or output rails only #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,7 +62,7 @@ | |
| from nemoguardrails.llm.taskmanager import LLMTaskManager | ||
| from nemoguardrails.rails.llm.buffer import get_buffer_strategy | ||
| from nemoguardrails.rails.llm.config import RailsConfig, _get_flow_name | ||
| from nemoguardrails.rails.llm.options import GenerationOptions | ||
| from nemoguardrails.rails.llm.options import GenerationOptions, RailsResult, RailStatus, RailType | ||
| from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler | ||
| from nemoguardrails.tracing.constants import GuardrailsAttributes | ||
| from nemoguardrails.types import LLMModel, LLMResponse, ToolCall | ||
|
|
@@ -160,6 +160,44 @@ def _build_assistant_message(content: str, tool_calls: Optional[list[ToolCall]]) | |
| } | ||
|
|
||
|
|
||
| # TODO: _determine_rails_from_messages and _get_last_content_by_role are duplicated | ||
| # from nemoguardrails.rails.llm.llmrails. They should move to a shared checks-helper | ||
| # module that both engines import, rather than IORails depending on the heavy LLMRails | ||
| # module. Tracked for a future refactor. | ||
| def _determine_rails_from_messages(messages: list[dict]) -> Optional[dict]: | ||
| """Pick which rails to run from message roles. | ||
|
|
||
| user-only -> input, assistant-only -> output, both -> input and output. | ||
| Returns ``{"rails": [...]}`` or ``None`` when there is no user/assistant | ||
| message to check. | ||
| """ | ||
| roles = {msg.get("role") for msg in reversed(messages)} | ||
| has_user = "user" in roles | ||
| has_assistant = "assistant" in roles | ||
|
|
||
| if not has_user and not has_assistant: | ||
| log.warning( | ||
| "check() called with no user or assistant messages. " | ||
| "Only system, context, or tool messages found. " | ||
| "Returning passing result without running rails." | ||
| ) | ||
| return None | ||
|
|
||
| if has_user and has_assistant: | ||
| return {"rails": ["input", "output"]} | ||
| if has_user: | ||
| return {"rails": ["input"]} | ||
| return {"rails": ["output"]} | ||
|
|
||
|
|
||
| def _get_last_content_by_role(messages: list[dict], role: str) -> str: | ||
| """Return the content of the last message with the given role, or "".""" | ||
| for msg in reversed(messages): | ||
| if msg.get("role") == role: | ||
| return msg.get("content", "") | ||
| return "" | ||
|
|
||
|
|
||
| class IORails(BaseGuardrails): | ||
| """Workflow engine for accelerated Input/Output rails inference.""" | ||
|
|
||
|
|
@@ -572,6 +610,119 @@ async def _parallel_input_rail_and_response_generation( | |
| log.debug("[%s] Main LLM response: %s", req_id, truncate(response.content)) | ||
| return response | ||
|
|
||
| def check(self, messages: LLMMessages, rail_types: Optional[list[RailType]] = None) -> RailsResult: | ||
| """Synchronous version of ``check_async``. | ||
|
|
||
| Mirrors ``generate``: spins up a short-lived IORails engine with tracing | ||
| and metrics disabled and runs the check on it. For production use, prefer | ||
| the asynchronous ``check_async``. | ||
| """ | ||
| sync_config = self.config.model_copy(deep=True) | ||
| if sync_config.tracing is not None: | ||
| sync_config.tracing.enabled = False | ||
| if sync_config.metrics is not None: | ||
| sync_config.metrics.enabled = False | ||
|
|
||
| async def _run_sync_iorails(): | ||
| """Spin up a short-lived IORails engine for one synchronous check call.""" | ||
| async with IORails(sync_config, _report_usage=False) as iorails_engine: | ||
| return await iorails_engine.check_async(messages, rail_types=rail_types) | ||
|
|
||
| return asyncio.run(_run_sync_iorails()) | ||
|
|
||
| async def check_async(self, messages: LLMMessages, rail_types: Optional[list[RailType]] = None) -> RailsResult: | ||
| """Run input and/or output rails on messages without main-LLM generation. | ||
|
|
||
| When ``rail_types`` is None the rails to run are auto-detected from the | ||
| message roles (user-only -> input, assistant-only -> output, both -> | ||
| input and output). When provided, exactly the named rail types run. | ||
|
|
||
| Submitted through the same admission queue as ``generate_async`` so the | ||
| check path shares non-streaming concurrency limits, request metrics, and | ||
| the per-request trace span. | ||
| """ | ||
| await self.start() | ||
| metrics_ctx = request_metrics() if self._metrics_enabled else nullcontext() | ||
| with metrics_ctx: | ||
| try: | ||
| return await self._generate_async_queue.submit(self._run_check, messages, rail_types) | ||
| except asyncio.QueueFull: | ||
| if self._metrics_enabled: | ||
| record_nonstream_rejected() | ||
| raise | ||
|
|
||
| async def _run_check(self, messages: LLMMessages, rail_types: Optional[list[RailType]]) -> RailsResult: | ||
| """Queue-worker entry for ``check_async``: wrap the rails in a request span.""" | ||
| tracer = self._tracer if self._tracing_enabled else None | ||
| with traced_request(tracer) as (request_span, req_id): | ||
| t0 = time.monotonic() | ||
| try: | ||
| result = await self._do_check(messages, rail_types, req_id) | ||
| except Exception: | ||
| elapsed_ms = (time.monotonic() - t0) * 1000 | ||
| log.error("[%s] check_async failed time=%.1fms", req_id, elapsed_ms, exc_info=True) | ||
| raise | ||
| if self._content_capture_enabled: | ||
| set_request_content(request_span, messages, result.content) | ||
| elapsed_ms = (time.monotonic() - t0) * 1000 | ||
| log.info( | ||
| "[%s] check_async completed time=%.1fms status=%s", | ||
| req_id, | ||
| elapsed_ms, | ||
| result.status.value, | ||
| ) | ||
| return result | ||
|
|
||
| async def _do_check( | ||
| self, | ||
| messages: LLMMessages, | ||
| rail_types: Optional[list[RailType]], | ||
| req_id: str, | ||
| ) -> RailsResult: | ||
| """Core check pipeline: run the requested input/output rails on messages.""" | ||
| log.info("[%s] check_async called", req_id) | ||
| log.debug("[%s] check_async messages=%s", req_id, truncate(messages)) | ||
|
|
||
| if rail_types is not None: | ||
| rails_to_run = [rail_type.value for rail_type in rail_types] | ||
| else: | ||
| determined = _determine_rails_from_messages(messages) | ||
| if determined is None: | ||
| last_content = messages[-1].get("content", "") if messages else "" | ||
| return RailsResult(status=RailStatus.PASSED, content=last_content) | ||
| rails_to_run = determined["rails"] | ||
|
|
||
| # Content reported on a pass. IORails rails only block or pass; they never | ||
| # rewrite content, so a passing check returns the checked content unchanged | ||
| # (RailStatus.MODIFIED never occurs for IORails). | ||
| if "output" in rails_to_run: | ||
| pass_content = _get_last_content_by_role(messages, "assistant") | ||
| else: | ||
| pass_content = _get_last_content_by_role(messages, "user") | ||
|
|
||
| if "input" in rails_to_run: | ||
| log.info("[%s] Running input rails", req_id) | ||
| input_result = await self.rails_manager.is_input_safe(messages) | ||
| if not input_result.is_safe: | ||
| log.info("[%s] Input blocked: %s", req_id, input_result.reason) | ||
| if self._metrics_enabled: | ||
| record_request_blocked(RailDirection.INPUT) | ||
| return RailsResult(status=RailStatus.BLOCKED, content=REFUSAL_MESSAGE, rail=input_result.triggered_rail) | ||
|
|
||
| if "output" in rails_to_run: | ||
| bot_response = _get_last_content_by_role(messages, "assistant") | ||
| log.info("[%s] Running output rails", req_id) | ||
| output_result = await self.rails_manager.is_output_safe(messages, bot_response) | ||
| if not output_result.is_safe: | ||
| log.info("[%s] Output blocked: %s", req_id, output_result.reason) | ||
| if self._metrics_enabled: | ||
| record_request_blocked(RailDirection.OUTPUT) | ||
| return RailsResult( | ||
| status=RailStatus.BLOCKED, content=REFUSAL_MESSAGE, rail=output_result.triggered_rail | ||
| ) | ||
|
Comment on lines
+712
to
+722
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Prompt To Fix With AIThis is a comment left during a code review.
Path: nemoguardrails/guardrails/iorails.py
Line: 712-722
Comment:
**False-positive block for output-only check with no assistant message**
When `rail_types=[RailType.OUTPUT]` is requested (explicitly or auto-detected) but the messages contain no assistant message, `_get_last_content_by_role(messages, "assistant")` returns `""`. `ContentSafetyOutputAction._extract_messages` then raises `RuntimeError("bot_response is required …")`, which `RailAction.run()` catches and converts to `RailResult(is_safe=False)`. `_do_check` then returns `RailsResult(status=BLOCKED, content=REFUSAL_MESSAGE)` — a false-positive block indistinguishable from a real safety verdict. Adding an early guard (`if not bot_response: return RailsResult(status=RailStatus.PASSED, content="")` or raising `ValueError`) before calling `is_output_safe` would surface the right outcome.
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| return RailsResult(status=RailStatus.PASSED, content=pass_content) | ||
|
|
||
| def _validate_streaming_with_output_rails(self) -> None: | ||
| """Raise if output rails exist but streaming is not enabled for them.""" | ||
| if len(self.config.rails.output.flows) > 0 and not self._has_streaming_output_rails: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reversed()call inside a set comprehension is a no-op — set construction is order-independent — and just allocates an extra iterator. A plainfor msg in messagesproduces an identical set.Prompt To Fix With AI
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!