Skip to content

Commit 387c57b

Browse files
committed
🍻 filter use propagator & depends
1 parent 73a57aa commit 387c57b

File tree

7 files changed

+164
-108
lines changed

7 files changed

+164
-108
lines changed

Diff for: arclet/entari/builtins/auto_reload.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,7 @@ def detect_filter_change(old: dict, new: dict):
4040
added = set(new) - set(old)
4141
removed = set(old) - set(new)
4242
changed = {key for key in set(new) & set(old) if new[key] != old[key]}
43-
if "$allow" in removed:
44-
allow = {}
45-
else:
46-
allow = new.get("$allow", {})
47-
if "$deny" in removed:
48-
deny = {}
49-
else:
50-
deny = new.get("$deny", {})
51-
return allow, deny, not ((added | removed | changed) - {"$allow", "$deny"})
43+
return "allow" in (added | removed | changed) or "$deny" in (added | removed | changed)
5244

5345

5446
class Watcher(Service):
@@ -106,7 +98,6 @@ async def watch_config(self):
10698
or Path(change[1]).resolve() in extra
10799
or Path(change[1]).resolve().parent in extra
108100
):
109-
print(change)
110101
continue
111102
logger.info(f"Detected change in {change[1]!r}, reloading config...")
112103

@@ -141,20 +132,17 @@ async def watch_config(self):
141132
old_conf = old_plugin[plugin_name]
142133
new_conf = EntariConfig.instance.plugin[plugin_name]
143134
if plugin := find_plugin(pid):
144-
allow, deny, only_filter = detect_filter_change(old_conf, new_conf)
145-
plugin.update_filter(allow, deny)
146-
if only_filter:
147-
logger.debug(f"Plugin <y>{pid!r}</y> config only changed filter.")
148-
continue
149-
res = await es.post(
150-
ConfigReload("plugin", plugin_name, new_conf, old_conf),
151-
)
152-
if res and res.value:
153-
logger.debug(f"Plugin <y>{pid!r}</y> config change handled by itself.")
154-
continue
135+
filter_changed = detect_filter_change(old_conf, new_conf)
136+
if not filter_changed:
137+
res = await es.post(
138+
ConfigReload("plugin", plugin_name, new_conf, old_conf),
139+
)
140+
if res and res.value:
141+
logger.debug(f"Plugin <y>{pid!r}</y> config change handled by itself.")
142+
continue
155143
logger.info(f"Detected config of <blue>{pid!r}</blue> changed, reloading...")
156144
plugin_file = str(plugin.module.__file__)
157-
unload_plugin(plugin_name)
145+
unload_plugin(pid)
158146
if plugin := load_plugin(plugin_name, new_conf):
159147
logger.info(f"Reloaded <blue>{plugin.id!r}</blue>")
160148
del plugin

Diff for: arclet/entari/filter/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..message import MessageChain
88
from ..session import Session
99
from .common import filter_ as filter_
10+
from .common import parse as parse
1011
from .message import direct_message
1112
from .message import notice_me as notice_me
1213
from .message import public_message
@@ -57,7 +58,7 @@ async def before(self, session: Optional[Session] = None):
5758
await session.send(self.limit_prompt)
5859
return STOP
5960

60-
async def after(self) -> Optional[bool]:
61+
async def after(self):
6162
self.semaphore.release()
6263

6364
def compose(self):

Diff for: arclet/entari/filter/common.py

+137-52
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,88 @@
11
from collections.abc import Awaitable
2+
from inspect import Parameter, Signature
23
from typing import Callable, Union
34
from typing_extensions import TypeAlias
45

5-
from arclet.letoderea import STOP
6-
from arclet.letoderea.handler import run_handler
7-
from arclet.letoderea.typing import run_sync
6+
from arclet.letoderea import STOP, Contexts, Depends, Propagator, propagate
7+
from arclet.letoderea.typing import Result, run_sync
8+
from satori import Channel, ChannelType, Guild, User
89
from tarina import is_coroutinefunction
910

10-
from ..event.base import MessageEvent
1111
from ..session import Session
12-
from .message import direct_message, notice_me, public_message, reply_me, to_me
1312

1413
_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]
15-
_sess_keys = {
16-
"user",
17-
"guild",
18-
"channel",
19-
"self",
20-
"platform",
14+
15+
16+
async def _direct_message(channel: Channel):
17+
return Result(channel.type == ChannelType.DIRECT)
18+
19+
20+
async def _public_message(channel: Channel):
21+
return Result(channel.type != ChannelType.DIRECT)
22+
23+
24+
async def _reply_me(is_reply_me: bool = False):
25+
return Result(is_reply_me)
26+
27+
28+
async def _notice_me(is_notice_me: bool = False):
29+
return Result(is_notice_me)
30+
31+
32+
async def _to_me(is_reply_me: bool = False, is_notice_me: bool = False):
33+
return Result(is_reply_me or is_notice_me)
34+
35+
36+
def _user(*ids: str):
37+
async def check_user(user: User):
38+
return Result(user.id in ids if ids else True)
39+
40+
return check_user
41+
42+
43+
def _channel(*ids: str):
44+
async def check_channel(channel: Channel):
45+
return Result(channel.id in ids if ids else True)
46+
47+
return check_channel
48+
49+
50+
def _guild(*ids: str):
51+
async def check_guild(guild: Guild):
52+
return Result(guild.id in ids if ids else True)
53+
54+
return check_guild
55+
56+
57+
def _account(*ids: str):
58+
async def check_account(session: Session):
59+
return Result(session.account.self_id in ids)
60+
61+
return check_account
62+
63+
64+
def _platform(*ids: str):
65+
async def check_platform(session: Session):
66+
return Result(session.account.platform in ids)
67+
68+
return check_platform
69+
70+
71+
_keys = {
72+
"user": (_user, 2),
73+
"guild": (_guild, 3),
74+
"channel": (_channel, 4),
75+
"self": (_account, 1),
76+
"platform": (_platform, 0),
2177
}
2278

23-
_message_keys = {
24-
"direct": direct_message,
25-
"private": direct_message,
26-
"public": public_message,
27-
"reply_me": reply_me,
28-
"notice_me": notice_me,
29-
"to_me": to_me,
79+
_mess_keys = {
80+
"direct": (_direct_message, 5),
81+
"private": (_direct_message, 5),
82+
"public": (_public_message, 6),
83+
"reply_me": (_reply_me, 7),
84+
"notice_me": (_notice_me, 8),
85+
"to_me": (_to_me, 9),
3086
}
3187

3288
_op_keys = {
@@ -54,46 +110,75 @@ async def _(session: Session):
54110
return _
55111

56112

113+
class _Filter(Propagator):
114+
def __init__(self):
115+
self.step: dict[int, Callable] = {}
116+
self.ops = []
117+
118+
def get_flow(self):
119+
if not self.step:
120+
return Depends(lambda: None)
121+
122+
steps = [slot[1] for slot in sorted(self.step.items(), key=lambda x: x[0])]
123+
124+
@propagate(*steps, prepend=True)
125+
async def flow(ctx: Contexts):
126+
if ctx.get("$result", False):
127+
return
128+
return STOP
129+
130+
return Depends(flow)
131+
132+
def generate(self):
133+
134+
if not self.ops:
135+
136+
async def check(res=self.get_flow()): # type: ignore
137+
return res
138+
139+
else:
140+
141+
async def check(**kwargs):
142+
res = kwargs["res"]
143+
for (op, _), res1 in zip(self.ops, list(kwargs.values())[1:]):
144+
if op == "and" and (res is None and res1 is None):
145+
continue
146+
if op == "or" and (res is None or res1 is None):
147+
res = None
148+
continue
149+
if op == "not" and (res is None and res1 is STOP):
150+
continue
151+
res = STOP
152+
return res
153+
154+
param = [Parameter("res", Parameter.POSITIONAL_OR_KEYWORD, default=self.get_flow())]
155+
for index, slot in enumerate(self.ops):
156+
param.append(
157+
Parameter(f"res_{index+1}", Parameter.POSITIONAL_OR_KEYWORD, default=Depends(slot[1].generate()))
158+
)
159+
check.__signature__ = Signature(param)
160+
161+
return check
162+
163+
def compose(self):
164+
yield self.generate(), True, 0
165+
166+
57167
def parse(patterns: PATTERNS):
58-
step: list[Callable[[Session], bool]] = []
59-
other = []
60-
ops = []
168+
f = _Filter()
169+
61170
for key, value in patterns.items():
62-
if key in _sess_keys:
63-
step.append(lambda session: getattr(session, key) in value)
64-
elif key in _message_keys:
65-
step.append(lambda session: isinstance(session.event, MessageEvent))
66-
other.append(_message_keys[key])
171+
if key in _keys:
172+
f.step[_keys[key][1]] = _keys[key][0](*value)
173+
elif key in _mess_keys:
174+
if value is True:
175+
f.step[_mess_keys[key][1]] = _mess_keys[key][0]
67176
elif key in _op_keys:
68177
op = _op_keys[key]
69178
if not isinstance(value, dict):
70179
raise ValueError(f"Expect a dict for operator {key}")
71-
ops.append((op, parse(value)))
180+
f.ops.append((op, parse(value)))
72181
else:
73182
raise ValueError(f"Unknown key: {key}")
74183

75-
async def f(session: Session):
76-
for i in step:
77-
if not i(session):
78-
return STOP
79-
for i in other:
80-
if not await run_handler(i, session.event):
81-
return STOP
82-
83-
if not ops:
84-
return f
85-
86-
async def _(session: Session):
87-
res = await f(session)
88-
89-
for op, f_ in ops:
90-
res1 = await f_(session)
91-
if op == "and" and (res is None and res1 is None):
92-
return
93-
if op == "or" and (res is None or res1 is None):
94-
return
95-
if op == "not" and (res is None and res1 is STOP):
96-
return
97-
return STOP
98-
99-
return _
184+
return f

Diff for: arclet/entari/logger.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def message(self):
3838

3939
def wrapper(self, name: str, color: str = "blue"):
4040
patched = logger.patch(
41-
lambda r: r.update(name="entari", extra=r["extra"] | {"entari_plugin_name": name, "entari_plugin_color": color})
41+
lambda r: r.update(
42+
name="entari", extra=r["extra"] | {"entari_plugin_name": name, "entari_plugin_color": color}
43+
)
4244
)
4345
patched = patched.bind(name=f"plugins.{name}")
4446
self.loggers[f"plugin.{name}"] = patched

Diff for: arclet/entari/plugin/model.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from launart import Launart, Service
1616
from tarina import ContextModel
1717

18-
from ..filter.common import parse
18+
from ..filter import parse
1919
from ..logger import log
2020
from .service import plugin_service
2121

@@ -216,20 +216,18 @@ def collect(self, *disposes: Callable[[], None]):
216216
self._dispose_callbacks.extend(disposes)
217217
return self
218218

219-
def update_filter(self, allow: dict, deny: dict):
220-
if not allow and not deny:
221-
return
219+
def __post_init__(self):
220+
self._scope = es.scope(self.id)
221+
plugin_service.plugins[self.id] = self
222+
allow = self.config.pop("$allow", {})
223+
deny = self.config.pop("$deny", {})
222224
pat = {}
223225
if allow:
224226
pat["$and"] = allow
225227
if deny:
226228
pat["$not"] = deny
227-
plugin_service.filters[self.id] = parse(pat)
228-
229-
def __post_init__(self):
230-
self._scope = es.scope(self.id)
231-
plugin_service.plugins[self.id] = self
232-
self.update_filter(self.config.pop("$allow", {}), self.config.pop("$deny", {}))
229+
if pat:
230+
self._scope.propagators.append(parse(pat))
233231
if "$static" in self.config:
234232
self.is_static = True
235233
self.config.pop("$static")
@@ -271,6 +269,7 @@ def dispose(self):
271269
plugin_service.plugins.pop(subplug, None)
272270
self.subplugins.clear()
273271
self._scope.dispose()
272+
self._scope.propagators.clear()
274273
del plugin_service.plugins[self.id]
275274
del self.module
276275

Diff for: arclet/entari/plugin/module.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..config import EntariConfig
1414
from ..logger import log
1515
from .model import Plugin, PluginMetadata, _current_plugin
16-
from .service import PluginAccess, plugin_service
16+
from .service import plugin_service
1717

1818
_SUBMODULE_WAITLIST: dict[str, set[str]] = {}
1919
_ENSURE_IS_PLUGIN: set[str] = set()
@@ -217,11 +217,9 @@ def create_module(self, spec) -> Optional[ModuleType]:
217217
return super().create_module(spec)
218218

219219
def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = None) -> None:
220-
is_sub = False
221220
if plugin := plugin_service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None:
222221
plugin.subplugins.add(module.__name__)
223222
plugin_service._subplugined[module.__name__] = plugin.id
224-
is_sub = True
225223

226224
if self.loaded:
227225
return
@@ -238,8 +236,6 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non
238236
# enter plugin context
239237
token = _current_plugin.set(plugin)
240238
if not plugin.is_static:
241-
if not is_sub:
242-
plugin._scope.propagators.append(PluginAccess(plugin.id))
243239
token1 = scope_ctx.set(plugin._scope)
244240
try:
245241
super().exec_module(module)

0 commit comments

Comments
 (0)