@@ -32,6 +32,21 @@ class CodegateContextRetriever(PipelineStep):
32
32
the word "codegate" in the user message.
33
33
"""
34
34
35
+ def __init__ (
36
+ self ,
37
+ storage_engine : StorageEngine | None = None ,
38
+ package_extractor : PackageExtractor | None = None ,
39
+ ):
40
+ """
41
+ Initialize the CodegateContextRetriever with optional dependencies.
42
+
43
+ Args:
44
+ storage_engine: Optional StorageEngine instance for package searching
45
+ package_extractor: Optional PackageExtractor class for package extraction
46
+ """
47
+ self .storage_engine = storage_engine or StorageEngine ()
48
+ self .package_extractor = package_extractor or PackageExtractor
49
+
35
50
@property
36
51
def name (self ) -> str :
37
52
"""
@@ -80,9 +95,6 @@ async def process( # noqa: C901
80
95
return PipelineResult (request = request )
81
96
user_message , last_user_idx = last_message
82
97
83
- # Create storage engine object
84
- storage_engine = StorageEngine ()
85
-
86
98
# Extract any code snippets
87
99
extractor = MessageCodeExtractorFactory .create_snippet_extractor (context .client )
88
100
snippets = extractor .extract_snippets (user_message )
@@ -106,7 +118,7 @@ async def process( # noqa: C901
106
118
f"for language { snippet_language } in code snippets."
107
119
)
108
120
# Find bad packages in the snippets
109
- bad_snippet_packages = await storage_engine .search (
121
+ bad_snippet_packages = await self . storage_engine .search (
110
122
language = snippet_language , packages = snippet_packages
111
123
) # type: ignore
112
124
logger .info (f"Found { len (bad_snippet_packages )} bad packages in code snippets." )
@@ -122,7 +134,11 @@ async def process( # noqa: C901
122
134
collected_bad_packages = []
123
135
for item_message in filter (None , map (str .strip , split_messages )):
124
136
# Vector search to find bad packages
125
- bad_packages = await storage_engine .search (query = item_message , distance = 0.5 , limit = 100 )
137
+ bad_packages = await self .storage_engine .search (
138
+ query = item_message ,
139
+ distance = 0.5 ,
140
+ limit = 100 ,
141
+ )
126
142
if bad_packages and len (bad_packages ) > 0 :
127
143
collected_bad_packages .extend (bad_packages )
128
144
@@ -145,42 +161,36 @@ async def process( # noqa: C901
145
161
# perform replacement in all the messages starting from this index
146
162
messages = request .get_messages ()
147
163
filtered = itertools .dropwhile (lambda x : x [0 ] < last_user_idx , enumerate (messages ))
148
- if context .client != ClientType .OPEN_INTERPRETER :
149
- for i , message in filtered :
150
- message_str = "" .join ([
151
- txt
152
- for content in message .get_content ()
153
- for txt in content .get_text ()
154
- ])
155
- context_msg = message_str
156
- # Add the context to the last user message
157
- if context .client in [ClientType .CLINE , ClientType .KODU ]:
158
- match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
159
- if match :
160
- task_content = match .group (1 ) # Content within <task>...</task>
161
- rest_of_message = match .group (
162
- 2
163
- ).strip () # Content after </task>, if any
164
-
165
- # Embed the context into the task block
166
- updated_task_content = (
167
- f"<task>Context: { context_str } "
168
- + f"Query: { task_content .strip ()} </task>"
169
- )
170
-
171
- # Combine updated task content with the rest of the message
172
- context_msg = updated_task_content + rest_of_message
173
- else :
174
- context_msg = f"Context: { context_str } \n \n Query: { message_str } "
175
- content = next (message .get_content ())
176
- content .set_text (context_msg )
177
- logger .debug ("Final context message" , context_message = context_msg )
178
- else :
179
- # just add a message in the end
180
- new_request ["messages" ].append (
181
- {
182
- "content" : context_str ,
183
- "role" : "assistant" ,
184
- }
185
- )
164
+ for i , message in filtered :
165
+ message_str = ""
166
+ for content in message .get_content ():
167
+ txt = content .get_text ()
168
+ if not txt :
169
+ logger .debug (f"content has no text: { content } " )
170
+ continue
171
+ message_str += txt
172
+ context_msg = message_str
173
+ # Add the context to the last user message
174
+ if context .client in [ClientType .CLINE , ClientType .KODU ]:
175
+ match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
176
+ if match :
177
+ task_content = match .group (1 ) # Content within <task>...</task>
178
+ rest_of_message = match .group (
179
+ 2
180
+ ).strip () # Content after </task>, if any
181
+
182
+ # Embed the context into the task block
183
+ updated_task_content = (
184
+ f"<task>Context: { context_str } "
185
+ + f"Query: { task_content .strip ()} </task>"
186
+ )
187
+
188
+ # Combine updated task content with the rest of the message
189
+ context_msg = updated_task_content + rest_of_message
190
+ else :
191
+ context_msg = f"Context: { context_str } \n \n Query: { message_str } "
192
+ content = next (message .get_content ())
193
+ content .set_text (context_msg )
194
+ logger .debug ("Final context message" , context_message = context_msg )
195
+
186
196
return PipelineResult (request = request , context = context )
0 commit comments