Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 0bbdde3

Browse files
committed
Fixing tests.
1 parent a148a16 commit 0bbdde3

File tree

9 files changed

+122
-476
lines changed

9 files changed

+122
-476
lines changed

src/codegate/pipeline/system_prompt/codegate.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,15 @@ async def process(
9292
if not should_add_codegate_sys_prompt and not wrksp_custom_instructions:
9393
return PipelineResult(request=request, context=context)
9494

95-
request_system_message = {}
96-
req_sys_prompt = ""
97-
for sysprompt in request.get_system_prompt():
98-
req_sys_prompt = sysprompt
99-
10095
system_prompt = await self._construct_system_prompt(
10196
context.client,
10297
wrksp_custom_instructions,
103-
req_sys_prompt,
98+
"",
10499
should_add_codegate_sys_prompt,
105100
)
106101
context.add_alert(self.name, trigger_string=system_prompt)
107-
request.set_system_prompt(system_prompt)
102+
# NOTE: this was changed from adding more text to an existing
103+
# system prompt to potentially adding a new system prompt.
104+
request.add_system_prompt(system_prompt)
108105

109106
return PipelineResult(request=request, context=context)

src/codegate/types/ollama/_request_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class ChatRequest(pydantic.BaseModel):
120120
format: dict | None = None
121121
keep_alive: int | str | None = None
122122
tools: List[ToolDef] | None = None
123-
options: dict
123+
options: dict | None = None
124124

125125
def get_stream(self) -> bool:
126126
return self.stream
@@ -192,7 +192,7 @@ class GenerateRequest(pydantic.BaseModel):
192192
format: dict | None = None
193193
keep_alive: int | str | None = None
194194
images: List[bytes] | None = None
195-
options: dict
195+
options: dict | None = None
196196

197197
def get_stream(self) -> bool:
198198
return self.stream

tests/pipeline/extract_snippets/test_extract_snippets.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

tests/pipeline/pii/test_pi.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from unittest.mock import MagicMock, patch
22

33
import pytest
4-
from litellm import ChatCompletionRequest, ModelResponse
5-
from litellm.types.utils import Delta, StreamingChoices
64

75
from codegate.pipeline.base import PipelineContext
86
from codegate.pipeline.output import OutputPipelineContext
97
from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep
8+
from codegate.types.openai import (
9+
ChatCompletionRequest,
10+
ChoiceDelta,
11+
MessageDelta,
12+
StreamingChatCompletion,
13+
UserMessage,
14+
)
1015

1116

1217
class TestCodegatePii:
@@ -43,7 +48,7 @@ def test_get_redacted_snippet_with_pii(self, pii_step):
4348

4449
@pytest.mark.asyncio
4550
async def test_process_no_messages(self, pii_step):
46-
request = ChatCompletionRequest(model="test-model")
51+
request = ChatCompletionRequest(model="test-model", messages=[])
4752
context = PipelineContext()
4853

4954
result = await pii_step.process(request, context)
@@ -55,7 +60,7 @@ async def test_process_no_messages(self, pii_step):
5560
async def test_process_with_pii(self, pii_step):
5661
original_text = "My email is [email protected]"
5762
request = ChatCompletionRequest(
58-
model="test-model", messages=[{"role": "user", "content": original_text}]
63+
model="test-model", messages=[UserMessage(role="user", content=original_text)]
5964
)
6065
context = PipelineContext()
6166

@@ -77,9 +82,10 @@ async def test_process_with_pii(self, pii_step):
7782
result = await pii_step.process(request, context)
7883

7984
# Verify the user message was anonymized
80-
user_messages = [m for m in result.request["messages"] if m["role"] == "user"]
85+
user_messages = [m for m in result.request.get_messages() if isinstance(m, UserMessage)]
8186
assert len(user_messages) == 1
82-
assert user_messages[0]["content"] == anonymized_text
87+
content = next(user_messages[0].get_content())
88+
assert content.get_text() == anonymized_text
8389

8490
# Verify metadata was updated
8591
assert result.context.metadata["redacted_pii_count"] == 1
@@ -89,9 +95,9 @@ async def test_process_with_pii(self, pii_step):
8995
assert "pii_manager" in result.context.metadata
9096

9197
# Verify system message was added
92-
system_messages = [m for m in result.request["messages"] if m["role"] == "system"]
98+
system_messages = [m for m in result.request.get_system_prompt()]
9399
assert len(system_messages) == 1
94-
assert system_messages[0]["content"] == "PII has been redacted"
100+
assert system_messages[0] == "PII has been redacted"
95101

96102
def test_restore_pii(self, pii_step):
97103
anonymized_text = "My email is <test-uuid>"
@@ -121,11 +127,11 @@ def test_is_complete_uuid_invalid(self, unredaction_step):
121127

122128
@pytest.mark.asyncio
123129
async def test_process_chunk_no_content(self, unredaction_step):
124-
chunk = ModelResponse(
130+
chunk = StreamingChatCompletion(
125131
id="test",
126132
choices=[
127-
StreamingChoices(
128-
finish_reason=None, index=0, delta=Delta(content=None), logprobs=None
133+
ChoiceDelta(
134+
finish_reason=None, index=0, delta=MessageDelta(content=None), logprobs=None
129135
)
130136
],
131137
created=1234567890,
@@ -142,13 +148,13 @@ async def test_process_chunk_no_content(self, unredaction_step):
142148
@pytest.mark.asyncio
143149
async def test_process_chunk_with_uuid(self, unredaction_step):
144150
uuid = "12345678-1234-1234-1234-123456789012"
145-
chunk = ModelResponse(
151+
chunk = StreamingChatCompletion(
146152
id="test",
147153
choices=[
148-
StreamingChoices(
154+
ChoiceDelta(
149155
finish_reason=None,
150156
index=0,
151-
delta=Delta(content=f"Text with <{uuid}>"),
157+
delta=MessageDelta(content=f"Text with <{uuid}>"),
152158
logprobs=None,
153159
)
154160
],
@@ -168,6 +174,7 @@ async def test_process_chunk_with_uuid(self, unredaction_step):
168174

169175
result = await unredaction_step.process_chunk(chunk, context, input_context)
170176

177+
# TODO this should use the abstract interface
171178
assert result[0].choices[0].delta.content == "Text with [email protected]"
172179

173180

@@ -199,11 +206,11 @@ def test_format_pii_summary_multiple(self, notifier):
199206

200207
@pytest.mark.asyncio
201208
async def test_process_chunk_no_pii(self, notifier):
202-
chunk = ModelResponse(
209+
chunk = StreamingChatCompletion(
203210
id="test",
204211
choices=[
205-
StreamingChoices(
206-
finish_reason=None, index=0, delta=Delta(content="Hello"), logprobs=None
212+
ChoiceDelta(
213+
finish_reason=None, index=0, delta=MessageDelta(content="Hello"), logprobs=None
207214
)
208215
],
209216
created=1234567890,
@@ -219,13 +226,13 @@ async def test_process_chunk_no_pii(self, notifier):
219226

220227
@pytest.mark.asyncio
221228
async def test_process_chunk_with_pii(self, notifier):
222-
chunk = ModelResponse(
229+
chunk = StreamingChatCompletion(
223230
id="test",
224231
choices=[
225-
StreamingChoices(
232+
ChoiceDelta(
226233
finish_reason=None,
227234
index=0,
228-
delta=Delta(content="Hello", role="assistant"),
235+
delta=MessageDelta(content="Hello", role="assistant"),
229236
logprobs=None,
230237
)
231238
],
@@ -244,6 +251,7 @@ async def test_process_chunk_with_pii(self, notifier):
244251
result = await notifier.process_chunk(chunk, context, input_context)
245252

246253
assert len(result) == 2 # Notification chunk + original chunk
254+
# TODO this should use the abstract interface
247255
notification_content = result[0].choices[0].delta.content
248256
assert "CodeGate protected" in notification_content
249257
assert "1 email address" in notification_content

0 commit comments

Comments
 (0)