1
1
from unittest .mock import MagicMock , patch
2
2
3
3
import pytest
4
- from litellm import ChatCompletionRequest , ModelResponse
5
- from litellm .types .utils import Delta , StreamingChoices
6
4
7
5
from codegate .pipeline .base import PipelineContext
8
6
from codegate .pipeline .output import OutputPipelineContext
9
7
from codegate .pipeline .pii .pii import CodegatePii , PiiRedactionNotifier , PiiUnRedactionStep
8
+ from codegate .types .openai import (
9
+ ChatCompletionRequest ,
10
+ ChoiceDelta ,
11
+ MessageDelta ,
12
+ StreamingChatCompletion ,
13
+ UserMessage ,
14
+ )
10
15
11
16
12
17
class TestCodegatePii :
@@ -43,7 +48,7 @@ def test_get_redacted_snippet_with_pii(self, pii_step):
43
48
44
49
@pytest .mark .asyncio
45
50
async def test_process_no_messages (self , pii_step ):
46
- request = ChatCompletionRequest (model = "test-model" )
51
+ request = ChatCompletionRequest (model = "test-model" , messages = [] )
47
52
context = PipelineContext ()
48
53
49
54
result = await pii_step .process (request , context )
@@ -55,7 +60,7 @@ async def test_process_no_messages(self, pii_step):
55
60
async def test_process_with_pii (self , pii_step ):
56
61
original_text = "My email is [email protected] "
57
62
request = ChatCompletionRequest (
58
- model = "test-model" , messages = [{ " role" : " user" , " content" : original_text } ]
63
+ model = "test-model" , messages = [UserMessage ( role = " user" , content = original_text ) ]
59
64
)
60
65
context = PipelineContext ()
61
66
@@ -77,9 +82,10 @@ async def test_process_with_pii(self, pii_step):
77
82
result = await pii_step .process (request , context )
78
83
79
84
# Verify the user message was anonymized
80
- user_messages = [m for m in result .request [ "messages" ] if m [ "role" ] == "user" ]
85
+ user_messages = [m for m in result .request . get_messages () if isinstance ( m , UserMessage ) ]
81
86
assert len (user_messages ) == 1
82
- assert user_messages [0 ]["content" ] == anonymized_text
87
+ content = next (user_messages [0 ].get_content ())
88
+ assert content .get_text () == anonymized_text
83
89
84
90
# Verify metadata was updated
85
91
assert result .context .metadata ["redacted_pii_count" ] == 1
@@ -89,9 +95,9 @@ async def test_process_with_pii(self, pii_step):
89
95
assert "pii_manager" in result .context .metadata
90
96
91
97
# Verify system message was added
92
- system_messages = [m for m in result .request [ "messages" ] if m [ "role" ] == "system" ]
98
+ system_messages = [m for m in result .request . get_system_prompt () ]
93
99
assert len (system_messages ) == 1
94
- assert system_messages [0 ][ "content" ] == "PII has been redacted"
100
+ assert system_messages [0 ] == "PII has been redacted"
95
101
96
102
def test_restore_pii (self , pii_step ):
97
103
anonymized_text = "My email is <test-uuid>"
@@ -121,11 +127,11 @@ def test_is_complete_uuid_invalid(self, unredaction_step):
121
127
122
128
@pytest .mark .asyncio
123
129
async def test_process_chunk_no_content (self , unredaction_step ):
124
- chunk = ModelResponse (
130
+ chunk = StreamingChatCompletion (
125
131
id = "test" ,
126
132
choices = [
127
- StreamingChoices (
128
- finish_reason = None , index = 0 , delta = Delta (content = None ), logprobs = None
133
+ ChoiceDelta (
134
+ finish_reason = None , index = 0 , delta = MessageDelta (content = None ), logprobs = None
129
135
)
130
136
],
131
137
created = 1234567890 ,
@@ -142,13 +148,13 @@ async def test_process_chunk_no_content(self, unredaction_step):
142
148
@pytest .mark .asyncio
143
149
async def test_process_chunk_with_uuid (self , unredaction_step ):
144
150
uuid = "12345678-1234-1234-1234-123456789012"
145
- chunk = ModelResponse (
151
+ chunk = StreamingChatCompletion (
146
152
id = "test" ,
147
153
choices = [
148
- StreamingChoices (
154
+ ChoiceDelta (
149
155
finish_reason = None ,
150
156
index = 0 ,
151
- delta = Delta (content = f"Text with <{ uuid } >" ),
157
+ delta = MessageDelta (content = f"Text with <{ uuid } >" ),
152
158
logprobs = None ,
153
159
)
154
160
],
@@ -168,6 +174,7 @@ async def test_process_chunk_with_uuid(self, unredaction_step):
168
174
169
175
result = await unredaction_step .process_chunk (chunk , context , input_context )
170
176
177
+ # TODO this should use the abstract interface
171
178
assert result [
0 ].
choices [
0 ].
delta .
content == "Text with [email protected] "
172
179
173
180
@@ -199,11 +206,11 @@ def test_format_pii_summary_multiple(self, notifier):
199
206
200
207
@pytest .mark .asyncio
201
208
async def test_process_chunk_no_pii (self , notifier ):
202
- chunk = ModelResponse (
209
+ chunk = StreamingChatCompletion (
203
210
id = "test" ,
204
211
choices = [
205
- StreamingChoices (
206
- finish_reason = None , index = 0 , delta = Delta (content = "Hello" ), logprobs = None
212
+ ChoiceDelta (
213
+ finish_reason = None , index = 0 , delta = MessageDelta (content = "Hello" ), logprobs = None
207
214
)
208
215
],
209
216
created = 1234567890 ,
@@ -219,13 +226,13 @@ async def test_process_chunk_no_pii(self, notifier):
219
226
220
227
@pytest .mark .asyncio
221
228
async def test_process_chunk_with_pii (self , notifier ):
222
- chunk = ModelResponse (
229
+ chunk = StreamingChatCompletion (
223
230
id = "test" ,
224
231
choices = [
225
- StreamingChoices (
232
+ ChoiceDelta (
226
233
finish_reason = None ,
227
234
index = 0 ,
228
- delta = Delta (content = "Hello" , role = "assistant" ),
235
+ delta = MessageDelta (content = "Hello" , role = "assistant" ),
229
236
logprobs = None ,
230
237
)
231
238
],
@@ -244,6 +251,7 @@ async def test_process_chunk_with_pii(self, notifier):
244
251
result = await notifier .process_chunk (chunk , context , input_context )
245
252
246
253
assert len (result ) == 2 # Notification chunk + original chunk
254
+ # TODO this should use the abstract interface
247
255
notification_content = result [0 ].choices [0 ].delta .content
248
256
assert "CodeGate protected" in notification_content
249
257
assert "1 email address" in notification_content
0 commit comments