diff --git a/nemoguardrails/library/regex/actions.py b/nemoguardrails/library/regex/actions.py index 5b47478ec8..f0d059b04e 100644 --- a/nemoguardrails/library/regex/actions.py +++ b/nemoguardrails/library/regex/actions.py @@ -33,6 +33,28 @@ def _regex_blocked_mapping(result: RegexDetectionResult) -> bool: return result.get("is_match", False) +def _get_regex_options(source: str, config: RailsConfig): + """Return the RegexDetectionOptions for *source*, or None with a warning.""" + if source not in ("input", "output", "retrieval"): + raise ValueError("source must be one of 'input', 'output', or 'retrieval'") + + regex_config = config.rails.config.regex_detection + if regex_config is None: + log.warning("No regex_detection configuration found.") + return None + + options = getattr(regex_config, source, None) + if options is None: + log.warning("No regex rails configuration found for source: %s", source) + return None + + if not options.compiled_patterns: + log.debug("No regex patterns specified for source: %s", source) + return None + + return options + + @action(is_system_action=True, output_mapping=_regex_blocked_mapping) async def detect_regex_pattern( source: str, @@ -53,23 +75,8 @@ async def detect_regex_pattern( - text (str): The original text that was checked. - detections (List[str]): List of pattern strings that matched. """ - if source not in ("input", "output", "retrieval"): - raise ValueError("source must be one of 'input', 'output', or 'retrieval'") - - regex_config = config.rails.config.regex_detection - if regex_config is None: - log.warning("No regex_detection configuration found.") - return RegexDetectionResult(is_match=False, text=text, detections=[]) - - options = getattr(regex_config, source, None) - + options = _get_regex_options(source, config) if options is None: - log.warning("No regex rails configuration found for source: %s", source) - return RegexDetectionResult(is_match=False, text=text, detections=[]) - - compiled_patterns = options.compiled_patterns - if not compiled_patterns: - log.debug("No regex patterns specified for source: %s", source) return RegexDetectionResult(is_match=False, text=text, detections=[]) if not text: @@ -78,12 +85,49 @@ async def detect_regex_pattern( # Match against pre-compiled patterns and collect all matches. matched: List[str] = [] - for compiled, raw_pattern in zip(compiled_patterns, options.patterns): + for compiled, pcfg in zip(options.compiled_patterns, options.normalized_patterns): if compiled.search(text): - log.info("Regex pattern matched: %s", raw_pattern) - matched.append(raw_pattern) + log.info("Regex pattern matched: %s", pcfg.pattern) + matched.append(pcfg.pattern) if matched: return RegexDetectionResult(is_match=True, text=text, detections=matched) return RegexDetectionResult(is_match=False, text=text, detections=[]) + + +@action(is_system_action=True) +async def redact_regex_pattern( + source: str, + text: str, + config: RailsConfig, + **kwargs, +) -> str: + """Replace all regex-matched spans with the configured mask token. + + Args: + source: The source for the text, i.e. "input", "output", "retrieval". + text: The text to redact. + config: The rails configuration object. + + Returns: + The text with every match of every configured pattern replaced by + the mask_token (default ````). + """ + options = _get_regex_options(source, config) + if options is None: + return text + + if not text: + log.debug("Empty text provided, skipping regex redaction.") + return text + + redacted = text + for compiled, pcfg in zip(options.compiled_patterns, options.normalized_patterns): + if compiled.search(redacted): + log.info("Regex pattern redacted: %s", pcfg.pattern) + mask = pcfg.mask_token + # use a lambda to ensure the mask token is treated as a literal string, not a regex + redacted = compiled.sub(lambda _: mask, redacted) + + return redacted diff --git a/nemoguardrails/library/regex/flows.co b/nemoguardrails/library/regex/flows.co index 6ac5c8f666..b5466d4568 100644 --- a/nemoguardrails/library/regex/flows.co +++ b/nemoguardrails/library/regex/flows.co @@ -2,32 +2,50 @@ flow regex check input """Check if the user input matches any forbidden regex patterns.""" - $result = await DetectRegexMatchAction(source="input", text=$user_message) + $result = await DetectRegexPatternAction(source="input", text=$user_message) if $result["is_match"] bot refuse to respond abort +flow regex redact input + """Redact any regex-matched content in the user input.""" + global $user_message + $user_message = await RedactRegexPatternAction(source="input", text=$user_message) + + # OUTPUT RAILS flow regex check output """Check if the bot output matches any forbidden regex patterns.""" - $result = await DetectRegexMatchAction(source="output", text=$bot_message) + $result = await DetectRegexPatternAction(source="output", text=$bot_message) if $result["is_match"] bot refuse to respond abort +flow regex redact output + """Redact any regex-matched content in the bot output.""" + global $bot_message + $bot_message = await RedactRegexPatternAction(source="output", text=$bot_message) + + # RETRIEVAL RAILS flow regex check retrieval """Check if the relevant chunks from the knowledge base match any forbidden regex patterns. """ - $result = await DetectRegexMatchAction(source="retrieval", text=$relevant_chunks) + $result = await DetectRegexPatternAction(source="retrieval", text=$relevant_chunks) if $result["is_match"] $relevant_chunks = "" + + +flow regex redact retrieval + """Redact any regex-matched content in the retrieved knowledge base chunks.""" + global $relevant_chunks + $relevant_chunks = await RedactRegexPatternAction(source="retrieval", text=$relevant_chunks) diff --git a/nemoguardrails/library/regex/flows.v1.co b/nemoguardrails/library/regex/flows.v1.co index 5b00682df9..025894ce60 100644 --- a/nemoguardrails/library/regex/flows.v1.co +++ b/nemoguardrails/library/regex/flows.v1.co @@ -12,6 +12,11 @@ define subflow regex check input stop +define subflow regex redact input + """Redact any regex-matched content in the user input.""" + $user_message = execute redact_regex_pattern(source="input", text=$user_message) + + # OUTPUT RAILS @@ -24,6 +29,11 @@ define subflow regex check output stop +define subflow regex redact output + """Redact any regex-matched content in the bot output.""" + $bot_message = execute redact_regex_pattern(source="output", text=$bot_message) + + # RETRIEVAL RAILS @@ -33,3 +43,8 @@ define subflow regex check retrieval if $result["is_match"] $relevant_chunks = "" + + +define subflow regex redact retrieval + """Redact any regex-matched content in the retrieved knowledge base chunks.""" + $relevant_chunks = execute redact_regex_pattern(source="retrieval", text=$relevant_chunks) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index fb890d99c6..ddbbd3d13c 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -248,30 +248,55 @@ class SensitiveDataDetection(BaseModel): ) +class RegexPatternConfig(BaseModel): + """A single regex pattern with an optional per-pattern mask token.""" + + pattern: str = Field(description="The regex pattern string.") + mask_token: str = Field( + default="", + description="Replacement token used when redacting this pattern's matches.", + ) + + class RegexDetectionOptions(BaseModel): - """Configuration options for regex pattern detection on a specific source.""" + """Configuration options for regex pattern detection on a specific source. + + Each entry in ``patterns`` may be a plain string (uses the default + ```` mask token) or an object with ``pattern`` and ``mask_token`` + keys. + """ - patterns: List[str] = Field( + patterns: List[Union[str, RegexPatternConfig]] = Field( default_factory=list, - description="List of regex patterns to match against the text.", + description="List of regex patterns (strings or objects with pattern/mask_token).", ) case_insensitive: bool = Field( default=False, description="Whether to perform case-insensitive matching.", ) + _normalized_patterns: List[RegexPatternConfig] = PrivateAttr(default_factory=list) _compiled_patterns: List["re.Pattern[str]"] = PrivateAttr(default_factory=list) @model_validator(mode="after") def compile_patterns(self) -> "RegexDetectionOptions": - """Pre-compile regex patterns at config load time.""" + """Normalize plain strings to RegexPatternConfig and pre-compile.""" + normalized: List[RegexPatternConfig] = [] + for entry in self.patterns: + if isinstance(entry, str): + normalized.append(RegexPatternConfig(pattern=entry)) + else: + normalized.append(entry) + flags = re.IGNORECASE if self.case_insensitive else 0 compiled = [] - for i, pattern in enumerate(self.patterns): + for i, cfg in enumerate(normalized): try: - compiled.append(re.compile(pattern, flags)) + compiled.append(re.compile(cfg.pattern, flags)) except re.error as e: - raise ValueError(f"Invalid regex pattern at index {i} ({pattern!r}): {e}") from e + raise ValueError(f"Invalid regex pattern at index {i} ({cfg.pattern!r}): {e}") from e + + object.__setattr__(self, "_normalized_patterns", normalized) object.__setattr__(self, "_compiled_patterns", compiled) return self @@ -280,6 +305,11 @@ def compiled_patterns(self) -> List["re.Pattern[str]"]: """Return the pre-compiled regex patterns.""" return self._compiled_patterns + @property + def normalized_patterns(self) -> List[RegexPatternConfig]: + """Return the normalized pattern configs.""" + return self._normalized_patterns + class RegexDetection(BaseModel): """Configuration for regex pattern detection.""" diff --git a/tests/test_regex_detection.py b/tests/test_regex_detection.py index 982017b132..996ca3d45b 100644 --- a/tests/test_regex_detection.py +++ b/tests/test_regex_detection.py @@ -19,7 +19,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action from nemoguardrails.actions.actions import ActionResult -from nemoguardrails.library.regex.actions import detect_regex_pattern +from nemoguardrails.library.regex.actions import detect_regex_pattern, redact_regex_pattern from tests.utils import TestChat @@ -704,3 +704,319 @@ def test_regex_output_mapping_is_registered(): "detect_regex_pattern is missing output_mapping — streaming output rails " "will silently pass matched content through (see #1936 follow-up)" ) + + +# ── Redaction tests ────────────────────────────────────────────────────── + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_replaces_matches(): + """redact_regex_pattern replaces matched spans with .""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="input", + text="My SSN is 123-45-6789 and yours is 987-65-4321.", + config=config, + ) + assert result == "My SSN is and yours is ." + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_custom_mask_token(): + """Per-pattern mask_token overrides the default replacement string.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + output: + patterns: + - pattern: "\\\\bsecret\\\\b" + mask_token: "[MASKED]" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="output", + text="This is a secret message with secret data.", + config=config, + ) + assert result == "This is a [MASKED] message with [MASKED] data." + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_mixed_plain_and_object_patterns(): + """Plain string patterns use default , object patterns use their mask_token.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + - pattern: "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}" + mask_token: "" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="input", + text="SSN: 123-45-6789, email: test@example.com", + config=config, + ) + assert result == "SSN: , email: " + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_no_match_passes_through(): + """Text without matches is returned unchanged.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="input", + text="Hello, no sensitive data here.", + config=config, + ) + assert result == "Hello, no sensitive data here." + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_multiple_patterns(): + """Multiple patterns are all redacted in a single pass.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + - "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="input", + text="SSN: 123-45-6789, email: test@example.com", + config=config, + ) + assert result == "SSN: , email: " + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_empty_text(): + """Empty text is returned as-is.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\bsecret\\\\b" + """, + colang_content="", + ) + + result = await redact_regex_pattern(source="input", text="", config=config) + assert result == "" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_redact_regex_accepts_extra_kwargs(): + """redact_regex_pattern must accept extra kwargs from the action dispatcher.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + output: + patterns: + - "\\\\bconfidential\\\\b" + """, + colang_content="", + ) + + result = await redact_regex_pattern( + source="output", + text="This is confidential information.", + config=config, + context={"user_message": "hi"}, + llm_task_manager=object(), + ) + assert result == "This is information." + + +@pytest.mark.unit +def test_regex_redact_input_e2e(): + """End-to-end: regex redact input flow replaces matched content.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + input: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + input: + flows: + - regex redact input + - check user message + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define flow check user message + execute check_user_message(user_message=$user_message) + """, + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", ' "Got it, your SSN is on file."'], + ) + + @action() + def check_user_message(user_message: str): + assert "123-45-6789" not in user_message + assert "" in user_message + + chat.app.register_action(check_user_message) + + chat >> "My SSN is 123-45-6789" + chat << "Got it, your SSN is on file." + + +@pytest.mark.unit +def test_regex_redact_output_e2e(): + """End-to-end: regex redact output flow replaces matched content in bot reply.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + output: + patterns: + - pattern: "\\\\d{3}-\\\\d{2}-\\\\d{4}" + mask_token: "" + output: + flows: + - regex redact output + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + """, + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", ' "Your SSN is 123-45-6789."'], + ) + + chat >> "Hi!" + chat << "Your SSN is ." + + +@pytest.mark.unit +def test_regex_redact_retrieval_e2e(): + """End-to-end: regex redact retrieval flow replaces matched content in chunks.""" + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + regex_detection: + retrieval: + patterns: + - "\\\\d{3}-\\\\d{2}-\\\\d{4}" + retrieval: + flows: + - regex redact retrieval + - check relevant chunks + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define flow check relevant chunks + execute check_relevant_chunks(relevant_chunks=$relevant_chunks) + """, + ) + + chat = TestChat( + config, + llm_completions=[" express greeting", ' "Here is what I found."'], + ) + + @action() + def retrieve_relevant_chunks(): + context_updates = {"relevant_chunks": "Employee SSN: 123-45-6789 in record."} + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + @action() + def check_relevant_chunks(relevant_chunks: str): + assert relevant_chunks == "Employee SSN: in record." + + chat.app.register_action(retrieve_relevant_chunks) + chat.app.register_action(check_relevant_chunks) + + chat >> "Hi!" + chat << "Here is what I found."