Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit e3aefe3

Browse files
jhrozekblkt
authored andcommitted
Amend the context replacer for the new data structures (#1081)
* Take the tool output into consideration when constructing the user block * Remove the not-needed OpenInterpreter special-case in the context retriever * Skip content that has no text in context retriever This is the case for Anthropic ToolUse messages. * Add a unit test for the context retriever * Re-enable the OpenInterpreter test case in test_messages_block.py
1 parent 4dc5566 commit e3aefe3

File tree

6 files changed

+435
-107
lines changed

6 files changed

+435
-107
lines changed

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ class CodegateContextRetriever(PipelineStep):
2626
the word "codegate" in the user message.
2727
"""
2828

29+
def __init__(
30+
self,
31+
storage_engine: StorageEngine | None = None,
32+
package_extractor: PackageExtractor | None = None,
33+
):
34+
"""
35+
Initialize the CodegateContextRetriever with optional dependencies.
36+
37+
Args:
38+
storage_engine: Optional StorageEngine instance for package searching
39+
package_extractor: Optional PackageExtractor class for package extraction
40+
"""
41+
self.storage_engine = storage_engine or StorageEngine()
42+
self.package_extractor = package_extractor or PackageExtractor
43+
2944
@property
3045
def name(self) -> str:
3146
"""
@@ -67,9 +82,6 @@ async def process( # noqa: C901
6782
return PipelineResult(request=request)
6883
user_message, last_user_idx = last_message
6984

70-
# Create storage engine object
71-
storage_engine = StorageEngine()
72-
7385
# Extract any code snippets
7486
extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client)
7587
snippets = extractor.extract_snippets(user_message)
@@ -81,15 +93,15 @@ async def process( # noqa: C901
8193
snippet_packages = []
8294
for snippet in snippets:
8395
snippet_packages.extend(
84-
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
96+
self.package_extractor.extract_packages(snippet.code, snippet.language) # type: ignore
8597
)
8698

8799
logger.info(
88100
f"Found {len(snippet_packages)} packages "
89101
f"for language {snippet_language} in code snippets."
90102
)
91103
# Find bad packages in the snippets
92-
bad_snippet_packages = await storage_engine.search(
104+
bad_snippet_packages = await self.storage_engine.search(
93105
language=snippet_language, packages=snippet_packages
94106
) # type: ignore
95107
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
@@ -107,7 +119,11 @@ async def process( # noqa: C901
107119
collected_bad_packages = []
108120
for item_message in filter(None, map(str.strip, split_messages)):
109121
# Vector search to find bad packages
110-
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
122+
bad_packages = await self.storage_engine.search(
123+
query=item_message,
124+
distance=0.5,
125+
limit=100,
126+
)
111127
if bad_packages and len(bad_packages) > 0:
112128
collected_bad_packages.extend(bad_packages)
113129

@@ -130,42 +146,36 @@ async def process( # noqa: C901
130146
# perform replacement in all the messages starting from this index
131147
messages = request.get_messages()
132148
filtered = itertools.dropwhile(lambda x: x[0] < last_user_idx, enumerate(messages))
133-
if context.client != ClientType.OPEN_INTERPRETER:
134-
for i, message in filtered:
135-
message_str = "".join([
136-
txt
137-
for content in message.get_content()
138-
for txt in content.get_text()
139-
])
140-
context_msg = message_str
141-
# Add the context to the last user message
142-
if context.client in [ClientType.CLINE, ClientType.KODU]:
143-
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
144-
if match:
145-
task_content = match.group(1) # Content within <task>...</task>
146-
rest_of_message = match.group(
147-
2
148-
).strip() # Content after </task>, if any
149-
150-
# Embed the context into the task block
151-
updated_task_content = (
152-
f"<task>Context: {context_str}"
153-
+ f"Query: {task_content.strip()}</task>"
154-
)
155-
156-
# Combine updated task content with the rest of the message
157-
context_msg = updated_task_content + rest_of_message
158-
else:
159-
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
160-
content = next(message.get_content())
161-
content.set_text(context_msg)
162-
logger.debug("Final context message", context_message=context_msg)
163-
else:
164-
#  just add a message in the end
165-
new_request["messages"].append(
166-
{
167-
"content": context_str,
168-
"role": "assistant",
169-
}
170-
)
149+
for i, message in filtered:
150+
message_str = ""
151+
for content in message.get_content():
152+
txt = content.get_text()
153+
if not txt:
154+
logger.debug(f"content has no text: {content}")
155+
continue
156+
message_str += txt
157+
context_msg = message_str
158+
# Add the context to the last user message
159+
if context.client in [ClientType.CLINE, ClientType.KODU]:
160+
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
161+
if match:
162+
task_content = match.group(1) # Content within <task>...</task>
163+
rest_of_message = match.group(
164+
2
165+
).strip() # Content after </task>, if any
166+
167+
# Embed the context into the task block
168+
updated_task_content = (
169+
f"<task>Context: {context_str}"
170+
+ f"Query: {task_content.strip()}</task>"
171+
)
172+
173+
# Combine updated task content with the rest of the message
174+
context_msg = updated_task_content + rest_of_message
175+
else:
176+
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
177+
content = next(message.get_content())
178+
content.set_text(context_msg)
179+
logger.debug("Final context message", context_message=context_msg)
180+
171181
return PipelineResult(request=request, context=context)

src/codegate/types/anthropic/_request_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_text(self) -> Iterable[str]:
110110

111111
def set_text(self, text) -> None:
112112
if isinstance(self.content, str):
113-
self.content = txt
113+
self.content = text
114114
return
115115

116116
# should have been called on the content

src/codegate/types/ollama/_request_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
142142

143143
def last_user_block(self) -> Iterable[tuple[Message, int]]:
144144
for idx, msg in enumerate(reversed(self.messages)):
145-
if isinstance(msg, UserMessage):
146-
yield msg, len(self.messages) - 1 - idx
147-
elif isinstance(msg, (SystemMessage, ToolMessage)):
145+
if isinstance(msg, (UserMessage, ToolMessage)):
146+
yield msg, len(self.messages) - 1 - idx
147+
elif isinstance(msg, SystemMessage):
148148
# these can occur in the middle of a user block
149149
continue
150150
elif isinstance(msg, AssistantMessage):

src/codegate/types/openai/_request_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
339339

340340
def last_user_block(self) -> Iterable[tuple[Message, int]]:
341341
for idx, msg in enumerate(reversed(self.messages)):
342-
if isinstance(msg, UserMessage):
342+
if isinstance(msg, (UserMessage, ToolMessage)):
343343
yield msg, len(self.messages) - 1 - idx
344-
elif isinstance(msg, (SystemMessage, DeveloperMessage, ToolMessage)):
344+
elif isinstance(msg, (SystemMessage, DeveloperMessage)):
345345
# these can occur in the middle of a user block
346346
continue
347347
elif isinstance(msg, (AssistantMessage, FunctionMessage)):

0 commit comments

Comments
 (0)