Skip to content

Commit de09cad

Browse files
jhrozeklukehinds
andauthored
Fix streaming output corruption with copilot (#1261)
* Shortcut other steps if one holds the output We had a bug in the output pipeline where if one step returned [] meaning that the output chunk should be held off, all the current chunks would continue to run and potentially modify the context. * Shortcut buffered PII sooner, as soon as the buffer can't be a UUID Our PII refaction format is `#UUID#`. Our code was finding an opening #, then checking for a closing matching # or end of the output. For copilot, however, this meant that we were buffering the whole file, because the filename comes in this format: ``` ``` This means we would keep searching for the closing hash which never came. Instead, buffer only as long as the context between the hashes can reasonably be a UUID. Fixes: #1250 --------- Co-authored-by: Luke Hinds <[email protected]>
1 parent cc8fd71 commit de09cad

File tree

3 files changed

+87
-2
lines changed

3 files changed

+87
-2
lines changed

src/codegate/pipeline/output.py

+2
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ async def process_stream(
153153
step_result = await step.process_chunk(
154154
c, self._context, self._input_context
155155
)
156+
if not step_result:
157+
break
156158
processed_chunks.extend(step_result)
157159

158160
current_chunks = processed_chunks

src/codegate/pipeline/pii/pii.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@
2020
logger = structlog.get_logger("codegate")
2121

2222

23+
def can_be_uuid(buffer):
24+
"""
25+
This is a way to check if a buffer can be a UUID. It aims to return as soon as possible
26+
meaning that we buffer as little as possible. This is important for performance reasons
27+
but also to make sure other steps don't wait too long as we don't buffer more than we need to.
28+
"""
29+
# UUID structure: 8-4-4-4-12 hex digits
30+
# Expected positions of hyphens
31+
hyphen_positions = {8, 13, 18, 23}
32+
33+
# Maximum length of a UUID
34+
max_uuid_length = 36
35+
36+
if buffer == "":
37+
return True
38+
39+
# If buffer is longer than a UUID, it can't be a UUID
40+
if len(buffer) > max_uuid_length:
41+
return False
42+
43+
for i, char in enumerate(buffer):
44+
# Check if hyphens are in the right positions
45+
if i in hyphen_positions:
46+
if char != "-":
47+
return False
48+
# Check if non-hyphen positions contain hex digits
49+
elif not (char.isdigit() or char.lower() in "abcdef"):
50+
return False
51+
52+
return True
53+
54+
2355
class CodegatePii(PipelineStep):
2456
"""
2557
CodegatePii is a pipeline step that handles the detection and redaction of PII
@@ -278,8 +310,13 @@ async def process_chunk( # noqa: C901
278310

279311
end_idx = content.find(self.marker_end, start_idx + 1)
280312
if end_idx == -1:
281-
# Incomplete marker, buffer the rest
282-
context.prefix_buffer = content[current_pos:]
313+
# Incomplete marker, buffer the rest only if it can be a UUID
314+
if start_idx + 1 < len(content) and not can_be_uuid(content[start_idx + 1 :]):
315+
# the buffer can't be a UUID, so we can't process it, just return
316+
result.append(content[current_pos:])
317+
else:
318+
# this can still be a UUID
319+
context.prefix_buffer = content[current_pos:]
283320
break
284321

285322
# Add text before marker

tests/pipeline/pii/test_pi.py

+46
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,52 @@ async def test_process_chunk_with_uuid(self, unredaction_step):
120120
result = await unredaction_step.process_chunk(chunk, context, input_context)
121121
assert result[0].choices[0].delta.content == "Text with [email protected]"
122122

123+
@pytest.mark.asyncio
124+
async def test_detect_not_an_uuid(self, unredaction_step):
125+
chunk1 = ModelResponse(
126+
id="test",
127+
choices=[
128+
StreamingChoices(
129+
finish_reason=None,
130+
index=0,
131+
delta=Delta(content="#"),
132+
logprobs=None,
133+
)
134+
],
135+
created=1234567890,
136+
model="test-model",
137+
object="chat.completion.chunk",
138+
)
139+
chunk2 = ModelResponse(
140+
id="test",
141+
choices=[
142+
StreamingChoices(
143+
finish_reason=None,
144+
index=0,
145+
delta=Delta(content=" filepath"),
146+
logprobs=None,
147+
)
148+
],
149+
created=1234567890,
150+
model="test-model",
151+
object="chat.completion.chunk",
152+
)
153+
154+
context = OutputPipelineContext()
155+
manager = SensitiveDataManager()
156+
sensitive = PipelineSensitiveData(manager=manager, session_id="session-id")
157+
input_context = PipelineContext(sensitive=sensitive)
158+
159+
# Mock PII manager in input context
160+
mock_sensitive_data_manager = MagicMock()
161+
mock_sensitive_data_manager.get_original_value = MagicMock(return_value="[email protected]")
162+
input_context.metadata["sensitive_data_manager"] = mock_sensitive_data_manager
163+
164+
result = await unredaction_step.process_chunk(chunk1, context, input_context)
165+
assert not result
166+
result = await unredaction_step.process_chunk(chunk2, context, input_context)
167+
assert result[0].choices[0].delta.content == "# filepath"
168+
123169

124170
class TestPiiRedactionNotifier:
125171
@pytest.fixture

0 commit comments

Comments
 (0)