Skip to content

Commit c2fb77f

Browse files
committed
fix: finalize name changes and defautl values
1 parent edc834d commit c2fb77f

File tree

4 files changed

+49
-47
lines changed

4 files changed

+49
-47
lines changed

nemoguardrails/rails/llm/buffer.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,37 @@ class BufferStrategy(ABC):
2525
def from_config(cls, config: OutputRailsStreamingConfig) -> "BufferStrategy":
2626
pass
2727

28+
# The abstract method is not async to ensure the return type
29+
# matches the async generator in the concrete implementation.
2830
@abstractmethod
29-
async def __call__(self, streaming_handler) -> AsyncGenerator:
31+
def __call__(
32+
self, streaming_handler
33+
) -> AsyncGenerator[Tuple[List[str], str], None]:
3034
pass
3135

3236
@abstractmethod
3337
def generate_chunk_str(self, *args, **kwargs) -> str:
3438
pass
3539

3640

37-
class SlidingWindow(BufferStrategy):
38-
"""DRFAT: A minimal buffer strategy that buffers chunks and yields them when the buffer is full."""
41+
class RollingBuffer(BufferStrategy):
42+
"""A minimal buffer strategy that buffers chunks and yields them when the buffer is full.
3943
40-
# - **chunk_size (X)**: This would correspond to the number of tokens in each chunk processed by the `streaming_handler`.
41-
# - **max_validation_length (N)**: This would correspond to the `look_back_size` parameter in the code, representing the maximum number of lookback chunks.
42-
#
43-
# In the code:
44-
# - `window_size` represents the number of chunks to process in each window.
45-
# - `look_back_size` represents the number of previous chunks to include in the window for context.
44+
Args:
45+
buffer_context_size (int): The number of tokens carried over from the previous chunk to provide context for continuity in processing.
46+
buffer_chunk_size (int): The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.
47+
"""
4648

47-
def __init__(self, look_back_size: int = 5, window_size: int = 10):
48-
self.look_back_size = look_back_size
49-
self.window_size = window_size
49+
def __init__(self, buffer_context_size: int = 5, buffer_chunk_size: int = 10):
50+
self.buffer_context_size = buffer_context_size
51+
self.buffer_chunk_size = buffer_chunk_size
5052
self.last_index = 0
5153

5254
@classmethod
5355
def from_config(cls, config: OutputRailsStreamingConfig):
54-
return cls(look_back_size=config.look_back_size, window_size=config.window_size)
56+
return cls(
57+
buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size
58+
)
5559

5660
async def __call__(
5761
self, streaming_handler
@@ -62,30 +66,26 @@ async def __call__(
6266
async for chunk in streaming_handler:
6367
buffer.append(chunk)
6468
index += 1
65-
# TODO: this is done in StreamingHandler, we need to find away to remove this duplication
66-
# print(f"\033[92m{chunk}\033[0m", end="", flush=True)
67-
# the hackish solution in StreamingHandler is resolved in Chat ClI, we should not alter interfaces
68-
# when we have stream_async we must use it everywhere, adding enable_print will cause headaches
69-
# then this hackish solution will cause a cancer of this hackish solution and will contaminate the whole codebase
7069

71-
if len(buffer) >= self.window_size:
70+
if len(buffer) >= self.buffer_chunk_size:
7271
yield (
73-
# buffer is used to apply output rails
74-
buffer[-self.window_size - self.look_back_size :],
75-
# this is what gets printed in the console or yield to user
72+
# we apply output rails on the buffer
73+
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
74+
# generate_chunk_str is what gets printed in the console or yield to user
7675
# to avoid repeating the already streamed/printed chunk
7776
self.generate_chunk_str(
78-
buffer[-self.window_size - self.look_back_size :], index
77+
buffer[-self.buffer_chunk_size - self.buffer_context_size :],
78+
index,
7979
),
8080
)
81-
buffer = buffer[-self.look_back_size :]
81+
buffer = buffer[-self.buffer_context_size :]
8282

8383
# Yield any remaining buffer if it's not empty
8484
if buffer:
8585
yield (
8686
buffer,
8787
self.generate_chunk_str(
88-
buffer[-self.window_size - self.look_back_size :], index
88+
buffer[-self.buffer_chunk_size - self.buffer_context_size :], index
8989
),
9090
)
9191

@@ -104,5 +104,5 @@ def generate_chunk_str(self, buffer, current_index) -> str:
104104

105105
def get_buffer_strategy(config: OutputRailsStreamingConfig) -> BufferStrategy:
106106
# TODO: use a factory function or class
107-
# currently we only have SlidingWindow, in future we use a registry
108-
return SlidingWindow.from_config(config)
107+
# currently we only have RollingBuffer, in future we use a registry
108+
return RollingBuffer.from_config(config)

nemoguardrails/rails/llm/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,22 @@ class InputRails(BaseModel):
305305

306306

307307
class OutputRailsStreamingConfig(BaseModel):
308+
"""Configuration for managing streaming output of LLM tokens."""
309+
308310
enabled: bool = Field(
309-
default=False, description="Indicates if streaming is enabled."
311+
default=False, description="Enables streaming mode when True."
312+
)
313+
chunk_size: int = Field(
314+
default=200,
315+
description="The number of tokens in each processing chunk. This is the size of the token block on which output rails are applied.",
316+
)
317+
context_size: int = Field(
318+
default=50,
319+
description="The number of tokens carried over from the previous chunk to provide context for continuity in processing.",
310320
)
311-
look_back_size: int = Field(default=5, description="The look back size.")
312-
window_size: int = Field(default=10, description="The window size.")
313321
stream_first: bool = Field(
314322
default=True,
315-
description="Prioritizes streaming chunks before applying output rails.",
323+
description="If True, token chunks are streamed immediately before output rails are applied.",
316324
)
317325
model_config = ConfigDict(extra="allow")
318326

tests/test_buffer_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717

18-
from nemoguardrails.rails.llm.buffer import SlidingWindow as BufferStrategy
18+
from nemoguardrails.rails.llm.buffer import RollingBuffer as BufferStrategy
1919

2020

2121
async def fake_streaming_handler():
@@ -26,7 +26,7 @@ async def fake_streaming_handler():
2626

2727
@pytest.mark.asyncio
2828
async def test_buffer_strategy():
29-
buffer_strategy = BufferStrategy(look_back_size=5, window_size=10)
29+
buffer_strategy = BufferStrategy(buffer_context_size=5, buffer_chunk_size=10)
3030
streaming_handler = fake_streaming_handler()
3131

3232
expected_buffers = [
@@ -69,7 +69,7 @@ async def async_enumerate(aiterable, start=0):
6969

7070

7171
async def test_generate_chunk_str():
72-
buffer_strategy = BufferStrategy(look_back_size=5, window_size=10)
72+
buffer_strategy = BufferStrategy(buffer_context_size=5, buffer_chunk_size=10)
7373
buffer = ["chunk0", "chunk1", "chunk2", "chunk3", "chunk4", "chunk5"]
7474
current_index = 6
7575

tests/test_streaming.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def output_rails_streaming_config():
266266
"flows": {"self check output"},
267267
"streaming": {
268268
"enabled": True,
269-
"window_size": 4,
270-
"look_back_size": 2,
269+
"chunk_size": 4,
270+
"context_size": 2,
271271
"stream_first": False,
272272
},
273273
}
@@ -403,15 +403,9 @@ async def test_streaming_output_rails_blocked_at_first_call(
403403
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
404404

405405

406-
def _calculate_number_of_actions(M, W, N):
407-
"""
408-
M: input_length
409-
W: window_size
410-
N: look_back_size
411-
"""
412-
413-
if W <= N:
414-
raise ValueError("Window size must be greater than look-back size.")
415-
if M <= W:
406+
def _calculate_number_of_actions(input_length, chunk_size, context_size):
407+
if chunk_size <= context_size:
408+
raise ValueError("chunk_size must be greater than context_size.")
409+
if input_length <= chunk_size:
416410
return 1
417-
return math.ceil((M - N) / (W - N))
411+
return math.ceil((input_length - context_size) / (chunk_size - context_size))

0 commit comments

Comments
 (0)