Skip to content

Commit 304486e

Browse files
committed
✨ filter from config
1 parent 5abd664 commit 304486e

File tree

11 files changed

+186
-106
lines changed

11 files changed

+186
-106
lines changed

Diff for: arclet/entari/_subscriber.py

-28
This file was deleted.

Diff for: arclet/entari/builtins/auto_reload.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@ class Config:
3030
logger = log.wrapper("[AutoReload]")
3131

3232

33+
def detect_filter_change(old: dict, new: dict):
34+
added = set(new) - set(old)
35+
removed = set(old) - set(new)
36+
changed = {key for key in set(new) & set(old) if new[key] != old[key]}
37+
if "$allow" in removed:
38+
allow = {}
39+
else:
40+
allow = new["$allow"]
41+
if "$deny" in removed:
42+
deny = {}
43+
else:
44+
deny = new["$deny"]
45+
return allow, deny, not ((added | removed | changed) - {"$allow", "$deny"})
46+
47+
3348
class Watcher(Service):
3449
id = "watcher"
3550

@@ -61,7 +76,6 @@ async def watch(self):
6176
dispose_plugin(pid)
6277
if plugin := load_plugin(pid):
6378
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
64-
plugin._load()
6579
await plugin._startup()
6680
await plugin._ready()
6781
del plugin
@@ -72,7 +86,6 @@ async def watch(self):
7286
logger("INFO", f"Detected change in {change[1]!r} which failed to reload, retrying...")
7387
if plugin := load_plugin(self.fail[change[1]]):
7488
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
75-
plugin._load()
7689
await plugin._startup()
7790
await plugin._ready()
7891
del plugin
@@ -102,37 +115,53 @@ async def watch_config(self):
102115
f"Basic config <y>{key!r}</y> changed from <r>{old_basic[key]!r}</r> "
103116
f"to <g>{EntariConfig.instance.basic[key]!r}</g>",
104117
)
105-
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key]))
118+
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key], old_basic[key]))
119+
for key in set(EntariConfig.instance.basic) - set(old_basic):
120+
logger("DEBUG", f"Basic config <y>{key!r}</y> appended")
121+
await es.publish(ConfigReload("basic", key, EntariConfig.instance.basic[key]))
106122
for plugin_name in old_plugin:
107123
pid = plugin_name.replace("::", "arclet.entari.builtins.")
108124
if (
109125
plugin_name not in EntariConfig.instance.plugin
110126
or EntariConfig.instance.plugin[plugin_name] is False
111-
) and (plugin := find_plugin(pid)):
112-
await plugin._cleanup()
113-
del plugin
114-
dispose_plugin(pid)
115-
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
127+
):
128+
if plugin := find_plugin(pid):
129+
await plugin._cleanup()
130+
del plugin
131+
dispose_plugin(pid)
132+
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
116133
continue
117134
if old_plugin[plugin_name] != EntariConfig.instance.plugin[plugin_name]:
118135
logger(
119136
"DEBUG",
120137
f"Plugin <y>{plugin_name!r}</y> config changed from <r>{old_plugin[plugin_name]!r}</r> "
121138
f"to <g>{EntariConfig.instance.plugin[plugin_name]!r}</g>",
122139
)
123-
res = await es.post(
124-
ConfigReload("plugin", plugin_name, EntariConfig.instance.plugin[plugin_name])
125-
)
126-
if res and res.value:
127-
logger("DEBUG", f"Plugin <y>{pid!r}</y> config change handled by itself.")
128-
continue
140+
if isinstance(old_plugin[plugin_name], bool):
141+
old_conf = {}
142+
else:
143+
old_conf: dict = old_plugin[plugin_name] # type: ignore
144+
if isinstance(EntariConfig.instance.plugin[plugin_name], bool):
145+
new_conf = {}
146+
else:
147+
new_conf: dict = EntariConfig.instance.plugin[plugin_name] # type: ignore
129148
if plugin := find_plugin(pid):
149+
allow, deny, only_filter = detect_filter_change(old_conf, new_conf)
150+
plugin.update_filter(allow, deny)
151+
if only_filter:
152+
logger("DEBUG", f"Plugin <y>{pid!r}</y> config only changed filter.")
153+
continue
154+
res = await es.post(
155+
ConfigReload("plugin", plugin_name, new_conf, old_conf),
156+
)
157+
if res and res.value:
158+
logger("DEBUG", f"Plugin <y>{pid!r}</y> config change handled by itself.")
159+
continue
130160
logger("INFO", f"Detected <blue>{pid!r}</blue>'s config change, reloading...")
131161
plugin_file = str(plugin.module.__file__)
132162
await plugin._cleanup()
133163
dispose_plugin(plugin_name)
134-
if plugin := load_plugin(plugin_name):
135-
plugin._load()
164+
if plugin := load_plugin(plugin_name, new_conf):
136165
await plugin._startup()
137166
await plugin._ready()
138167
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
@@ -142,12 +171,11 @@ async def watch_config(self):
142171
self.fail[plugin_file] = pid
143172
else:
144173
logger("INFO", f"Detected <blue>{pid!r}</blue> appended, loading...")
145-
load_plugin(plugin_name)
174+
load_plugin(plugin_name, new_conf)
146175
if new := (set(EntariConfig.instance.plugin) - set(old_plugin)):
147176
for plugin_name in new:
148177
if not (plugin := load_plugin(plugin_name)):
149178
continue
150-
plugin._load()
151179
await plugin._startup()
152180
await plugin._ready()
153181
del plugin

Diff for: arclet/entari/command/__init__.py

+30-24
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts):
5555
if not msg:
5656
return
5757
if matches := list(self.trie.prefixes(msg)):
58-
results = await asyncio.gather(*(res.value.handle(ctx.copy(), inner=True) for res in matches if res.value))
58+
results = await asyncio.gather(*(res.value.handle(ctx.copy()) for res in matches if res.value))
5959
for result in results:
6060
if result is not None:
6161
await session.send(result)
@@ -67,7 +67,7 @@ async def handle(self, session: Session, message: MessageChain, ctx: Contexts):
6767
command_manager.find_shortcut(get_cmd(value), data)
6868
except ValueError:
6969
continue
70-
result = await value.handle(ctx.copy(), inner=True)
70+
result = await value.handle(ctx.copy())
7171
if result is not None:
7272
await session.send(result)
7373

@@ -143,7 +143,7 @@ def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
143143
f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern)
144144
)
145145
auxiliaries.insert(0, AlconnaSuppiler(_command))
146-
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
146+
target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers)
147147
self.publisher.remove_subscriber(target)
148148
self.trie[key] = target
149149

@@ -152,30 +152,34 @@ def _remove(_):
152152
self.trie.pop(key, None) # type: ignore
153153

154154
target._dispose = _remove
155+
return target
156+
157+
_command = cast(Alconna, command)
158+
if not isinstance(command.command, str):
159+
raise TypeError("Command name must be a string.")
160+
_command.reset_namespace(self.__namespace__)
161+
auxiliaries.insert(0, AlconnaSuppiler(_command))
162+
keys = []
163+
if not _command.prefixes:
164+
keys.append(_command.command)
165+
elif not all(isinstance(i, str) for i in _command.prefixes):
166+
raise TypeError("Command prefixes must be a list of string.")
155167
else:
156-
auxiliaries.insert(0, AlconnaSuppiler(command))
157-
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
158-
self.publisher.remove_subscriber(target)
159-
if not isinstance(command.command, str):
160-
raise TypeError("Command name must be a string.")
161-
keys = []
162-
if not command.prefixes:
163-
self.trie[command.command] = target
164-
keys.append(command.command)
165-
elif not all(isinstance(i, str) for i in command.prefixes):
166-
raise TypeError("Command prefixes must be a list of string.")
167-
else:
168-
for prefix in cast(list[str], command.prefixes):
169-
self.trie[prefix + command.command] = target
170-
keys.append(prefix + command.command)
168+
for prefix in cast(list[str], _command.prefixes):
169+
keys.append(prefix + _command.command)
171170

172-
def _remove(_):
173-
command_manager.delete(get_cmd(_))
174-
for key in keys:
175-
self.trie.pop(key, None) # type: ignore
171+
target = self.publisher.register(func, auxiliaries=auxiliaries, providers=providers)
172+
self.publisher.remove_subscriber(target)
176173

177-
target._dispose = _remove
178-
command.reset_namespace(self.__namespace__)
174+
for _key in keys:
175+
self.trie[_key] = target
176+
177+
def _remove(_):
178+
command_manager.delete(get_cmd(_))
179+
for _key in keys:
180+
self.trie.pop(_key, None) # type: ignore
181+
182+
target._dispose = _remove
179183
return target
180184

181185
return wrapper
@@ -209,6 +213,8 @@ def _(plg: RootlessPlugin):
209213
if "use_config_prefix" in plg.config:
210214
_commands.judge.use_config_prefix = plg.config["use_config_prefix"]
211215

216+
plg.dispatch(MessageCreatedEvent).handle(_commands.handle, auxiliaries=[_commands.judge])
217+
212218
@plg.use(ConfigReload)
213219
def update(event: ConfigReload):
214220
if event.scope != "plugin":

Diff for: arclet/entari/command/plugin.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from arclet.alconna import Alconna, command_manager
66
from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory
77

8-
from .._subscriber import SubscribeLoader
98
from ..event import MessageCreatedEvent
109
from ..event.command import pub as execute_handles
1110
from ..plugin.model import Plugin, PluginDispatcher
@@ -60,8 +59,7 @@ def on_execute(
6059
_auxiliaries.append(self.supplier)
6160

6261
def wrapper(func):
63-
caller = execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)
64-
sub = SubscribeLoader(func, caller)
62+
sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers)
6563
self._subscribers.append(sub)
6664
return sub
6765

Diff for: arclet/entari/core.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import suppress
44
import os
55

6-
from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, es, global_providers
6+
from arclet.letoderea import BaseAuxiliary, Contexts, Param, Provider, ProviderFactory, Subscriber, es, global_providers
77
from creart import it
88
from launart import Launart, Service
99
from satori import LoginStatus
@@ -14,15 +14,14 @@
1414
from satori.model import Event
1515
from tarina.generic import get_origin
1616

17-
from .command import _commands
1817
from .config import EntariConfig
1918
from .event.config import ConfigReload
2019
from .event.lifespan import AccountUpdate
2120
from .event.protocol import MessageCreatedEvent, event_parse
2221
from .event.send import SendResponse
2322
from .logger import log
2423
from .plugin import load_plugin, plugin_config, requires
25-
from .plugin.model import RootlessPlugin
24+
from .plugin.model import Plugin, RootlessPlugin
2625
from .plugin.service import plugin_service
2726
from .session import EntariProtocol, Session
2827

@@ -55,7 +54,17 @@ async def __call__(self, context: Contexts):
5554
return context["account"]
5655

5756

58-
global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider()])
57+
class PluginProvider(Provider[Plugin]):
58+
async def __call__(self, context: Contexts):
59+
subscriber: Subscriber = context["$subscriber"]
60+
func = subscriber.callable_target
61+
if hasattr(func, "__globals__") and "__plugin__" in func.__globals__: # type: ignore
62+
return func.__globals__["__plugin__"]
63+
if hasattr(func, "__module__"):
64+
return plugin_service.plugins.get(func.__module__)
65+
66+
67+
global_providers.extend([ApiProtocolProvider(), SessionProvider(), AccountProvider(), PluginProvider()])
5968

6069

6170
@RootlessPlugin.apply("record_message")
@@ -117,6 +126,8 @@ def __init__(
117126
super().__init__(*configs, default_api_cls=EntariProtocol)
118127
if not hasattr(EntariConfig, "instance"):
119128
EntariConfig.load()
129+
if "~commands" not in EntariConfig.instance.plugin:
130+
EntariConfig.instance.plugin["~commands"] = True
120131
log.set_level(log_level)
121132
log.core.opt(colors=True).debug(f"Log level set to <y><c>{log_level}</c></y>")
122133
requires(*EntariConfig.instance.plugin)
@@ -128,7 +139,6 @@ def __init__(
128139
self._ref_tasks = set()
129140

130141
es.on(ConfigReload, self.reset_self)
131-
es.on(MessageCreatedEvent, _commands.handle, auxiliaries=[_commands.judge])
132142

133143
def reset_self(self, scope, key, value):
134144
if scope != "basic":

Diff for: arclet/entari/event/config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any
2+
from typing import Any, Optional
33

44
from arclet.letoderea import es
55

@@ -11,6 +11,7 @@ class ConfigReload(BasedEvent):
1111
scope: str
1212
key: str
1313
value: Any
14+
old: Optional[Any] = None
1415

1516
__publisher__ = "entari.event/config_reload"
1617
__result_type__: type[bool] = bool

0 commit comments

Comments
 (0)