@@ -32,6 +32,21 @@ class CodegateContextRetriever(PipelineStep):
32
32
the word "codegate" in the user message.
33
33
"""
34
34
35
+ def __init__ (
36
+ self ,
37
+ storage_engine : StorageEngine | None = None ,
38
+ package_extractor : PackageExtractor | None = None ,
39
+ ):
40
+ """
41
+ Initialize the CodegateContextRetriever with optional dependencies.
42
+
43
+ Args:
44
+ storage_engine: Optional StorageEngine instance for package searching
45
+ package_extractor: Optional PackageExtractor class for package extraction
46
+ """
47
+ self .storage_engine = storage_engine or StorageEngine ()
48
+ self .package_extractor = package_extractor or PackageExtractor
49
+
35
50
@property
36
51
def name (self ) -> str :
37
52
"""
@@ -73,9 +88,6 @@ async def process( # noqa: C901
73
88
return PipelineResult (request = request )
74
89
user_message , last_user_idx = last_message
75
90
76
- # Create storage engine object
77
- storage_engine = StorageEngine ()
78
-
79
91
# Extract any code snippets
80
92
extractor = MessageCodeExtractorFactory .create_snippet_extractor (context .client )
81
93
snippets = extractor .extract_snippets (user_message )
@@ -87,15 +99,15 @@ async def process( # noqa: C901
87
99
snippet_packages = []
88
100
for snippet in snippets :
89
101
snippet_packages .extend (
90
- PackageExtractor .extract_packages (snippet .code , snippet .language ) # type: ignore
102
+ self . package_extractor .extract_packages (snippet .code , snippet .language ) # type: ignore
91
103
)
92
104
93
105
logger .info (
94
106
f"Found { len (snippet_packages )} packages "
95
107
f"for language { snippet_language } in code snippets."
96
108
)
97
109
# Find bad packages in the snippets
98
- bad_snippet_packages = await storage_engine .search (
110
+ bad_snippet_packages = await self . storage_engine .search (
99
111
language = snippet_language , packages = snippet_packages
100
112
) # type: ignore
101
113
logger .info (f"Found { len (bad_snippet_packages )} bad packages in code snippets." )
@@ -111,7 +123,11 @@ async def process( # noqa: C901
111
123
collected_bad_packages = []
112
124
for item_message in filter (None , map (str .strip , split_messages )):
113
125
# Vector search to find bad packages
114
- bad_packages = await storage_engine .search (query = item_message , distance = 0.5 , limit = 100 )
126
+ bad_packages = await self .storage_engine .search (
127
+ query = item_message ,
128
+ distance = 0.5 ,
129
+ limit = 100 ,
130
+ )
115
131
if bad_packages and len (bad_packages ) > 0 :
116
132
collected_bad_packages .extend (bad_packages )
117
133
@@ -134,42 +150,36 @@ async def process( # noqa: C901
134
150
# perform replacement in all the messages starting from this index
135
151
messages = request .get_messages ()
136
152
filtered = itertools .dropwhile (lambda x : x [0 ] < last_user_idx , enumerate (messages ))
137
- if context .client != ClientType .OPEN_INTERPRETER :
138
- for i , message in filtered :
139
- message_str = "" .join ([
140
- txt
141
- for content in message .get_content ()
142
- for txt in content .get_text ()
143
- ])
144
- context_msg = message_str
145
- # Add the context to the last user message
146
- if context .client in [ClientType .CLINE , ClientType .KODU ]:
147
- match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
148
- if match :
149
- task_content = match .group (1 ) # Content within <task>...</task>
150
- rest_of_message = match .group (
151
- 2
152
- ).strip () # Content after </task>, if any
153
-
154
- # Embed the context into the task block
155
- updated_task_content = (
156
- f"<task>Context: { context_str } "
157
- + f"Query: { task_content .strip ()} </task>"
158
- )
159
-
160
- # Combine updated task content with the rest of the message
161
- context_msg = updated_task_content + rest_of_message
162
- else :
163
- context_msg = f"Context: { context_str } \n \n Query: { message_str } "
164
- content = next (message .get_content ())
165
- content .set_text (context_msg )
166
- logger .debug ("Final context message" , context_message = context_msg )
167
- else :
168
- # just add a message in the end
169
- new_request ["messages" ].append (
170
- {
171
- "content" : context_str ,
172
- "role" : "assistant" ,
173
- }
174
- )
153
+ for i , message in filtered :
154
+ message_str = ""
155
+ for content in message .get_content ():
156
+ txt = content .get_text ()
157
+ if not txt :
158
+ logger .debug (f"content has no text: { content } " )
159
+ continue
160
+ message_str += txt
161
+ context_msg = message_str
162
+ # Add the context to the last user message
163
+ if context .client in [ClientType .CLINE , ClientType .KODU ]:
164
+ match = re .search (r"<task>\s*(.*?)\s*</task>(.*)" , message_str , re .DOTALL )
165
+ if match :
166
+ task_content = match .group (1 ) # Content within <task>...</task>
167
+ rest_of_message = match .group (
168
+ 2
169
+ ).strip () # Content after </task>, if any
170
+
171
+ # Embed the context into the task block
172
+ updated_task_content = (
173
+ f"<task>Context: { context_str } "
174
+ + f"Query: { task_content .strip ()} </task>"
175
+ )
176
+
177
+ # Combine updated task content with the rest of the message
178
+ context_msg = updated_task_content + rest_of_message
179
+ else :
180
+ context_msg = f"Context: { context_str } \n \n Query: { message_str } "
181
+ content = next (message .get_content ())
182
+ content .set_text (context_msg )
183
+ logger .debug ("Final context message" , context_message = context_msg )
184
+
175
185
return PipelineResult (request = request , context = context )
0 commit comments