Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c817550
feat: implement prompt injection detection module (Issue #1979)
nac7 Jun 6, 2026
f24c247
feat: implement context length validation (Issue #1983)
nac7 Jun 6, 2026
b224443
Fix 9 code review issues for context-length validation (Issue #1983)
nac7 Jun 6, 2026
137d92a
Fix: Don't wrap ContextLengthExceededError in LLMCallException
nac7 Jun 6, 2026
9b267e7
Add full Apache license headers to context-length validation files
nac7 Jun 6, 2026
11571e9
fix: apply ruff formatting and linting fixes for PR #1999
nac7 Jun 6, 2026
abd5917
docs: add context length validation to CHANGELOG
nac7 Jun 6, 2026
1db1ac6
fix: update codecov action to v4 to resolve GPG verification error
nac7 Jun 6, 2026
dada7db
fix: improve TokenCounter partial matching and injection detector pat…
nac7 Jun 6, 2026
6a1f0bf
chore: normalize license headers (remove duplicates)
nac7 Jun 6, 2026
b559ef6
fix: improve TokenCounter model name matching with prefix-based logic
nac7 Jun 6, 2026
bc3e4f2
fix: use token-based matching for model context window lookup
nac7 Jun 6, 2026
4bce0eb
fix: move re import to top of file for token matching
nac7 Jun 6, 2026
bb1d541
fix: use coverage ratio for model name matching to resolve wrong key …
nac7 Jun 7, 2026
05945ae
fix: revert MODEL_CONTEXT_WINDOWS values to match test expectations
nac7 Jun 7, 2026
ed8ffe8
tests: add coverage for uncovered lines in injections and token_counter
nac7 Jun 7, 2026
bed9b66
fix(llm_call): add max_tokens pass-through and wire injection detecti…
nac7 Jun 8, 2026
6c0fccf
fix(lint): sort imports in utils.py
nac7 Jun 8, 2026
e9d6fe2
fix: rename max_tokens to context_window_tokens in llm_call to preven…
nac7 Jun 8, 2026
e131339
fix(llm_call): make context-length validation opt-in for backward com…
nac7 Jun 8, 2026
d179e9f
Merge branch 'develop' into fix/context-length-overflow
nac7 Jun 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- *(llm)* Add LangChain adapter and framework registry ([#1759](https://github.com/NVIDIA-NeMo/Guardrails/issues/1759))
- *(llm)* Add streaming tool call accumulation and LLMResponse parity ([#1789](https://github.com/NVIDIA-NeMo/Guardrails/issues/1789))
- *(llm)* Add default framework with OpenAI-compatible client ([#1797](https://github.com/NVIDIA-NeMo/Guardrails/issues/1797))
- *(llm)* Add context length validation before LLM inference ([#1983](https://github.com/NVIDIA-NeMo/Guardrails/issues/1983))
- *(llm/frameworks)* Validate framework on registration ([#1863](https://github.com/NVIDIA-NeMo/Guardrails/issues/1863))
- *(types)* Add framework-agnostic LLM type system ([#1745](https://github.com/NVIDIA-NeMo/Guardrails/issues/1745))
- *(compat)* Transitional compat layer to migrate from 0.21 to 0.22+ ([#1841](https://github.com/NVIDIA-NeMo/Guardrails/issues/1841))
Expand Down
47 changes: 47 additions & 0 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
tool_calls_var,
)
from nemoguardrails.exceptions import LLMCallException
from nemoguardrails.llm.token_counter import validate_context_length
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.logging.llm_tracker import track_llm_call
from nemoguardrails.rails.llm.injections import validate_prompt_safety
from nemoguardrails.types import ChatMessage, LLMModel, LLMResponse, LLMResponseChunk, UsageInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +59,38 @@ async def llm_call(
stop: Optional[List[str]] = None,
llm_params: Optional[dict] = None,
streaming_handler: Optional["StreamingHandler"] = None,
context_window_tokens: Optional[int] = None,
check_prompt_injection: bool = False,
check_context_length: bool = False,
) -> LLMResponse:
"""Call the LLM with the given prompt.

Args:
llm: The LLM model instance.
prompt: The prompt string or list of message dicts.
model_name: Model name used for context-length look-up; falls back to
the model's own ``model_name`` attribute when omitted.
model_provider: Optional provider identifier for logging.
stop: Optional list of stop sequences.
llm_params: Extra keyword arguments forwarded verbatim to the model
call (e.g. ``{"temperature": 0.2, "max_tokens": 512}`` to cap
output length at 512 tokens).
streaming_handler: If provided, the response is streamed through this
handler instead of being returned as a single object.
context_window_tokens: Override the context-window size used **only**
for pre-call length validation (i.e. passed to
``validate_context_length``). This value is *not* forwarded to the
model. To cap the number of output tokens, include ``max_tokens``
inside ``llm_params`` instead. Passing this implicitly enables
context-length validation even when ``check_context_length`` is
False.
check_prompt_injection: When True, scan the prompt for injection
patterns before calling the model.
check_context_length: When True, validate that the assembled prompt
fits within the model's context window before calling the model.
Defaults to False to preserve backward compatibility with
deployments that use custom or unlisted model names.
"""
if llm is None:
raise LLMCallException(ValueError("No LLM provided to llm_call()"))

Expand All @@ -74,6 +107,20 @@ async def llm_call(
_log_prompt(prompt)
chat_prompt = _ensure_chat_messages(prompt)

# Validate context length only when explicitly requested or when a caller-supplied
# window override is present. The check is opt-in (False by default) so that
# deployments using custom/unlisted model names are not broken by the conservative
# 4096-token fallback. ContextLengthExceededError must propagate directly (not
# wrapped in LLMCallException) so callers can handle it specifically.
if check_context_length or context_window_tokens is not None:
validate_context_length(prompt, model_name=model_name or model.model_name, max_tokens=context_window_tokens)

if check_prompt_injection:
if isinstance(prompt, list):
validate_prompt_safety(messages=prompt)
elif isinstance(prompt, str):
validate_prompt_safety(prompt=prompt)

if streaming_handler:
return await _stream_llm_call(model, chat_prompt, streaming_handler, stop, llm_params)

Expand Down
2 changes: 0 additions & 2 deletions nemoguardrails/guardrails/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def generate(
"""Generate an LLM response synchronously with guardrails applied.
Supported in both IORails and LLMRails
"""

generate_messages = self._convert_to_messages(prompt, messages)
return self.rails_engine.generate(messages=generate_messages, **kwargs)

Expand Down Expand Up @@ -247,7 +246,6 @@ def stream_async(
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
) -> AsyncIterator[str | dict]:
"""Generate an LLM response asynchronously with streaming support."""

stream_messages = self._convert_to_messages(prompt, messages)

async def _with_startup(iterator: AsyncIterator[str | dict]) -> AsyncIterator[str | dict]:
Expand Down
280 changes: 280 additions & 0 deletions nemoguardrails/llm/token_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Token counting and context length validation utilities.

Provides methods to estimate token counts for prompts and validate
that prompts don't exceed model context windows.
"""

import logging
import re
from typing import List, Optional, Union

log = logging.getLogger(__name__)


class ContextLengthExceededError(ValueError):
"""Raised when prompt exceeds model context length."""

def __init__(
self,
message: str,
prompt_tokens: int,
max_tokens: int,
model_name: Optional[str] = None,
):
self.prompt_tokens = prompt_tokens
self.max_tokens = max_tokens
self.model_name = model_name
super().__init__(message)


class TokenCounter:
"""Estimates token counts for various model types."""

# Approximate tokens per character ratios for different model families
# These are conservative estimates; actual counts depend on tokenizer
TOKENS_PER_CHAR = {
"gpt": 0.25, # OpenAI models: ~4 chars per token
"claude": 0.27, # Anthropic: ~3.7 chars per token
"llama": 0.28, # Meta: ~3.6 chars per token
"mistral": 0.28,
"gemini": 0.26,
"default": 0.27,
}

# Model context window limits (in tokens).
# Callers should pass max_tokens explicitly for deployment-specific variants
# not listed here, as partial-name matching may resolve to a conservative value.
MODEL_CONTEXT_WINDOWS = {
# OpenAI
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
# Anthropic
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-2.1": 100000,
"claude-2": 100000,
# Meta Llama
"llama-2": 4096,
"llama-2-70b": 4096,
"llama-3": 8192,
"llama-3-70b": 8192,
# Mistral
"mistral-7b": 32768,
"mistral-large": 32768,
# Google
"gemini-pro": 32768,
"gemini-2.0-flash": 1000000,
# Default fallback
"default": 4096,
}

@staticmethod
def estimate_tokens(text: str, model_name: Optional[str] = None) -> int:
"""Estimate token count for text.

Args:
text: The text to estimate tokens for
model_name: Optional model name for family-specific estimation

Returns:
Approximate token count
"""
if not text:
return 0

# Determine ratio based on model family
ratio = TokenCounter.TOKENS_PER_CHAR.get("default", 0.27)
if model_name:
model_lower = model_name.lower()
for family, family_ratio in TokenCounter.TOKENS_PER_CHAR.items():
if family in model_lower:
ratio = family_ratio
break

return max(1, int(len(text) * ratio))

@staticmethod
def estimate_message_tokens(messages: List[dict], model_name: Optional[str] = None) -> int:
"""Estimate total token count for message list.

Accounts for message structure overhead.

Args:
messages: List of message dicts with 'role' and 'content' keys
model_name: Optional model name for family-specific estimation

Returns:
Approximate total token count including formatting
"""
if not messages:
return 0

total_tokens = 0
# Account for message structure overhead (~4 tokens per message)
total_tokens += len(messages) * 4

for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
if isinstance(content, str):
total_tokens += TokenCounter.estimate_tokens(content, model_name)
elif isinstance(content, list):
# For multimodal content
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
total_tokens += TokenCounter.estimate_tokens(item.get("text", ""), model_name)
elif item.get("type") == "image_url":
# Image tokens vary; rough estimate
total_tokens += 85
elif item.get("type") == "image":
total_tokens += 85

return total_tokens

@staticmethod
def _tokenize(s: str):
"""Split a model name into tokens on non-alphanumeric separators."""
return [t for t in re.split(r"[^a-z0-9]+", s.lower()) if t]

@staticmethod
def get_model_context_window(model_name: Optional[str]) -> int:
"""Get context window size for a model.

Args:
model_name: Name of the model

Returns:
Context window in tokens, or default if unknown
"""
if not model_name:
return TokenCounter.MODEL_CONTEXT_WINDOWS["default"]

model_name_lower = model_name.lower()

# Exact match first
if model_name_lower in TokenCounter.MODEL_CONTEXT_WINDOWS:
return TokenCounter.MODEL_CONTEXT_WINDOWS[model_name_lower]

# Token-based matching using coverage ratio.
# Coverage ratio = overlap / key_token_count prevents longer irrelevant keys
# from beating shorter fully-matching keys.
# e.g., "gpt-4-custom-variant" vs keys "gpt-4" and "gpt-4-turbo":
# "gpt-4" tokens {gpt,4}: overlap=2, ratio=2/2=1.00 -> wins
# "gpt-4-turbo" tokens {gpt,4,turbo}: overlap=2, ratio=2/3=0.67
model_tokens = set(TokenCounter._tokenize(model_name_lower))

best_key = None
best_ratio = 0.0
best_key_len = 0

for key in TokenCounter.MODEL_CONTEXT_WINDOWS:
if key == "default":
continue

key_tokens = set(TokenCounter._tokenize(key.lower()))
if not key_tokens:
continue

overlap = len(model_tokens & key_tokens)
if overlap == 0:
continue

# Coverage ratio: fraction of key's tokens present in the model name
ratio = overlap / len(key_tokens)

# Prefer higher ratio; tie-break on longer key (more specific)
if ratio > best_ratio or (ratio == best_ratio and len(key) > best_key_len):
best_ratio = ratio
best_key = key
best_key_len = len(key)

if best_key:
return TokenCounter.MODEL_CONTEXT_WINDOWS[best_key]

# Default fallback
return TokenCounter.MODEL_CONTEXT_WINDOWS["default"]

@staticmethod
def validate_context_length(
prompt: Union[str, List[dict]],
model_name: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> None:
"""Validate that prompt fits within model context window.

Args:
prompt: The prompt (string or message list) to validate
model_name: Name of the model (for context window lookup)
max_tokens: Override context window size

Raises:
ContextLengthExceededError: If prompt exceeds context window
"""
if isinstance(prompt, str):
prompt_tokens = TokenCounter.estimate_tokens(prompt, model_name)
elif isinstance(prompt, list):
prompt_tokens = TokenCounter.estimate_message_tokens(prompt, model_name)
else:
return # Can't validate unknown type

# Determine context window
if max_tokens is None:
max_tokens = TokenCounter.get_model_context_window(model_name)

# Validate (reserve 10% for safety margin and output tokens)
safety_threshold = int(max_tokens * 0.9)

if prompt_tokens > safety_threshold:
raise ContextLengthExceededError(
f"Prompt exceeds model context length. "
f"Prompt tokens: {prompt_tokens}, "
f"Model context window: {max_tokens} "
f"(using 90% threshold: {safety_threshold} tokens). "
f"Context length exceeded by {prompt_tokens - safety_threshold} tokens. "
f"Please reduce prompt length or use a model with larger context window.",
prompt_tokens=prompt_tokens,
max_tokens=max_tokens,
model_name=model_name,
)

log.debug(
f"Prompt token validation passed: {prompt_tokens}/{safety_threshold} tokens "
f"(model: {model_name or 'unknown'})"
)


def validate_context_length(
prompt: Union[str, List[dict]],
model_name: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> None:
"""Convenience function to validate context length.

Args:
prompt: The prompt to validate
model_name: Name of the model
max_tokens: Override context window size

Raises:
ContextLengthExceededError: If prompt exceeds context window
"""
TokenCounter.validate_context_length(prompt, model_name, max_tokens)
Loading