Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 6f6fe62

Browse files
convert muxing rule match to coroutines
1 parent 2b4b13c commit 6f6fe62

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

src/codegate/muxing/rulematcher.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
6161
self._mux_rule = mux_rule
6262

6363
@abstractmethod
64-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
64+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
6565
"""Return True if the rule matches the thing_to_match."""
6666
pass
6767

@@ -97,7 +97,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
9797
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
9898
"""A catch all muxing rule matcher."""
9999

100-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
100+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
101101
logger.info("Catch all rule matched")
102102
return True
103103

@@ -132,7 +132,7 @@ def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> b
132132
)
133133
return is_filename_match
134134

135-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
135+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
136136
"""
137137
Return True if the matcher is in one of the request filenames.
138138
"""
@@ -156,7 +156,7 @@ def _is_request_type_match(self, is_fim_request: bool) -> bool:
156156
return True
157157
return False
158158

159-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
159+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
160160
"""
161161
Return True if the matcher is in one of the request filenames and
162162
if the request type matches the MuxMatcherType.
@@ -194,7 +194,7 @@ def _get_user_messages_from_body(self, body: Dict) -> List[str]:
194194
user_messages.append(msgs_content)
195195
return user_messages
196196

197-
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
197+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
198198
"""
199199
Return True if the matcher is the persona description matched with the
200200
user messages.
@@ -204,14 +204,11 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
204204
return False
205205

206206
persona_manager = PersonaManager()
207-
is_persona_matched = persona_manager.check_persona_match(
207+
is_persona_matched = await persona_manager.check_persona_match(
208208
persona_name=self._mux_rule.matcher, queries=user_messages
209209
)
210-
logger.info(
211-
"Persona rule matched",
212-
matcher=self._mux_rule.matcher,
213-
is_persona_matched=is_persona_matched,
214-
)
210+
if is_persona_matched:
211+
logger.info("Persona rule matched", persona=self._mux_rule.matcher)
215212
return is_persona_matched
216213

217214

@@ -258,7 +255,7 @@ async def get_match_for_active_workspace(
258255
try:
259256
rules = await self.get_ws_rules(self._active_workspace)
260257
for rule in rules:
261-
if rule.match(thing_to_match):
258+
if await rule.match(thing_to_match):
262259
return rule.destination()
263260
return None
264261
except KeyError:

tests/muxing/test_rulematcher.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626

27+
@pytest.mark.asyncio
2728
@pytest.mark.parametrize(
2829
"matcher_blob, thing_to_match",
2930
[
@@ -40,12 +41,13 @@
4041
),
4142
],
4243
)
43-
def test_catch_all(matcher_blob, thing_to_match):
44+
async def test_catch_all(matcher_blob, thing_to_match):
4445
muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob)
4546
# It should always match
46-
assert muxing_rule_matcher.match(thing_to_match) is True
47+
assert await muxing_rule_matcher.match(thing_to_match) is True
4748

4849

50+
@pytest.mark.asyncio
4951
@pytest.mark.parametrize(
5052
"matcher, filenames_to_match, expected_bool",
5153
[
@@ -60,7 +62,7 @@ def test_catch_all(matcher_blob, thing_to_match):
6062
("*.ts", ["main.tsx", "test.tsx"], False), # Extension no match
6163
],
6264
)
63-
def test_file_matcher(
65+
async def test_file_matcher(
6466
matcher,
6567
filenames_to_match,
6668
expected_bool,
@@ -81,9 +83,10 @@ def test_file_matcher(
8183
is_fim_request=False,
8284
client_type="generic",
8385
)
84-
assert muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool
86+
assert await muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool
8587

8688

89+
@pytest.mark.asyncio
8790
@pytest.mark.parametrize(
8891
"matcher, filenames_to_match, expected_bool_filenames",
8992
[
@@ -107,7 +110,7 @@ def test_file_matcher(
107110
(True, "chat_filename", False), # No match
108111
],
109112
)
110-
def test_request_file_matcher(
113+
async def test_request_file_matcher(
111114
matcher,
112115
filenames_to_match,
113116
expected_bool_filenames,
@@ -143,7 +146,7 @@ def test_request_file_matcher(
143146
)
144147
is expected_bool_filenames
145148
)
146-
assert muxing_rule_matcher.match(mocked_thing_to_match) is (
149+
assert await muxing_rule_matcher.match(mocked_thing_to_match) is (
147150
expected_bool_request and expected_bool_filenames
148151
)
149152

@@ -155,6 +158,10 @@ def test_request_file_matcher(
155158
(mux_models.MuxMatcherType.filename_match, rulematcher.FileMuxingRuleMatcher),
156159
(mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
157160
(mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
161+
(
162+
mux_models.MuxMatcherType.persona_description,
163+
rulematcher.PersonaDescriptionMuxingRuleMatcher,
164+
),
158165
("invalid_matcher", None),
159166
],
160167
)

0 commit comments

Comments
 (0)