@@ -26,6 +26,21 @@ class CodegateContextRetriever(PipelineStep):
26
26
the word "codegate" in the user message.
27
27
"""
28
28
29
+ def __init__ (
30
+ self ,
31
+ storage_engine : StorageEngine | None = None ,
32
+ package_extractor : PackageExtractor | None = None ,
33
+ ):
34
+ """
35
+ Initialize the CodegateContextRetriever with optional dependencies.
36
+
37
+ Args:
38
+ storage_engine: Optional StorageEngine instance for package searching
39
+ package_extractor: Optional PackageExtractor class for package extraction
40
+ """
41
+ self .storage_engine = storage_engine or StorageEngine ()
42
+ self .package_extractor = package_extractor or PackageExtractor
43
+
29
44
@property
30
45
def name (self ) -> str :
31
46
"""
@@ -67,9 +82,6 @@ async def process( # noqa: C901
67
82
return PipelineResult (request = request )
68
83
user_message , last_user_idx = last_message
69
84
70
- # Create storage engine object
71
- storage_engine = StorageEngine ()
72
-
73
85
# Extract any code snippets
74
86
extractor = MessageCodeExtractorFactory .create_snippet_extractor (context .client )
75
87
snippets = extractor .extract_snippets (user_message )
@@ -81,15 +93,15 @@ async def process( # noqa: C901
81
93
snippet_packages = []
82
94
for snippet in snippets :
83
95
snippet_packages .extend (
84
- PackageExtractor .extract_packages (snippet .code , snippet .language ) # type: ignore
96
+ self . package_extractor .extract_packages (snippet .code , snippet .language ) # type: ignore
85
97
)
86
98
87
99
logger .info (
88
100
f"Found { len (snippet_packages )} packages "
89
101
f"for language { snippet_language } in code snippets."
90
102
)
91
103
# Find bad packages in the snippets
92
- bad_snippet_packages = await storage_engine .search (
104
+ bad_snippet_packages = await self . storage_engine .search (
93
105
language = snippet_language , packages = snippet_packages
94
106
) # type: ignore
95
107
logger .info (f"Found { len (bad_snippet_packages )} bad packages in code snippets." )
@@ -107,7 +119,11 @@ async def process( # noqa: C901
107
119
collected_bad_packages = []
108
120
for item_message in filter (None , map (str .strip , split_messages )):
109
121
# Vector search to find bad packages
110
- bad_packages = await storage_engine .search (query = item_message , distance = 0.5 , limit = 100 )
122
+ bad_packages = await self .storage_engine .search (
123
+ query = item_message ,
124
+ distance = 0.5 ,
125
+ limit = 100 ,
126
+ )
111
127
if bad_packages and len (bad_packages ) > 0 :
112
128
collected_bad_packages .extend (bad_packages )
113
129
@@ -130,42 +146,36 @@ async def process( # noqa: C901
130
146
# perform replacement in all the messages starting from this index
131
147
messages = request .get_messages ()
132
148
filtered = itertools .dropwhile (lambda x : x [0 ] < last_user_idx , enumerate (messages ))
133
- if context .client != ClientType .OPEN_INTERPRETER :
134
- for i , message in filtered :
135
- message_str = "" .join ([
136
- txt
137
- for content in message .get_content ()
138
- for txt in content .get_text ()
139
- ])
140
- context_msg = message_str
141
- # Add the context to the last user message
142
- if context .client in [ClientType .CLINE , ClientType .KODU ]:
143
- match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
144
- if match :
145
- task_content = match .group (1 ) # Content within <task>...</task>
146
- rest_of_message = match .group (
147
- 2
148
- ).strip () # Content after </task>, if any
149
-
150
- # Embed the context into the task block
151
- updated_task_content = (
152
- f"<task>Context: { context_str } "
153
- + f"Query: { task_content .strip ()} </task>"
154
- )
155
-
156
- # Combine updated task content with the rest of the message
157
- context_msg = updated_task_content + rest_of_message
158
- else :
159
- context_msg = f"Context: { context_str } \n \n Query: { message_str } "
160
- content = next (message .get_content ())
161
- content .set_text (context_msg )
162
- logger .debug ("Final context message" , context_message = context_msg )
163
- else :
164
- # just add a message in the end
165
- new_request ["messages" ].append (
166
- {
167
- "content" : context_str ,
168
- "role" : "assistant" ,
169
- }
170
- )
149
+ for i , message in filtered :
150
+ message_str = ""
151
+ for content in message .get_content ():
152
+ txt = content .get_text ()
153
+ if not txt :
154
+ logger .debug (f"content has no text: { content } " )
155
+ continue
156
+ message_str += txt
157
+ context_msg = message_str
158
+ # Add the context to the last user message
159
+ if context .client in [ClientType .CLINE , ClientType .KODU ]:
160
+ match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
161
+ if match :
162
+ task_content = match .group (1 ) # Content within <task>...</task>
163
+ rest_of_message = match .group (
164
+ 2
165
+ ).strip () # Content after </task>, if any
166
+
167
+ # Embed the context into the task block
168
+ updated_task_content = (
169
+ f"<task>Context: { context_str } "
170
+ + f"Query: { task_content .strip ()} </task>"
171
+ )
172
+
173
+ # Combine updated task content with the rest of the message
174
+ context_msg = updated_task_content + rest_of_message
175
+ else :
176
+ context_msg = f"Context: { context_str } \n \n Query: { message_str } "
177
+ content = next (message .get_content ())
178
+ content .set_text (context_msg )
179
+ logger .debug ("Final context message" , context_message = context_msg )
180
+
171
181
return PipelineResult (request = request , context = context )
0 commit comments