1
1
import itertools
2
2
from abc import abstractmethod
3
- from typing import List , Optional , Tuple
3
+ from typing import Any , List , Optional , Tuple
4
4
5
5
import regex as re
6
6
19
19
from codegate .pipeline .output import OutputPipelineContext , OutputPipelineStep
20
20
from codegate .pipeline .secrets .manager import SecretsManager
21
21
from codegate .pipeline .secrets .signatures import CodegateSignatures , Match
22
- from codegate .pipeline .systemmsg import add_or_update_system_message
23
- from codegate .types .common import (
24
- ChatCompletionRequest ,
25
- ChatCompletionSystemMessage ,
26
- Delta ,
27
- ModelResponse ,
28
- StreamingChoices ,
29
- )
30
22
31
23
32
24
logger = structlog .get_logger ("codegate" )
@@ -280,7 +272,7 @@ def _redact_text(
280
272
return text_encryptor .obfuscate (text , snippet )
281
273
282
274
async def process (
283
- self , request : ChatCompletionRequest , context : PipelineContext
275
+ self , reques : Any , context : PipelineContext
284
276
) -> PipelineResult :
285
277
"""
286
278
Process the request to find and protect secrets in all messages.
@@ -293,68 +285,34 @@ async def process(
293
285
PipelineResult containing the processed request and context with redaction metadata
294
286
"""
295
287
296
- ##### NEW CODE PATH #####
297
-
298
- if type (request ) != ChatCompletionRequest :
299
- secrets_manager = context .sensitive .manager
300
- if not secrets_manager or not isinstance (secrets_manager , SecretsManager ):
301
- raise ValueError ("Secrets manager not found in context" )
302
- session_id = context .sensitive .session_id
303
- if not session_id :
304
- raise ValueError ("Session ID not found in context" )
305
-
306
- total_matches = []
307
-
308
- # get last user message block to get index for the first relevant user message
309
- last_user_message = self .get_last_user_message_block (request , context .client )
310
- last_assistant_idx = last_user_message [1 ] - 1 if last_user_message else - 1
311
-
312
- # Process all messages
313
- for i , message in enumerate (request .get_messages ()):
314
- for content in message .get_content ():
315
- txt = content .get_text ()
316
- if txt is not None :
317
- redacted_content , secrets_matched = self ._redact_message_content (
318
- "" .join (txt for txt in content .get_text ()), secrets_manager , session_id , context
319
- )
320
- content .set_text (redacted_content )
321
- if i > last_assistant_idx :
322
- total_matches += secrets_matched
323
-
324
- # Not count repeated secret matches
325
- request = self ._finalize_redaction (context , total_matches , request )
326
- return PipelineResult (request = request , context = context )
327
-
328
- ##### OLD CODE PATH #####
329
-
330
- if "messages" not in request :
331
- return PipelineResult (request = request , context = context )
332
-
333
288
secrets_manager = context .sensitive .manager
334
289
if not secrets_manager or not isinstance (secrets_manager , SecretsManager ):
335
290
raise ValueError ("Secrets manager not found in context" )
336
291
session_id = context .sensitive .session_id
337
292
if not session_id :
338
293
raise ValueError ("Session ID not found in context" )
339
294
340
- new_request = request .copy ()
341
295
total_matches = []
342
296
343
297
# get last user message block to get index for the first relevant user message
344
- last_user_message = self .get_last_user_message_block (new_request , context .client )
298
+ last_user_message = self .get_last_user_message_block (request , context .client )
345
299
last_assistant_idx = last_user_message [1 ] - 1 if last_user_message else - 1
346
300
347
301
# Process all messages
348
- for i , message in enumerate (new_request ["messages" ]):
349
- if "content" in message and message ["content" ]:
350
- redacted_content , secrets_matched = self ._redact_message_content (
351
- message ["content" ], secrets_manager , session_id , context
352
- )
353
- new_request ["messages" ][i ]["content" ] = redacted_content
354
- if i > last_assistant_idx :
355
- total_matches += secrets_matched
356
- new_request = self ._finalize_redaction (context , total_matches , new_request )
357
- return PipelineResult (request = new_request , context = context )
302
+ for i , message in enumerate (request .get_messages ()):
303
+ for content in message .get_content ():
304
+ txt = content .get_text ()
305
+ if txt is not None :
306
+ redacted_content , secrets_matched = self ._redact_message_content (
307
+ "" .join (txt for txt in content .get_text ()), secrets_manager , session_id , context
308
+ )
309
+ content .set_text (redacted_content )
310
+ if i > last_assistant_idx :
311
+ total_matches += secrets_matched
312
+
313
+ # Not count repeated secret matches
314
+ request = self ._finalize_redaction (context , total_matches , request )
315
+ return PipelineResult (request = request , context = context )
358
316
359
317
def _redact_message_content (self , message_content , secrets_manager , session_id , context ):
360
318
# Extract any code snippets
@@ -404,14 +362,7 @@ def _finalize_redaction(self, context, total_matches, new_request):
404
362
logger .info (f"Total secrets redacted since last assistant message: { total_redacted } " )
405
363
context .metadata ["redacted_secrets_count" ] = total_redacted
406
364
if total_redacted > 0 :
407
- if isinstance (new_request , pydantic .BaseModel ):
408
- new_request .add_system_prompt (Config .get_config ().prompts .secrets_redacted )
409
- return new_request
410
- system_message = ChatCompletionSystemMessage (
411
- content = Config .get_config ().prompts .secrets_redacted ,
412
- role = "system" ,
413
- )
414
- return add_or_update_system_message (new_request , system_message , context )
365
+ new_request .add_system_prompt (Config .get_config ().prompts .secrets_redacted )
415
366
return new_request
416
367
417
368
@@ -449,10 +400,10 @@ def _find_complete_redaction(self, text: str) -> tuple[Optional[re.Match[str]],
449
400
450
401
async def process_chunk (
451
402
self ,
452
- chunk : ModelResponse ,
403
+ chunk : Any ,
453
404
context : OutputPipelineContext ,
454
405
input_context : Optional [PipelineContext ] = None ,
455
- ) -> list [ModelResponse ]:
406
+ ) -> list [Any ]:
456
407
"""Process a single chunk of the stream"""
457
408
if not input_context :
458
409
raise ValueError ("Input context not found" )
@@ -461,9 +412,6 @@ async def process_chunk(
461
412
if input_context .sensitive .session_id == "" :
462
413
raise ValueError ("Session ID not found in input context" )
463
414
464
- # if len(chunk.choices) == 0 or not chunk.choices[0].delta.content:
465
- # return [chunk]
466
-
467
415
for content in chunk .get_content ():
468
416
# Check the buffered content
469
417
buffered_content = "" .join (context .buffer )
@@ -518,37 +466,20 @@ class SecretRedactionNotifier(OutputPipelineStep):
518
466
def name (self ) -> str :
519
467
return "secret-redaction-notifier"
520
468
521
- def _create_chunk (self , original_chunk : ModelResponse , content : str ) -> ModelResponse :
469
+ def _create_chunk (self , original_chunk : Any , content : str ) -> Any :
522
470
"""
523
471
Creates a new chunk with the given content, preserving the original chunk's metadata
524
472
"""
525
- if isinstance (original_chunk , ModelResponse ):
526
- return ModelResponse (
527
- id = original_chunk .id ,
528
- choices = [
529
- StreamingChoices (
530
- finish_reason = None ,
531
- index = 0 ,
532
- delta = Delta (content = content , role = "assistant" ),
533
- logprobs = None ,
534
- )
535
- ],
536
- created = original_chunk .created ,
537
- model = original_chunk .model ,
538
- object = "chat.completion.chunk" ,
539
- )
540
- else :
541
- # TODO verify if deep-copy is necessary
542
- copy = original_chunk .model_copy (deep = True )
543
- copy .set_text (content )
544
- return copy
473
+ copy = original_chunk .model_copy (deep = True )
474
+ copy .set_text (content )
475
+ return copy
545
476
546
477
async def process_chunk (
547
478
self ,
548
- chunk : ModelResponse ,
479
+ chunk : Any ,
549
480
context : OutputPipelineContext ,
550
481
input_context : Optional [PipelineContext ] = None ,
551
- ) -> list [ModelResponse ]:
482
+ ) -> list [Any ]:
552
483
"""Process a single chunk of the stream"""
553
484
if (
554
485
not input_context
@@ -568,20 +499,21 @@ async def process_chunk(
568
499
)
569
500
570
501
# Check if this is the first chunk (delta role will be present, others will not)
571
- # if len(chunk.choices) > 0 and chunk.choices[0].delta.role:
572
502
for _ in itertools .takewhile (lambda x : x [0 ] == 1 , enumerate (chunk .get_content ())):
573
503
redacted_count = input_context .metadata ["redacted_secrets_count" ]
574
504
secret_text = "secret" if redacted_count == 1 else "secrets"
575
505
# Create notification chunk
576
506
if tool_name in ["cline" , "kodu" ]:
507
+ # NOTE: Original code was ensuring that role was
508
+ # "assistant" here, we might have to do that as well,
509
+ # but I believe it was defensive programming or
510
+ # leftover of some refactoring.
577
511
notification_chunk = self ._create_chunk (
578
512
chunk ,
579
513
f"<thinking>\n 🛡️ [CodeGate prevented { redacted_count } { secret_text } ]"
580
514
f"(http://localhost:9090/?search=codegate-secrets) from being leaked "
581
515
f"by redacting them.</thinking>\n \n " ,
582
516
)
583
- # TODO fix this
584
- # notification_chunk.choices[0].delta.role = "assistant"
585
517
else :
586
518
notification_chunk = self ._create_chunk (
587
519
chunk ,
0 commit comments