Skip to content

Commit dde0570

Browse files
authored
feat: introduce per client prompts and use the one in kodu (#963)
* feat: introduce per client prompts and use the one in kodu Using the client detection functionality, expose the ability to send customized system prompts per client, that will add more control to the result we provide to our tools * fix system prompt
1 parent 53f33a2 commit dde0570

File tree

5 files changed

+36
-12
lines changed

5 files changed

+36
-12
lines changed

prompts/default.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ red_team: "You are a red team member conducting a security assessment. Identify
5050

5151
# BlueTeam prompts
5252
blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities."
53+
54+
# Per client prompts
55+
client_prompts:
56+
kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in <attempt_completion><result> tags"

src/codegate/pipeline/factory.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr
2929
CodegateSecrets(),
3030
CodegateCli(),
3131
CodegateContextRetriever(),
32-
SystemPrompt(Config.get_config().prompts.default_chat),
32+
SystemPrompt(
33+
Config.get_config().prompts.default_chat, Config.get_config().prompts.client_prompts
34+
),
3335
]
3436
return SequentialPipelineProcessor(
3537
input_steps,

src/codegate/pipeline/system_prompt/codegate.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
from codegate.clients.clients import ClientType
34
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
45

56
from codegate.pipeline.base import (
@@ -16,8 +17,9 @@ class SystemPrompt(PipelineStep):
1617
the word "codegate" in the user message.
1718
"""
1819

19-
def __init__(self, system_prompt: str):
20+
def __init__(self, system_prompt: str, client_prompts: dict[str]):
2021
self.codegate_system_prompt = system_prompt
22+
self.client_prompts = client_prompts
2123

2224
@property
2325
def name(self) -> str:
@@ -36,6 +38,7 @@ async def _get_workspace_custom_instructions(self) -> str:
3638

3739
async def _construct_system_prompt(
3840
self,
41+
client: ClientType,
3942
wrksp_custom_instr: str,
4043
req_sys_prompt: Optional[str],
4144
should_add_codegate_sys_prompt: bool,
@@ -59,6 +62,10 @@ def _start_or_append(existing_prompt: str, new_prompt: str) -> str:
5962
if req_sys_prompt and "codegate" not in req_sys_prompt.lower():
6063
system_prompt = _start_or_append(system_prompt, req_sys_prompt)
6164

65+
# Add per client system prompt
66+
if client and client.value in self.client_prompts:
67+
system_prompt = _start_or_append(system_prompt, self.client_prompts[client.value])
68+
6269
return system_prompt
6370

6471
async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool:
@@ -92,7 +99,10 @@ async def process(
9299
req_sys_prompt = request_system_message.get("content")
93100

94101
system_prompt = await self._construct_system_prompt(
95-
wrksp_custom_instructions, req_sys_prompt, should_add_codegate_sys_prompt
102+
context.client,
103+
wrksp_custom_instructions,
104+
req_sys_prompt,
105+
should_add_codegate_sys_prompt,
96106
)
97107
context.add_alert(self.name, trigger_string=system_prompt)
98108
if not request_system_message:

src/codegate/prompts.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,19 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig":
4141
if not isinstance(prompt_data, dict):
4242
raise ConfigurationError("Prompts file must contain a YAML dictionary")
4343

44-
# Validate all values are strings
45-
for key, value in prompt_data.items():
46-
if not isinstance(value, str):
47-
raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}")
48-
44+
def validate_prompts(data, parent_key=""):
45+
"""Recursively validate prompt values."""
46+
for key, value in data.items():
47+
full_key = f"{parent_key}.{key}" if parent_key else key
48+
if isinstance(value, dict):
49+
validate_prompts(value, full_key) # Recurse into nested dictionaries
50+
elif not isinstance(value, str):
51+
raise ConfigurationError(
52+
f"Prompt '{full_key}' must be a string, got {type(value)}"
53+
)
54+
55+
# Validate the entire structure
56+
validate_prompts(prompt_data)
4957
return cls(prompts=prompt_data)
5058
except yaml.YAMLError as e:
5159
raise ConfigurationError(f"Failed to parse prompts file: {e}")

tests/pipeline/system_prompt/test_system_prompt.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_init_with_system_message(self):
1313
Test initialization with a system message
1414
"""
1515
test_message = "Test system prompt"
16-
step = SystemPrompt(system_prompt=test_message)
16+
step = SystemPrompt(system_prompt=test_message, client_prompts={})
1717
assert step.codegate_system_prompt == test_message
1818

1919
@pytest.mark.asyncio
@@ -28,7 +28,7 @@ async def test_process_system_prompt_insertion(self):
2828

2929
# Create system prompt step
3030
system_prompt = "Security analysis system prompt"
31-
step = SystemPrompt(system_prompt=system_prompt)
31+
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
3232
step._get_workspace_custom_instructions = AsyncMock(return_value="")
3333

3434
# Mock the get_last_user_message method
@@ -62,7 +62,7 @@ async def test_process_system_prompt_update(self):
6262

6363
# Create system prompt step
6464
system_prompt = "Security analysis system prompt"
65-
step = SystemPrompt(system_prompt=system_prompt)
65+
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
6666
step._get_workspace_custom_instructions = AsyncMock(return_value="")
6767

6868
# Mock the get_last_user_message method
@@ -97,7 +97,7 @@ async def test_edge_cases(self, edge_case):
9797
mock_context = Mock(spec=PipelineContext)
9898

9999
system_prompt = "Security edge case prompt"
100-
step = SystemPrompt(system_prompt=system_prompt)
100+
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
101101
step._get_workspace_custom_instructions = AsyncMock(return_value="")
102102

103103
# Mock get_last_user_message to return None

0 commit comments

Comments
 (0)