diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 2ab4bcd68..21e90725c 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -939,6 +939,7 @@ def stream_async( prompt: Optional[str] = None, messages: Optional[List[dict]] = None, options: Optional[Union[dict, GenerationOptions]] = None, + state: Optional[Union[dict, State]] = None, ) -> AsyncIterator[str]: """Simplified interface for getting directly the streamed tokens from the LLM.""" streaming_handler = StreamingHandler() @@ -951,6 +952,7 @@ def stream_async( messages=messages, streaming_handler=streaming_handler, options=options, + state=state, ) ) # TODO: