Skip to content

Commit 2c858b1

Browse files
jhrozekblkt
authored andcommitted
Amend the context replacer for the new data structures (#1081)
* Take the tool output into consideration when constructing the user block * Remove the not-needed OpenInterpreter special-case in the context retriever * Skip content that has no text in context retriever This is the case for Anthropic ToolUse messages. * Add a unit test for the context retriever * Re-enable the OpenInterpreter test case in test_messages_block.py
1 parent 05f593d commit 2c858b1

File tree

6 files changed

+434
-106
lines changed

6 files changed

+434
-106
lines changed

src/codegate/pipeline/codegate_context_retriever/codegate.py

+53-43
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ class CodegateContextRetriever(PipelineStep):
3232
the word "codegate" in the user message.
3333
"""
3434

35+
def __init__(
36+
self,
37+
storage_engine: StorageEngine | None = None,
38+
package_extractor: PackageExtractor | None = None,
39+
):
40+
"""
41+
Initialize the CodegateContextRetriever with optional dependencies.
42+
43+
Args:
44+
storage_engine: Optional StorageEngine instance for package searching
45+
package_extractor: Optional PackageExtractor class for package extraction
46+
"""
47+
self.storage_engine = storage_engine or StorageEngine()
48+
self.package_extractor = package_extractor or PackageExtractor
49+
3550
@property
3651
def name(self) -> str:
3752
"""
@@ -80,9 +95,6 @@ async def process( # noqa: C901
8095
return PipelineResult(request=request)
8196
user_message, last_user_idx = last_message
8297

83-
# Create storage engine object
84-
storage_engine = StorageEngine()
85-
8698
# Extract any code snippets
8799
extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client)
88100
snippets = extractor.extract_snippets(user_message)
@@ -106,7 +118,7 @@ async def process( # noqa: C901
106118
f"for language {snippet_language} in code snippets."
107119
)
108120
# Find bad packages in the snippets
109-
bad_snippet_packages = await storage_engine.search(
121+
bad_snippet_packages = await self.storage_engine.search(
110122
language=snippet_language, packages=snippet_packages
111123
) # type: ignore
112124
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
@@ -122,7 +134,11 @@ async def process( # noqa: C901
122134
collected_bad_packages = []
123135
for item_message in filter(None, map(str.strip, split_messages)):
124136
# Vector search to find bad packages
125-
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
137+
bad_packages = await self.storage_engine.search(
138+
query=item_message,
139+
distance=0.5,
140+
limit=100,
141+
)
126142
if bad_packages and len(bad_packages) > 0:
127143
collected_bad_packages.extend(bad_packages)
128144

@@ -145,42 +161,36 @@ async def process( # noqa: C901
145161
# perform replacement in all the messages starting from this index
146162
messages = request.get_messages()
147163
filtered = itertools.dropwhile(lambda x: x[0] < last_user_idx, enumerate(messages))
148-
if context.client != ClientType.OPEN_INTERPRETER:
149-
for i, message in filtered:
150-
message_str = "".join([
151-
txt
152-
for content in message.get_content()
153-
for txt in content.get_text()
154-
])
155-
context_msg = message_str
156-
# Add the context to the last user message
157-
if context.client in [ClientType.CLINE, ClientType.KODU]:
158-
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
159-
if match:
160-
task_content = match.group(1) # Content within <task>...</task>
161-
rest_of_message = match.group(
162-
2
163-
).strip() # Content after </task>, if any
164-
165-
# Embed the context into the task block
166-
updated_task_content = (
167-
f"<task>Context: {context_str}"
168-
+ f"Query: {task_content.strip()}</task>"
169-
)
170-
171-
# Combine updated task content with the rest of the message
172-
context_msg = updated_task_content + rest_of_message
173-
else:
174-
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
175-
content = next(message.get_content())
176-
content.set_text(context_msg)
177-
logger.debug("Final context message", context_message=context_msg)
178-
else:
179-
#  just add a message in the end
180-
new_request["messages"].append(
181-
{
182-
"content": context_str,
183-
"role": "assistant",
184-
}
185-
)
164+
for i, message in filtered:
165+
message_str = ""
166+
for content in message.get_content():
167+
txt = content.get_text()
168+
if not txt:
169+
logger.debug(f"content has no text: {content}")
170+
continue
171+
message_str += txt
172+
context_msg = message_str
173+
# Add the context to the last user message
174+
if context.client in [ClientType.CLINE, ClientType.KODU]:
175+
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
176+
if match:
177+
task_content = match.group(1) # Content within <task>...</task>
178+
rest_of_message = match.group(
179+
2
180+
).strip() # Content after </task>, if any
181+
182+
# Embed the context into the task block
183+
updated_task_content = (
184+
f"<task>Context: {context_str}"
185+
+ f"Query: {task_content.strip()}</task>"
186+
)
187+
188+
# Combine updated task content with the rest of the message
189+
context_msg = updated_task_content + rest_of_message
190+
else:
191+
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
192+
content = next(message.get_content())
193+
content.set_text(context_msg)
194+
logger.debug("Final context message", context_message=context_msg)
195+
186196
return PipelineResult(request=request, context=context)

src/codegate/types/anthropic/_request_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_text(self) -> Iterable[str]:
110110

111111
def set_text(self, text) -> None:
112112
if isinstance(self.content, str):
113-
self.content = txt
113+
self.content = text
114114
return
115115

116116
# should have been called on the content

src/codegate/types/ollama/_request_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
142142

143143
def last_user_block(self) -> Iterable[tuple[Message, int]]:
144144
for idx, msg in enumerate(reversed(self.messages)):
145-
if isinstance(msg, UserMessage):
146-
yield msg, len(self.messages) - 1 - idx
147-
elif isinstance(msg, (SystemMessage, ToolMessage)):
145+
if isinstance(msg, (UserMessage, ToolMessage)):
146+
yield msg, len(self.messages) - 1 - idx
147+
elif isinstance(msg, SystemMessage):
148148
# these can occur in the middle of a user block
149149
continue
150150
elif isinstance(msg, AssistantMessage):

src/codegate/types/openai/_request_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def last_user_message(self) -> tuple[Message, int] | None:
339339

340340
def last_user_block(self) -> Iterable[tuple[Message, int]]:
341341
for idx, msg in enumerate(reversed(self.messages)):
342-
if isinstance(msg, UserMessage):
342+
if isinstance(msg, (UserMessage, ToolMessage)):
343343
yield msg, len(self.messages) - 1 - idx
344-
elif isinstance(msg, (SystemMessage, DeveloperMessage, ToolMessage)):
344+
elif isinstance(msg, (SystemMessage, DeveloperMessage)):
345345
# these can occur in the middle of a user block
346346
continue
347347
elif isinstance(msg, (AssistantMessage, FunctionMessage)):

0 commit comments

Comments
 (0)