Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit c4386d6

Browse files
committed
Fixed pipeline output tests.
1 parent 12c1445 commit c4386d6

File tree

2 files changed

+62
-29
lines changed

2 files changed

+62
-29
lines changed

src/codegate/pipeline/output.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,20 @@ async def process_stream(
187187
content=final_content,
188188
len=len(self._context.buffer),
189189
)
190-
# NOTE: Original code created chunks for all remaining
191-
# messages in `self._context.buffer`, but it looks
192-
# like it was defensive code. We should instead ensure
193-
# that no messages remain there at each step of the
194-
# pipeline in some way.
190+
191+
# NOTE: this block ensured that buffered chunks were
192+
# flushed at the end of the pipeline. This was
193+
# possible as long as the current implementation
194+
# assumed that all messages were equivalent and
195+
# position was not relevant.
196+
#
197+
# This is not the case for Anthropic, whose protocol
198+
# is much more structured than that of the others.
199+
#
200+
# We're not there yet to ensure that such a protocol
201+
# is not broken in face of messages being arbitrarily
202+
# retained at each pipeline step, so we decided to
203+
# treat a clogged pipelines as a bug.
195204
self._context.buffer.clear()
196205

197206
if finish_stream:

tests/pipeline/test_output.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
OutputPipelineStep,
1111
)
1212
from codegate.types.common import Delta, ModelResponse, StreamingChoices
13+
from codegate.types.openai import (
14+
ChatCompletionRequest,
15+
ChoiceDelta,
16+
MessageDelta,
17+
StreamingChatCompletion,
18+
)
1319

1420

1521
class MockOutputPipelineStep(OutputPipelineStep):
@@ -26,30 +32,37 @@ def name(self) -> str:
2632

2733
async def process_chunk(
2834
self,
29-
chunk: ModelResponse,
35+
chunk: StreamingChatCompletion,
3036
context: OutputPipelineContext,
3137
input_context: PipelineContext = None,
32-
) -> list[ModelResponse]:
38+
) -> list[StreamingChatCompletion]:
3339
if self._should_pause:
3440
return []
3541

36-
if self._modify_content and chunk.choices[0].delta.content:
42+
if next(chunk.get_content(), None) is None:
43+
return [chunk] # short-circuit
44+
45+
content = next(chunk.get_content())
46+
if content.get_text() is None or content.get_text() == "":
47+
return [chunk] # short-circuit
48+
49+
if self._modify_content:
3750
# Append step name to content to track modifications
38-
modified_content = f"{chunk.choices[0].delta.content}_{self.name}"
39-
chunk.choices[0].delta.content = modified_content
51+
modified_content = f"{content.get_text()}_{self.name}"
52+
content.set_text(modified_content)
4053

4154
return [chunk]
4255

4356

44-
def create_model_response(content: str, id: str = "test") -> ModelResponse:
45-
"""Helper to create test ModelResponse objects"""
46-
return ModelResponse(
57+
def create_model_response(content: str, id: str = "test") -> StreamingChatCompletion:
58+
"""Helper to create test StreamingChatCompletion objects"""
59+
return StreamingChatCompletion(
4760
id=id,
4861
choices=[
49-
StreamingChoices(
62+
ChoiceDelta(
5063
finish_reason=None,
5164
index=0,
52-
delta=Delta(content=content, role="assistant"),
65+
delta=MessageDelta(content=content, role="assistant"),
5366
logprobs=None,
5467
)
5568
],
@@ -64,7 +77,7 @@ class MockContext:
6477
def __init__(self):
6578
self.sensitive = False
6679

67-
def add_output(self, chunk: ModelResponse):
80+
def add_output(self, chunk: StreamingChatCompletion):
6881
pass
6982

7083

@@ -157,10 +170,23 @@ async def mock_stream():
157170
async for chunk in instance.process_stream(mock_stream()):
158171
chunks.append(chunk)
159172

173+
# NOTE: this test ensured that buffered chunks were flushed at
174+
# the end of the pipeline. This was possible as long as the
175+
# current implementation assumed that all messages were
176+
# equivalent and position was not relevant.
177+
#
178+
# This is not the case for Anthropic, whose protocol is much
179+
# more structured than that of the others.
180+
#
181+
# We're not there yet to ensure that such a protocol is not
182+
# broken in face of messages being arbitrarily retained at
183+
# each pipeline step, so we decided to treat a clogged
184+
# pipelines as a bug.
185+
160186
# Should get one chunk at the end with all buffered content
161-
assert len(chunks) == 1
187+
assert len(chunks) == 0
162188
# Content should be buffered and combined
163-
assert chunks[0].choices[0].delta.content == "hello world"
189+
# assert chunks[0].choices[0].delta.content == "hello world"
164190
# Buffer should be cleared after flush
165191
assert len(instance._context.buffer) == 0
166192

@@ -180,19 +206,19 @@ def name(self) -> str:
180206

181207
async def process_chunk(
182208
self,
183-
chunk: ModelResponse,
209+
chunk: StreamingChatCompletion,
184210
context: OutputPipelineContext,
185211
input_context: PipelineContext = None,
186-
) -> List[ModelResponse]:
212+
) -> List[StreamingChatCompletion]:
187213
# Replace 'world' with 'moon' in buffered content
188214
content = "".join(context.buffer)
189215
if "world" in content:
190216
content = content.replace("world", "moon")
191217
chunk.choices = [
192-
StreamingChoices(
218+
ChoiceDelta(
193219
finish_reason=None,
194220
index=0,
195-
delta=Delta(content=content, role="assistant"),
221+
delta=MessageDelta(content=content, role="assistant"),
196222
logprobs=None,
197223
)
198224
]
@@ -274,10 +300,10 @@ def name(self) -> str:
274300

275301
async def process_chunk(
276302
self,
277-
chunk: ModelResponse,
303+
chunk: StreamingChatCompletion,
278304
context: OutputPipelineContext,
279305
input_context: PipelineContext = None,
280-
) -> List[ModelResponse]:
306+
) -> List[StreamingChatCompletion]:
281307
assert input_context.metadata["test"] == "value"
282308
return [chunk]
283309

@@ -308,8 +334,6 @@ async def mock_stream():
308334
async for chunk in instance.process_stream(mock_stream()):
309335
chunks.append(chunk)
310336

311-
# Should get one chunk with combined buffer content
312-
assert len(chunks) == 1
313-
assert chunks[0].choices[0].delta.content == "HelloWorld"
314-
# Buffer should be cleared after flush
315-
assert len(instance._context.buffer) == 0
337+
# We do not flush messages anymore, this should be treated as
338+
# a bug of the pipeline rather than and edge case.
339+
assert len(chunks) == 0

0 commit comments

Comments
 (0)