Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 0c4497e

Browse files
jhrozekblkt
authored andcommitted
Amend the context replacer for the new data structures (#1081)
* Take the tool output into consideration when constructing the user block * Remove the not-needed OpenInterpreter special-case in the context retriever * Skip content that has no text in context retriever This is the case for Anthropic ToolUse messages. * Add a unit test for the context retriever * Re-enable the OpenInterpreter test case in test_messages_block.py
1 parent 7ced113 commit 0c4497e

File tree

6 files changed

+435
-107
lines changed

6 files changed

+435
-107
lines changed

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ class CodegateContextRetriever(PipelineStep):
3232
the word "codegate" in the user message.
3333
"""
3434

35+
def __init__(
36+
self,
37+
storage_engine: StorageEngine | None = None,
38+
package_extractor: PackageExtractor | None = None,
39+
):
40+
"""
41+
Initialize the CodegateContextRetriever with optional dependencies.
42+
43+
Args:
44+
storage_engine: Optional StorageEngine instance for package searching
45+
package_extractor: Optional PackageExtractor class for package extraction
46+
"""
47+
self.storage_engine = storage_engine or StorageEngine()
48+
self.package_extractor = package_extractor or PackageExtractor
49+
3550
@property
3651
def name(self) -> str:
3752
"""
@@ -73,9 +88,6 @@ async def process( # noqa: C901
7388
return PipelineResult(request=request)
7489
user_message, last_user_idx = last_message
7590

76-
# Create storage engine object
77-
storage_engine = StorageEngine()
78-
7991
# Extract any code snippets
8092
extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client)
8193
snippets = extractor.extract_snippets(user_message)
@@ -87,15 +99,15 @@ async def process( # noqa: C901
8799
snippet_packages = []
88100
for snippet in snippets:
89101
snippet_packages.extend(
90-
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
102+
self.package_extractor.extract_packages(snippet.code, snippet.language) # type: ignore
91103
)
92104

93105
logger.info(
94106
f"Found {len(snippet_packages)} packages "
95107
f"for language {snippet_language} in code snippets."
96108
)
97109
# Find bad packages in the snippets
98-
bad_snippet_packages = await storage_engine.search(
110+
bad_snippet_packages = await self.storage_engine.search(
99111
language=snippet_language, packages=snippet_packages
100112
) # type: ignore
101113
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
@@ -111,7 +123,11 @@ async def process( # noqa: C901
111123
collected_bad_packages = []
112124
for item_message in filter(None, map(str.strip, split_messages)):
113125
# Vector search to find bad packages
114-
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
126+
bad_packages = await self.storage_engine.search(
127+
query=item_message,
128+
distance=0.5,
129+
limit=100,
130+
)
115131
if bad_packages and len(bad_packages) > 0:
116132
collected_bad_packages.extend(bad_packages)
117133

@@ -134,42 +150,36 @@ async def process( # noqa: C901
134150
# perform replacement in all the messages starting from this index
135151
messages = request.get_messages()
136152
filtered = itertools.dropwhile(lambda x: x[0] < last_user_idx, enumerate(messages))
137-
if context.client != ClientType.OPEN_INTERPRETER:
138-
for i, message in filtered:
139-
message_str = "".join([
140-
txt
141-
for content in message.get_content()
142-
for txt in content.get_text()
143-
])
144-
context_msg = message_str
145-
# Add the context to the last user message
146-
if context.client in [ClientType.CLINE, ClientType.KODU]:
147-
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
148-
if match:
149-
task_content = match.group(1) # Content within <task>...</task>
150-
rest_of_message = match.group(
151-
2
152-
).strip() # Content after </task>, if any
153-
154-
# Embed the context into the task block
155-
updated_task_content = (
156-
f"<task>Context: {context_str}"
157-
+ f"Query: {task_content.strip()}</task>"
158-
)
159-
160-
# Combine updated task content with the rest of the message
161-
context_msg = updated_task_content + rest_of_message
162-
else:
163-
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
164-
content = next(message.get_content())
165-
content.set_text(context_msg)
166-
logger.debug("Final context message", context_message=context_msg)
167-
else:
168-
#  just add a message in the end
169-
new_request["messages"].append(
170-
{
171-
"content": context_str,
172-
"role": "assistant",
173-
}
174-
)
153+
for i, message in filtered:
154+
message_str = ""
155+
for content in message.get_content():
156+
txt = content.get_text()
157+
if not txt:
158+
logger.debug(f"content has no text: {content}")
159+
continue
160+
message_str += txt
161+
context_msg = message_str
162+
# Add the context to the last user message
163+
if context.client in [ClientType.CLINE, ClientType.KODU]:
164+
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
165+
if match:
166+
task_content = match.group(1) # Content within <task>...</task>
167+
rest_of_message = match.group(
168+
2
169+
).strip() # Content after </task>, if any
170+
171+
# Embed the context into the task block
172+
updated_task_content = (
173+
f"<task>Context: {context_str}"
174+
+ f"Query: {task_content.strip()}</task>"
175+
)
176+
177+
# Combine updated task content with the rest of the message
178+
context_msg = updated_task_content + rest_of_message
179+
else:
180+
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
181+
content = next(message.get_content())
182+
content.set_text(context_msg)
183+
logger.debug("Final context message", context_message=context_msg)
184+
175185
return PipelineResult(request=request, context=context)

src/codegate/types/anthropic/_request_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_text(self) -> Iterable[str]:
110110

111111
def set_text(self, text) -> None:
112112
if isinstance(self.content, str):
113-
self.content = txt
113+
self.content = text
114114
return
115115

116116
# should have been called on the content

src/codegate/types/ollama/_request_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
142142

143143
def last_user_block(self) -> Iterable[tuple[Message, int]]:
144144
for idx, msg in enumerate(reversed(self.messages)):
145-
if isinstance(msg, UserMessage):
146-
yield msg, len(self.messages) - 1 - idx
147-
elif isinstance(msg, (SystemMessage, ToolMessage)):
145+
if isinstance(msg, (UserMessage, ToolMessage)):
146+
yield msg, len(self.messages) - 1 - idx
147+
elif isinstance(msg, SystemMessage):
148148
# these can occur in the middle of a user block
149149
continue
150150
elif isinstance(msg, AssistantMessage):

src/codegate/types/openai/_request_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
339339

340340
def last_user_block(self) -> Iterable[tuple[Message, int]]:
341341
for idx, msg in enumerate(reversed(self.messages)):
342-
if isinstance(msg, UserMessage):
342+
if isinstance(msg, (UserMessage, ToolMessage)):
343343
yield msg, len(self.messages) - 1 - idx
344-
elif isinstance(msg, (SystemMessage, DeveloperMessage, ToolMessage)):
344+
elif isinstance(msg, (SystemMessage, DeveloperMessage)):
345345
# these can occur in the middle of a user block
346346
continue
347347
elif isinstance(msg, (AssistantMessage, FunctionMessage)):

0 commit comments

Comments
 (0)