diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py
index de19b30..c2dd738 100644
--- a/arclet/entari/__init__.py
+++ b/arclet/entari/__init__.py
@@ -1,4 +1,5 @@
from arclet.letoderea import bind as bind
+from arclet.letoderea import propagate as propagate
from satori import ArgvInteraction as ArgvInteraction
from satori import At as At
from satori import Audio as Audio
@@ -45,7 +46,6 @@
from .core import Entari as Entari
from .event import MessageCreatedEvent as MessageCreatedEvent
from .event import MessageEvent as MessageEvent
-from .filter import Filter as Filter
from .filter import filter_ as filter_
from .message import MessageChain as MessageChain
from .plugin import Plugin as Plugin
diff --git a/arclet/entari/builtins/auto_reload.py b/arclet/entari/builtins/auto_reload.py
index a0e389a..fe8bb6b 100644
--- a/arclet/entari/builtins/auto_reload.py
+++ b/arclet/entari/builtins/auto_reload.py
@@ -40,15 +40,7 @@ def detect_filter_change(old: dict, new: dict):
added = set(new) - set(old)
removed = set(old) - set(new)
changed = {key for key in set(new) & set(old) if new[key] != old[key]}
- if "$allow" in removed:
- allow = {}
- else:
- allow = new.get("$allow", {})
- if "$deny" in removed:
- deny = {}
- else:
- deny = new.get("$deny", {})
- return allow, deny, not ((added | removed | changed) - {"$allow", "$deny"})
+ return "allow" in (added | removed | changed) or "$deny" in (added | removed | changed)
class Watcher(Service):
@@ -86,9 +78,7 @@ async def watch(self):
logger.error(f"Failed to reload {pid!r}")
self.fail[change[1]] = pid
elif change[1] in self.fail:
- logger.info(
- f"Detected change in {change[1]!r} which failed to reload, retrying..."
- )
+ logger.info(f"Detected change in {change[1]!r} which failed to reload, retrying...")
if plugin := load_plugin(self.fail[change[1]]):
logger.info(f"Reloaded {plugin.id!r}")
del plugin
@@ -108,7 +98,6 @@ async def watch_config(self):
or Path(change[1]).resolve() in extra
or Path(change[1]).resolve().parent in extra
):
- print(change)
continue
logger.info(f"Detected change in {change[1]!r}, reloading config...")
@@ -143,22 +132,17 @@ async def watch_config(self):
old_conf = old_plugin[plugin_name]
new_conf = EntariConfig.instance.plugin[plugin_name]
if plugin := find_plugin(pid):
- allow, deny, only_filter = detect_filter_change(old_conf, new_conf)
- plugin.update_filter(allow, deny)
- if only_filter:
- logger.debug(f"Plugin {pid!r} config only changed filter.")
- continue
- res = await es.post(
- ConfigReload("plugin", plugin_name, new_conf, old_conf),
- )
- if res and res.value:
- logger.debug(f"Plugin {pid!r} config change handled by itself.")
- continue
- logger.info(
- f"Detected config of {pid!r} changed, reloading..."
- )
+ filter_changed = detect_filter_change(old_conf, new_conf)
+ if not filter_changed:
+ res = await es.post(
+ ConfigReload("plugin", plugin_name, new_conf, old_conf),
+ )
+ if res and res.value:
+ logger.debug(f"Plugin {pid!r} config change handled by itself.")
+ continue
+ logger.info(f"Detected config of {pid!r} changed, reloading...")
plugin_file = str(plugin.module.__file__)
- unload_plugin(plugin_name)
+ unload_plugin(pid)
if plugin := load_plugin(plugin_name, new_conf):
logger.info(f"Reloaded {plugin.id!r}")
del plugin
diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py
index 94e49b5..611c056 100644
--- a/arclet/entari/command/__init__.py
+++ b/arclet/entari/command/__init__.py
@@ -4,7 +4,7 @@
from arclet.alconna import Alconna, Arg, Args, CommandMeta, Namespace, command_manager, config
from arclet.alconna.tools.construct import AlconnaString, alconna_from_format
from arclet.alconna.typing import TAValue
-from arclet.letoderea import BaseAuxiliary, Provider, Scope, Subscriber, es
+from arclet.letoderea import Provider, Scope, Subscriber, es
from arclet.letoderea.handler import generate_contexts
from arclet.letoderea.provider import ProviderFactory, get_providers
from arclet.letoderea.typing import Contexts, TTarget
@@ -28,7 +28,9 @@
def get_cmd(target: Subscriber):
- return next(a for a in target.auxiliaries["prepare"] if isinstance(a, AlconnaSuppiler)).cmd
+ if sub := target.get_propagator(AlconnaSuppiler.supply):
+ return sub.callable_target.__self__.cmd # type: ignore
+ raise ValueError("Subscriber has no command.")
class EntariCommands:
@@ -104,12 +106,11 @@ def command(
self,
command: str,
help_text: Optional[str] = None,
- auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
):
class Command(AlconnaString):
def __call__(_cmd_self, func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
- return self.on(_cmd_self.build(), auxiliaries, providers)(func)
+ return self.on(_cmd_self.build(), providers)(func)
return Command(command, help_text)
@@ -117,7 +118,6 @@ def __call__(_cmd_self, func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]
def on(
self,
command: Alconna,
- auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
) -> Callable[[TTarget[Optional[TM]]], Subscriber[Optional[TM]]]: ...
@@ -125,7 +125,6 @@ def on(
def on(
self,
command: str,
- auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
@@ -135,15 +134,12 @@ def on(
def on(
self,
command: Union[Alconna, str],
- auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
) -> Callable[[TTarget[Optional[TM]]], Subscriber[Optional[TM]]]:
- auxiliaries = auxiliaries or []
- if plg := _current_plugin.get():
- auxiliaries.extend(plg._scope.auxiliaries)
+ plg = _current_plugin.get()
providers = providers or []
def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
@@ -155,8 +151,10 @@ def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
key = _command.name + "".join(
f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern)
)
- auxiliaries.append(AlconnaSuppiler(_command))
- target = self.scope.register(func, auxiliaries=auxiliaries, providers=providers)
+ target = self.scope.register(func, providers=providers)
+ target.propagate(AlconnaSuppiler(_command))
+ if plg:
+ target.propagates(*plg._scope.propagators)
self.trie[key] = target.id
def _remove(_):
@@ -175,7 +173,6 @@ def _remove(_):
if not isinstance(command.command, str):
raise TypeError("Command name must be a string.")
_command.reset_namespace(self.__namespace__)
- auxiliaries.insert(0, AlconnaSuppiler(_command))
keys = []
if not _command.prefixes:
keys.append(_command.command)
@@ -185,7 +182,10 @@ def _remove(_):
for prefix in cast(list[str], _command.prefixes):
keys.append(prefix + _command.command)
- target = self.scope.register(func, auxiliaries=auxiliaries, providers=providers)
+ target = self.scope.register(func, providers=providers)
+ target.propagate(AlconnaSuppiler(_command))
+ if plg:
+ target.propagates(*plg._scope.propagators)
for _key in keys:
self.trie[_key] = target.id
@@ -234,7 +234,7 @@ def _(plg: RootlessPlugin):
if "use_config_prefix" in plg.config:
_commands.judge.use_config_prefix = plg.config["use_config_prefix"]
- plg.dispatch(MessageCreatedEvent).handle(_commands.handle, auxiliaries=[_commands.judge])
+ plg.dispatch(MessageCreatedEvent).handle(_commands.handle).propagate(_commands.judge)
@plg.use(ConfigReload)
def update(event: ConfigReload):
diff --git a/arclet/entari/command/plugin.py b/arclet/entari/command/plugin.py
index 897de18..1305c6c 100644
--- a/arclet/entari/command/plugin.py
+++ b/arclet/entari/command/plugin.py
@@ -3,7 +3,7 @@
from typing import Any
from arclet.alconna import Alconna, command_manager
-from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory, es
+from arclet.letoderea import Provider, ProviderFactory, es
from ..event import MessageCreatedEvent
from ..event.command import CommandExecute
@@ -28,10 +28,10 @@ def __init__(
plugin._extra.setdefault("commands", []).append((command.prefixes, command.command))
self.supplier = AlconnaSuppiler(command)
super().__init__(plugin, MessageCreatedEvent, command.path)
- self.auxiliaries.append(
+ self.propagators.append(
MessageJudges(need_reply_me, need_notice_me, use_config_prefix),
)
- self.auxiliaries.append(self.supplier)
+ self.propagators.append(self.supplier)
self.providers.append(AlconnaProviderFactory())
@plugin.collect
@@ -46,24 +46,28 @@ def assign(
value: Any = _seminal,
or_not: bool = False,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
- _auxiliaries = auxiliaries or []
- _auxiliaries.append(Assign(path, value, or_not))
- return self.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)
+ assign = Assign(path, value, or_not)
+ try:
+ self.propagators.append(assign)
+ return self.register(priority=priority, providers=providers)
+ finally:
+ self.propagators.remove(assign)
def on_execute(
self,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
- _auxiliaries = auxiliaries or []
- _auxiliaries.append(self.supplier)
- return self.plugin._scope.register(
- priority=priority, auxiliaries=_auxiliaries, providers=providers, publisher=exec_pub
- )
+ wrapper = self.plugin._scope.register(priority=priority, providers=providers, publisher=exec_pub)
+
+ def decorator(func):
+ sub = wrapper(func)
+ sub.propagate(self.supplier)
+ return sub
+
+ return decorator
Match = Match
Query = Query
diff --git a/arclet/entari/command/provider.py b/arclet/entari/command/provider.py
index 9142643..de73701 100644
--- a/arclet/entari/command/provider.py
+++ b/arclet/entari/command/provider.py
@@ -3,7 +3,8 @@
from arclet.alconna import Alconna, Arparma, Duplication, Empty, output_manager
from arclet.alconna.builtin import generate_duplication
-from arclet.letoderea import BaseAuxiliary, Contexts, Interface, Param, Provider
+from arclet.letoderea import STOP, Contexts, Param, Propagator, Provider
+from arclet.letoderea.exceptions import ProviderUnsatisfied
from arclet.letoderea.provider import ProviderFactory
from nepattern.util import CUnionType
from satori.element import Text
@@ -30,50 +31,34 @@ def _remove_config_prefix(message: MessageChain):
return MessageChain()
-class MessageJudges(BaseAuxiliary):
+class MessageJudges(Propagator):
def __init__(self, need_reply_me: bool, need_notice_me: bool, use_config_prefix: bool):
self.need_reply_me = need_reply_me
self.need_notice_me = need_notice_me
self.use_config_prefix = use_config_prefix
- async def on_prepare(self, interface: Interface):
- if "$message_content" in interface.ctx:
- message: MessageChain = interface.ctx["$message_content"]
- is_reply_me = interface.ctx.get("is_reply_me", False)
- is_notice_me = interface.ctx.get("is_notice_me", False)
- if self.need_reply_me and not is_reply_me:
- return False
- if self.need_notice_me and not is_notice_me:
- return False
- if self.use_config_prefix and not (message := _remove_config_prefix(message)):
- return False
- return interface.update(**{"$message_content": message})
- return (await interface.query(MessageChain, "message", force_return=True)) is not None
-
- @property
- def before(self) -> set[str]:
- return {"entari.filter"}
-
- @property
- def after(self) -> set[str]:
- return {"entari.command/supplier"}
-
- @property
- def id(self) -> str:
- return "entari.command/message_judges"
-
-
-class AlconnaSuppiler(BaseAuxiliary):
+ async def judge(self, ctx: Contexts, message: MessageChain, is_reply_me: bool = False, is_notice_me: bool = False):
+ if self.need_reply_me and not is_reply_me:
+ return STOP
+ if self.need_notice_me and not is_notice_me:
+ return STOP
+ if self.use_config_prefix and not (message := _remove_config_prefix(message)):
+ return STOP
+ if "$message_content" in ctx:
+ return {"$message_content": message}
+ return {"message": message}
+
+ def compose(self):
+ yield self.judge, True, 60
+
+
+class AlconnaSuppiler(Propagator):
cmd: Alconna
def __init__(self, cmd: Alconna):
self.cmd = cmd
- async def on_prepare(self, interface: Interface) -> Optional[Union[bool, Interface.Update]]:
- message = await interface.query(MessageChain, "message", force_return=True)
- if not message:
- return False
- session = await interface.query(Session, "session", force_return=True)
+ async def supply(self, message: MessageChain, session: Optional[Session] = None):
with output_manager.capture(self.cmd.name) as cap:
output_manager.set_action(lambda x: x, self.cmd.name)
try:
@@ -82,17 +67,16 @@ async def on_prepare(self, interface: Interface) -> Optional[Union[bool, Interfa
_res = Arparma(self.cmd._hash, message, False, error_info=e)
may_help_text: Optional[str] = cap.get("output", None)
if _res.matched:
- return interface.update(alc_result=CommandResult(self.cmd, _res, may_help_text))
- elif may_help_text:
+ return {"alc_result": CommandResult(self.cmd, _res, may_help_text)}
+ if may_help_text:
if session:
await session.send(MessageChain(may_help_text))
- return False
- return interface.update(alc_result=CommandResult(self.cmd, _res, may_help_text))
- return False
+ return STOP
+ return {"alc_result": CommandResult(self.cmd, _res, may_help_text)}
+ return STOP
- @property
- def id(self) -> str:
- return "entari.command/supplier"
+ def compose(self):
+ yield self.supply, True, 70
class AlconnaProvider(Provider[Any]):
@@ -103,7 +87,9 @@ def __init__(self, type_: str, extra: Optional[dict] = None):
async def __call__(self, context: Contexts):
if "alc_result" not in context:
- return
+ if self.type == "args":
+ return
+ raise ProviderUnsatisfied("alc_result")
result: CommandResult = context["alc_result"]
if self.type == "result":
return result
@@ -134,36 +120,29 @@ async def __call__(self, context: Contexts):
_seminal = type("_seminal", (object,), {})
-class Assign(BaseAuxiliary):
+class Assign(Propagator):
def __init__(self, path: str, value: Any = _seminal, or_not: bool = False):
self.path = path
self.value = value
self.or_not = or_not
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- result = await interface.query(CommandResult, "alc_result", force_return=True)
- if result is None:
- return False
+ async def check(self, alc_result: CommandResult):
if self.value == _seminal:
if self.path == "$main" or self.or_not:
- if not result.result.components:
- return True
- return False
- return result.result.query(self.path, "\1") != "\1"
+ if not alc_result.result.components:
+ return
+ return STOP
+ if alc_result.result.query(self.path, "\1") == "\1":
+ return STOP
else:
- if result.result.query(self.path) == self.value:
- return True
- if self.or_not and result.result.query(self.path) == Empty:
- return True
- return False
-
- @property
- def before(self) -> set[str]:
- return {"entari.command/supplier"}
-
- @property
- def id(self) -> str:
- return f"entari.command/assign:{self.path}"
+ if alc_result.result.query(self.path) != self.value:
+ return STOP
+ if self.or_not and alc_result.result.query(self.path) == Empty:
+ return
+ return STOP
+
+ def compose(self):
+ yield self.check, True, 80
class AlconnaProviderFactory(ProviderFactory):
diff --git a/arclet/entari/core.py b/arclet/entari/core.py
index cdb7cd7..2d33537 100644
--- a/arclet/entari/core.py
+++ b/arclet/entari/core.py
@@ -36,7 +36,7 @@ async def __call__(self, context: Contexts):
class SessionProvider(Provider[Session]):
def validate(self, param: Param):
- return get_origin(param.annotation) == Session
+ return (get_origin(param.annotation) == Session) or super().validate(param)
async def __call__(self, context: Contexts):
if "session" in context and isinstance(context["session"], Session):
diff --git a/arclet/entari/filter/__init__.py b/arclet/entari/filter/__init__.py
index 55d8e1a..a5b2ae2 100644
--- a/arclet/entari/filter/__init__.py
+++ b/arclet/entari/filter/__init__.py
@@ -1,67 +1,90 @@
import asyncio
+from collections.abc import Awaitable
from datetime import datetime
-from typing import Optional, Union, cast
+from typing import Callable, Final, Optional, Union
+from typing_extensions import TypeAlias
-from arclet.letoderea import BaseAuxiliary, Interface
+from arclet.letoderea import STOP, Propagator
+from arclet.letoderea.typing import run_sync
+from tarina import is_coroutinefunction
+from . import common, message
from ..message import MessageChain
from ..session import Session
-from .common import Filter as Filter
+from .common import parse as parse
+_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]
-class Interval(BaseAuxiliary):
+
+class _Filter:
+ user = staticmethod(common.user)
+ guild = staticmethod(common.guild)
+ channel = staticmethod(common.channel)
+ self_ = staticmethod(common.account)
+ platform = staticmethod(common.platform)
+ direct = staticmethod(message.direct_message)
+ private = staticmethod(message.direct_message)
+ direct_message = staticmethod(message.direct_message)
+ public = staticmethod(message.public_message)
+ public_message = staticmethod(message.public_message)
+ notice_me = staticmethod(message.notice_me)
+ reply_me = staticmethod(message.reply_me)
+ to_me = staticmethod(message.to_me)
+
+ def __call__(self, func: _SessionFilter):
+ _func = func if is_coroutinefunction(func) else run_sync(func)
+
+ async def _(session: Session):
+ if not await _func(session): # type: ignore
+ return STOP
+
+ return _
+
+
+filter_: Final[_Filter] = _Filter()
+F = filter_
+
+
+class Interval(Propagator):
def __init__(self, interval: float, limit_prompt: Optional[Union[str, MessageChain]] = None):
self.success = True
self.last_time = None
self.interval = interval
self.limit_prompt = limit_prompt
- @property
- def id(self):
- return "entari.filter/interval"
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
+ async def before(self, session: Optional[Session] = None):
if not self.last_time:
- return True
+ return
self.success = (datetime.now() - self.last_time).total_seconds() > self.interval
if not self.success:
- session = await interface.query(Session, "session", force_return=True)
if session and self.limit_prompt:
await session.send(self.limit_prompt)
- return self.success
+ return STOP
- async def on_cleanup(self, interface: Interface) -> Optional[bool]:
+ async def after(self):
if self.success:
self.last_time = datetime.now()
- return True
+ def compose(self):
+ yield self.before, True, 15
+ yield self.after, False, 60
-class Semaphore(BaseAuxiliary):
+
+class Semaphore(Propagator):
def __init__(self, count: int, limit_prompt: Optional[Union[str, MessageChain]] = None):
self.count = count
self.limit_prompt = limit_prompt
self.semaphore = asyncio.Semaphore(count)
- @property
- def id(self):
- return "entari.filter/access"
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
+ async def before(self, session: Optional[Session] = None):
if not await self.semaphore.acquire():
- session = await interface.query(Session, "session", force_return=True)
if session and self.limit_prompt:
await session.send(self.limit_prompt)
- return False
- return True
+ return STOP
- async def on_cleanup(self, interface: Interface) -> Optional[bool]:
+ async def after(self):
self.semaphore.release()
- return True
-
-
-class _Wrapper:
- def __getattr__(self, item):
- return getattr(Filter(), item)
-
-filter_ = cast(Filter, _Wrapper())
+ def compose(self):
+ yield self.before, True, 15
+ yield self.after, False, 60
diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py
index 72af25d..f4e29cb 100644
--- a/arclet/entari/filter/common.py
+++ b/arclet/entari/filter/common.py
@@ -1,244 +1,145 @@
from collections.abc import Awaitable
-from typing import Callable, Optional, Union
-from typing_extensions import Self, TypeAlias
-
-from arclet.letoderea import BaseAuxiliary, Interface
-from arclet.letoderea import bind as _bind
-from arclet.letoderea.auxiliary import sort_auxiliaries
-from arclet.letoderea.typing import run_sync
-from satori import Channel, Guild, User
-from satori.client import Account
-from tarina import is_async
-
-from ..event.base import SatoriEvent
+from typing import Any, Callable, Optional, Union
+from typing_extensions import TypeAlias
+
+from arclet.letoderea import STOP, Propagator
+
from ..session import Session
-from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger
-from .op import ExcludeFilter, IntersectFilter, UnionFilter
+from .message import direct_message, notice_me, public_message, reply_me, to_me
-class PlatformFilter(BaseAuxiliary):
- def __init__(self, *platforms: str):
- self.platforms = set(platforms)
+def user(*ids: str):
+ async def check_user(session: Session):
+ return (None if session.user.id in ids else STOP) if ids else None
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (account := await interface.query(Account, "account", force_return=True)):
- return False
- return account.platform in self.platforms
+ return check_user
- @property
- def id(self) -> str:
- return "entari.filter/platform"
+def channel(*ids: str):
+ async def check_channel(session: Session):
+ return (None if session.channel.id in ids else STOP) if ids else None
-class SelfFilter(BaseAuxiliary):
- def __init__(self, *self_ids: str):
- self.self_ids = set(self_ids)
+ return check_channel
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (account := await interface.query(Account, "account", force_return=True)):
- return False
- return account.self_id in self.self_ids
- @property
- def id(self) -> str:
- return "entari.filter/self"
+def guild(*ids: str):
+ async def check_guild(session: Session):
+ return (None if session.guild.id in ids else STOP) if ids else None
- @property
- def before(self) -> set[str]:
- return {"entari.filter/platform"}
+ return check_guild
-class GuildFilter(BaseAuxiliary):
- def __init__(self, *guild_ids: str):
- self.guild_ids = set(guild_ids)
+def account(*ids: str):
+ async def check_account(session: Session):
+ return (None if session.account.self_id in ids else STOP) if ids else None
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (guild := await interface.query(Guild, "guild", force_return=True)):
- return False
- return guild.id in self.guild_ids if self.guild_ids else True
+ return check_account
- @property
- def id(self) -> str:
- return "entari.filter/guild"
- @property
- def before(self) -> set[str]:
- return {"entari.filter/platform", "entari.filter/self"}
+def platform(*ids: str):
+ async def check_platform(session: Session):
+ return (None if session.account.platform in ids else STOP) if ids else None
+ return check_platform
-class ChannelFilter(BaseAuxiliary):
- def __init__(self, *channel_ids: str):
- self.channel_ids = set(channel_ids)
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (channel := await interface.query(Channel, "channel", force_return=True)):
- return False
- return channel.id in self.channel_ids if self.channel_ids else True
+_keys = {
+ "user": (user, 2),
+ "guild": (guild, 3),
+ "channel": (channel, 4),
+ "self": (account, 1),
+ "platform": (platform, 0),
+ "direct": (lambda: direct_message, 5),
+ "private": (lambda: direct_message, 5),
+ "public": (lambda: public_message, 6),
+}
- @property
- def id(self) -> str:
- return "entari.filter/channel"
+_mess_keys = {
+ "reply_me": (reply_me, 7),
+ "notice_me": (notice_me, 8),
+ "to_me": (to_me, 9),
+}
- @property
- def before(self) -> set[str]:
- return {"entari.filter/platform", "entari.filter/self", "entari.filter/guild"}
+_op_keys = {
+ "$and": "and",
+ "$or": "or",
+ "$not": "not",
+ "$intersect": "and",
+ "$union": "or",
+ "$exclude": "not",
+}
+PATTERNS: TypeAlias = dict[str, Union[list[str], bool, "PATTERNS"]]
-class UserFilter(BaseAuxiliary):
- def __init__(self, *user_ids: str):
- self.user_ids = set(user_ids)
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (user := await interface.query(User, "user", force_return=True)):
- return False
- return user.id in self.user_ids if self.user_ids else True
+class _Filter(Propagator):
+ def __init__(
+ self,
+ steps: list[Callable[[Session], Awaitable[bool]]],
+ mess: list[Callable[[bool, bool], bool]],
+ ops: list[tuple[str, "_Filter"]],
+ ):
+ self.steps = steps
+ self.mess = mess
+ self.ops = ops
+
+ async def check(self, session: Optional[Session] = None, is_reply_me: bool = False, is_notice_me: bool = False):
+ res = True
+ if session and self.steps:
+ res = all([await step(session) for step in self.steps])
+ if self.mess:
+ res = res and all(mess(is_reply_me, is_notice_me) for mess in self.mess)
+ for op, f_ in self.ops:
+ if op == "and":
+ res = res and (await f_.check(session, is_reply_me, is_notice_me)) is None
+ elif op == "or":
+ res = res or (await f_.check(session, is_reply_me, is_notice_me)) is None
+ else:
+ res = res and (await f_.check(session, is_reply_me, is_notice_me)) is STOP
+ return None if res else STOP
- @property
- def id(self) -> str:
- return "entari.filter/user"
+ def compose(self):
+ yield self.check, True, 0
- @property
- def before(self) -> set[str]:
- return {"entari.filter/platform", "entari.filter/self"}
+def _wrapper(func: Callable[[Session], Any]):
+ async def _(session: Session):
+ return True if await func(session) is None else False
-_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]
-_keys = {
- "user",
- "guild",
- "channel",
- "self",
- "platform",
- "direct",
- "private",
- "public",
- "reply_me",
- "notice_me",
- "to_me",
-}
+ return _
-PATTERNS: TypeAlias = dict[str, Union[list[str], bool, "PATTERNS"]]
+def parse(patterns: PATTERNS):
+ step: dict[int, Callable[[Session], Awaitable[bool]]] = {}
+ mess: dict[int, Callable[[bool, bool], bool]] = {}
+ ops: list[tuple[str, _Filter]] = []
-class Filter(BaseAuxiliary):
- def __init__(self, callback: Optional[_SessionFilter] = None):
- self.steps = []
- if callback:
- if is_async(callback):
- self.callback = callback
+ for key, value in patterns.items():
+ if key in _keys:
+ step[_keys[key][1]] = _wrapper(_keys[key][0](*value) if isinstance(value, list) else _keys[key][0]())
+ elif key in _mess_keys:
+ if key == "reply_me":
+ mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
+ True if _mess_keys[key][0](is_reply_me) is None else False
+ )
+ elif key == "notice_me":
+ mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
+ True if _mess_keys[key][0](is_notice_me) is None else False
+ )
else:
- self.callback = run_sync(callback)
+ mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
+ True if _mess_keys[key][0](is_reply_me, is_notice_me) is None else False
+ )
+ elif key in _op_keys:
+ op = _op_keys[key]
+ if not isinstance(value, dict):
+ raise ValueError(f"Expect a dict for operator {key}")
+ ops.append((op, parse(value)))
else:
- self.callback = None
-
- def __repr__(self):
- return f"{self.__class__.__name__}({self.steps})"
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not isinstance(interface.event, SatoriEvent): # we only care about event from satori
- return True
- for step in sort_auxiliaries(self.steps):
- if not await step.on_prepare(interface):
- return False
- if self.callback:
- session = await interface.query(Session, "session", force_return=True)
- if not session:
- return False
- if not await self.callback(session): # type: ignore
- return False
- return True
-
- @property
- def id(self) -> str:
- return "entari.filter"
-
- def platform(self, *platforms: str) -> Self:
- self.steps.append(PlatformFilter(*platforms))
- return self
-
- def self(self, *self_ids: str) -> Self:
- self.steps.append(SelfFilter(*self_ids))
- return self
-
- def guild(self, *guild_ids: str) -> Self:
- self.steps.append(GuildFilter(*guild_ids))
- return self
-
- def channel(self, *channel_ids: str) -> Self:
- self.steps.append(ChannelFilter(*channel_ids))
- return self
-
- def user(self, *user_ids: str) -> Self:
- self.steps.append(UserFilter(*user_ids))
- return self
-
- def direct(self) -> Self:
- self.steps.append(DirectMessageJudger())
- return self
-
- private = direct
-
- def public(self) -> Self:
- self.steps.append(PublicMessageJudger())
- return self
-
- def reply_me(self) -> Self:
- self.steps.append(ReplyMeJudger())
- return self
-
- def notice_me(self) -> Self:
- self.steps.append(NoticeMeJudger())
- return self
-
- def to_me(self) -> Self:
- self.steps.append(ToMeJudger())
- return self
-
- def __call__(self, func):
- return _bind(self)(func)
-
- bind = __call__
-
- def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
- new = Filter()
- _other = other if isinstance(other, Filter) else Filter(callback=other)
- new.steps.append(IntersectFilter(self, _other))
- return new
-
- intersect = and_
-
- def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
- new = Filter()
- _other = other if isinstance(other, Filter) else Filter(callback=other)
- new.steps.append(UnionFilter(self, _other))
- return new
-
- union = or_
-
- def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter":
- new = Filter()
- _other = other if isinstance(other, Filter) else Filter(callback=other)
- new.steps.append(ExcludeFilter(self, _other))
- return new
-
- exclude = not_
-
- @classmethod
- def parse(cls, patterns: PATTERNS) -> Self:
- fter = cls()
- for key, value in patterns.items():
- if key in _keys:
- if isinstance(value, list):
- getattr(fter, key)(*value)
- elif isinstance(value, bool) and value:
- getattr(fter, key)()
- elif key in ("$and", "$or", "$not", "$intersect", "$union", "$exclude"):
- op = key[1:]
- if op in ("and", "or", "not"):
- op += "_"
- if not isinstance(value, dict):
- raise ValueError(f"Expect a dict for operator {key}")
- fter = getattr(fter, op)(cls.parse(value))
- else:
- raise ValueError(f"Unknown key: {key}")
- return fter
+ raise ValueError(f"Unknown key: {key}")
+
+ return _Filter(
+ steps=[slot[1] for slot in sorted(step.items(), key=lambda x: x[0])],
+ mess=[slot[1] for slot in sorted(mess.items(), key=lambda x: x[0])],
+ ops=ops,
+ )
diff --git a/arclet/entari/filter/message.py b/arclet/entari/filter/message.py
index e8eacbd..1630b55 100644
--- a/arclet/entari/filter/message.py
+++ b/arclet/entari/filter/message.py
@@ -1,71 +1,29 @@
-from typing import Optional
+from arclet.letoderea import STOP
+from satori import ChannelType
-from arclet.letoderea import BaseAuxiliary, Interface
-from satori import Channel, ChannelType
+from ..session import Session
-class DirectMessageJudger(BaseAuxiliary):
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (channel := await interface.query(Channel, "channel", force_return=True)):
- return False
- return channel.type == ChannelType.DIRECT
+async def direct_message(sess: Session):
+ if sess.channel.type != ChannelType.DIRECT:
+ return STOP
- @property
- def id(self) -> str:
- return "entari.filter/direct_message"
+async def public_message(sess: Session):
+ if sess.channel.type == ChannelType.DIRECT:
+ return STOP
-class PublicMessageJudger(BaseAuxiliary):
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- if not (channel := await interface.query(Channel, "channel", force_return=True)):
- return False
- return channel.type != ChannelType.DIRECT
- @property
- def id(self) -> str:
- return "entari.filter/public_message"
+def reply_me(is_reply_me: bool = False):
+ if not is_reply_me:
+ return STOP
-class ReplyMeJudger(BaseAuxiliary):
+def notice_me(is_notice_me: bool = False):
+ if not is_notice_me:
+ return STOP
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- return interface.ctx.get("is_reply_me", False)
- @property
- def id(self) -> str:
- return "entari.filter/judge_reply_me"
-
-
-class NoticeMeJudger(BaseAuxiliary):
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- return interface.ctx.get("is_notice_me", False)
-
- @property
- def id(self) -> str:
- return "entari.filter/judge_notice_me"
-
- @property
- def before(self) -> set[str]:
- return {"entari.filter/judge_reply_me"}
-
-
-class ToMeJudger(BaseAuxiliary):
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- is_reply_me = interface.ctx.get("is_reply_me", False)
- is_notice_me = interface.ctx.get("is_notice_me", False)
- return is_reply_me or is_notice_me
-
- @property
- def id(self) -> str:
- return "entari.filter/judge_to_me"
-
- @property
- def before(self) -> set[str]:
- return {"entari.filter/judge_reply_me", "entari.filter/judge_notice_me"}
-
-
-public_message = PublicMessageJudger()
-direct_message = DirectMessageJudger()
-reply_me = ReplyMeJudger()
-notice_me = NoticeMeJudger()
-to_me = ToMeJudger()
+def to_me(is_reply_me: bool = False, is_notice_me: bool = False):
+ if not is_reply_me and not is_notice_me:
+ return STOP
diff --git a/arclet/entari/filter/op.py b/arclet/entari/filter/op.py
deleted file mode 100644
index e6a8516..0000000
--- a/arclet/entari/filter/op.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import TYPE_CHECKING, Optional
-
-from arclet.letoderea import BaseAuxiliary, Interface
-
-if TYPE_CHECKING:
- from .common import Filter
-
-
-class IntersectFilter(BaseAuxiliary):
- def __init__(self, left: "Filter", right: "Filter"):
- self.left = left
- self.right = right
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- return await self.left.on_prepare(interface) and await self.right.on_prepare(interface)
-
- @property
- def id(self) -> str:
- return "entari.filter/intersect"
-
-
-class UnionFilter(BaseAuxiliary):
- def __init__(self, left: "Filter", right: "Filter"):
- self.left = left
- self.right = right
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- return await self.left.on_prepare(interface) or await self.right.on_prepare(interface)
-
- @property
- def id(self) -> str:
- return "entari.filter/union"
-
-
-class ExcludeFilter(BaseAuxiliary):
- def __init__(self, left: "Filter", right: "Filter"):
- self.left = left
- self.right = right
-
- async def on_prepare(self, interface: Interface) -> Optional[bool]:
- return await self.left.on_prepare(interface) and not await self.right.on_prepare(interface)
-
- @property
- def id(self) -> str:
- return "entari.filter/exclude"
diff --git a/arclet/entari/logger.py b/arclet/entari/logger.py
index c78b181..ed2fb72 100644
--- a/arclet/entari/logger.py
+++ b/arclet/entari/logger.py
@@ -20,17 +20,17 @@ def __init__(self):
def fork(self, child_name: str):
patched = logger.patch(lambda r: r.update(name=child_name))
- patched = patched.bind(name=child_name).opt(colors=True)
+ patched = patched.bind(name=child_name)
self.loggers[child_name] = patched
return patched
@property
def core(self):
- return self.loggers["[core]"]
+ return self.loggers["[core]"].opt(colors=True)
@property
def plugin(self):
- return self.loggers["[plugin]"]
+ return self.loggers["[plugin]"].opt(colors=True)
@property
def message(self):
@@ -38,7 +38,9 @@ def message(self):
def wrapper(self, name: str, color: str = "blue"):
patched = logger.patch(
- lambda r: r.update(name="entari", extra={"entari_plugin_name": name, "entari_plugin_color": color})
+ lambda r: r.update(
+ name="entari", extra=r["extra"] | {"entari_plugin_name": name, "entari_plugin_color": color}
+ )
)
patched = patched.bind(name=f"plugins.{name}")
self.loggers[f"plugin.{name}"] = patched
diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py
index cc1d8ee..298fd37 100644
--- a/arclet/entari/plugin/__init__.py
+++ b/arclet/entari/plugin/__init__.py
@@ -94,9 +94,7 @@ def load_plugin(
if referent in recursive_guard:
continue
if referent in plugin_service.plugins:
- log.plugin.debug(
- f"reloading {mod.__name__}'s referent {referent!r}"
- )
+ log.plugin.debug(f"reloading {mod.__name__}'s referent {referent!r}")
unload_plugin(referent)
if not load_plugin(referent):
plugin_service._referents[mod.__name__].add(referent)
@@ -109,9 +107,7 @@ def load_plugin(
log.plugin.error(f"failed to load plugin {path!r}: {e.args[0]}")
es.publish(PluginLoadedFailed(path, e))
except Exception as e:
- log.plugin.exception(
- f"failed to load plugin {path!r} caused by {e!r}", exc_info=e
- )
+ log.plugin.exception(f"failed to load plugin {path!r} caused by {e!r}", exc_info=e)
es.publish(PluginLoadedFailed(path, e))
diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py
index 5815d87..a6af046 100644
--- a/arclet/entari/plugin/model.py
+++ b/arclet/entari/plugin/model.py
@@ -8,14 +8,14 @@
from typing import Any, Callable, TypeVar, overload
from weakref import finalize, proxy
-from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory, Scope, StepOut, Subscriber, es
+from arclet.letoderea import Propagator, Provider, ProviderFactory, Scope, StepOut, Subscriber, es
from arclet.letoderea.publisher import Publisher, _publishers
from arclet.letoderea.typing import TTarget
from creart import it
from launart import Launart, Service
from tarina import ContextModel
-from ..filter import Filter
+from ..filter import parse
from ..logger import log
from .service import plugin_service
@@ -46,13 +46,12 @@ def __init__(
self.plugin = plugin
self._event = event
self.providers: list[Provider[Any] | ProviderFactory] = []
- self.auxiliaries: list[BaseAuxiliary] = []
+ self.propagators: list[Propagator] = []
def waiter(
self,
*events: Any,
providers: Sequence[Provider | type[Provider]] | None = None,
- auxiliaries: list[BaseAuxiliary] | None = None,
priority: int = 15,
block: bool = False,
) -> Callable[[TTarget[R]], StepOut[R]]:
@@ -60,7 +59,7 @@ def wrapper(func: TTarget[R]):
nonlocal events
if not events:
events = (self._event,)
- return StepOut(list(events), func, providers, auxiliaries, priority, block) # type: ignore
+ return StepOut(list(events), func, providers, priority, block) # type: ignore
return wrapper
@@ -70,7 +69,6 @@ def register(
func: Callable[..., Any],
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -82,7 +80,6 @@ def register(
self,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -94,29 +91,27 @@ def register(
func: Callable[..., Any] | None = None,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
temporary: bool = False,
):
- _auxiliaries = auxiliaries or []
_providers = providers or []
wrapper = self.plugin._scope.register(
priority=priority,
- auxiliaries=[*self.auxiliaries, *_auxiliaries],
providers=[*self.providers, *_providers],
temporary=temporary,
publisher=self.publisher,
)
- if func:
- self.plugin.validate(func) # type: ignore
- return wrapper(func)
def decorator(func1, /):
self.plugin.validate(func1)
- return wrapper(func1)
+ sub = wrapper(func1)
+ sub.propagates(*self.propagators)
+ return sub
+ if func:
+ return decorator(func)
return decorator
@overload
@@ -125,7 +120,6 @@ def once(
func: Callable[..., Any],
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -136,7 +130,6 @@ def once(
self,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -147,14 +140,13 @@ def once(
func: Callable[..., Any] | None = None,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
):
if func:
- return self.register(func, priority=priority, auxiliaries=auxiliaries, providers=providers, temporary=True)
- return self.register(priority=priority, auxiliaries=auxiliaries, providers=providers, temporary=True)
+ return self.register(func, priority=priority, providers=providers, temporary=True)
+ return self.register(priority=priority, providers=providers, temporary=True)
on = register
handle = register
@@ -224,21 +216,18 @@ def collect(self, *disposes: Callable[[], None]):
self._dispose_callbacks.extend(disposes)
return self
- def update_filter(self, allow: dict, deny: dict):
- if not allow and not deny:
- return
- fter = Filter()
- if allow:
- fter = fter.and_(Filter.parse(allow))
- if deny:
- fter = fter.not_(Filter.parse(deny))
- if fter.steps:
- plugin_service.filters[self.id] = fter
-
def __post_init__(self):
self._scope = es.scope(self.id)
plugin_service.plugins[self.id] = self
- self.update_filter(self.config.pop("$allow", {}), self.config.pop("$deny", {}))
+ allow = self.config.pop("$allow", {})
+ deny = self.config.pop("$deny", {})
+ pat = {}
+ if allow:
+ pat["$and"] = allow
+ if deny:
+ pat["$not"] = deny
+ if pat:
+ self._scope.propagators.append(parse(pat))
if "$static" in self.config:
self.is_static = True
self.config.pop("$static")
@@ -280,6 +269,7 @@ def dispose(self):
plugin_service.plugins.pop(subplug, None)
self.subplugins.clear()
self._scope.dispose()
+ self._scope.propagators.clear()
del plugin_service.plugins[self.id]
del self.module
@@ -294,7 +284,6 @@ def use(
pub: Any,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -307,7 +296,6 @@ def use(
func: Callable[..., Any],
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -319,7 +307,6 @@ def use(
func: Callable[..., Any] | None = None,
*,
priority: int = 16,
- auxiliaries: list[BaseAuxiliary] | None = None,
providers: (
Sequence[Provider[Any] | type[Provider[Any]] | ProviderFactory | type[ProviderFactory]] | None
) = None,
@@ -336,8 +323,8 @@ def use(
raise LookupError(f"no publisher found: {pid}")
disp = PluginDispatcher(self, _publishers[pid].target)
if func:
- return disp.register(func=func, priority=priority, auxiliaries=auxiliaries, providers=providers)
- return disp.register(priority=priority, auxiliaries=auxiliaries, providers=providers)
+ return disp.register(func=func, priority=priority, providers=providers)
+ return disp.register(priority=priority, providers=providers)
def validate(self, func):
if func.__module__ != self.module.__name__:
diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py
index 0f51f92..a40af83 100644
--- a/arclet/entari/plugin/module.py
+++ b/arclet/entari/plugin/module.py
@@ -13,7 +13,7 @@
from ..config import EntariConfig
from ..logger import log
from .model import Plugin, PluginMetadata, _current_plugin
-from .service import AccessAuxiliary, plugin_service
+from .service import plugin_service
_SUBMODULE_WAITLIST: dict[str, set[str]] = {}
_ENSURE_IS_PLUGIN: set[str] = set()
@@ -217,11 +217,9 @@ def create_module(self, spec) -> Optional[ModuleType]:
return super().create_module(spec)
def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = None) -> None:
- is_sub = False
if plugin := plugin_service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None:
plugin.subplugins.add(module.__name__)
plugin_service._subplugined[module.__name__] = plugin.id
- is_sub = True
if self.loaded:
return
@@ -238,8 +236,6 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non
# enter plugin context
token = _current_plugin.set(plugin)
if not plugin.is_static:
- if not is_sub:
- plugin._scope.auxiliaries.append(AccessAuxiliary(plugin.id))
token1 = scope_ctx.set(plugin._scope)
try:
super().exec_module(module)
diff --git a/arclet/entari/plugin/service.py b/arclet/entari/plugin/service.py
index 5f39899..e04d15f 100644
--- a/arclet/entari/plugin/service.py
+++ b/arclet/entari/plugin/service.py
@@ -1,12 +1,11 @@
from typing import TYPE_CHECKING, Any, Callable
-from arclet.letoderea import BaseAuxiliary, es
+from arclet.letoderea import es
from launart import Launart, Service
from launart.status import Phase
from ..event.lifespan import Cleanup, Ready, Startup
from ..event.plugin import PluginUnloaded
-from ..filter import Filter
from ..logger import log
if TYPE_CHECKING:
@@ -17,7 +16,6 @@ class PluginManagerService(Service):
id = "entari.plugin.manager"
plugins: dict[str, "Plugin"]
- filters: dict[str, Filter]
_keep_values: dict[str, dict[str, "KeepingVariable"]]
_referents: dict[str, set[str]]
_unloaded: set[str]
@@ -32,7 +30,6 @@ def __init__(self):
self._unloaded = set()
self._subplugined = {}
self._apply = {}
- self.filters = {}
@property
def required(self) -> set[str]:
@@ -74,21 +71,3 @@ async def launch(self, manager: Launart):
plugin_service = PluginManagerService()
-
-
-class AccessAuxiliary(BaseAuxiliary):
- def __init__(self, plugin_id: str):
- self.plugin_id = plugin_id
-
- @property
- def id(self):
- return f"entari.plugin.access:{self.plugin_id}"
-
- async def on_prepare(self, interface):
- if self.plugin_id in plugin_service.filters:
- return await plugin_service.filters[self.plugin_id].on_prepare(interface)
- return True
-
- @property
- def after(self) -> set[str]:
- return {"entari.filter"}
diff --git a/arclet/entari/session.py b/arclet/entari/session.py
index 7e1ccb0..86b8834 100644
--- a/arclet/entari/session.py
+++ b/arclet/entari/session.py
@@ -3,7 +3,7 @@
from collections.abc import Iterable
from typing import Generic, NoReturn, TypeVar, cast
-from arclet.letoderea import ParsingStop, StepOut, es
+from arclet.letoderea import HandlerStop, StepOut, es
from satori.client.account import Account
from satori.client.protocol import ApiProtocol
from satori.const import Api
@@ -134,11 +134,11 @@ async def waiter(content: MessageChain, session: Session[MessageEvent]):
result = await step.wait(timeout=timeout)
if not result:
await self.send(timeout_message)
- raise ParsingStop()
+ raise HandlerStop()
return result
def stop(self) -> NoReturn:
- raise ParsingStop()
+ raise HandlerStop()
@property
def user(self) -> User:
diff --git a/example_plugin.py b/example_plugin.py
index 1fa1c09..19bbcf3 100644
--- a/example_plugin.py
+++ b/example_plugin.py
@@ -9,6 +9,7 @@
keeping,
scheduler,
local_data,
+ propagate
# Entari,
)
from arclet.entari.filter import Interval
@@ -29,7 +30,7 @@ async def cleanup():
@plug.dispatch(MessageCreatedEvent)
-@filter_.public().bind
+@propagate(filter_.public)
async def _(session: Session):
if session.content == "test":
resp = await session.send("This message will recall in 5s...")
@@ -41,22 +42,26 @@ async def _():
disp_message = plug.dispatch(MessageCreatedEvent)
-@disp_message.on(auxiliaries=[filter_.public().to_me().and_(lambda sess: str(sess.content) == "aaa")])
+@disp_message.on()
+@propagate(filter_.public, filter_.to_me, filter_(lambda sess: str(sess.content) == "aaa"), prepend=True)
async def _(session: Session):
return await session.send("Filter: public message, to me, and content is 'aaa'")
@disp_message
-@filter_.public().to_me().not_(lambda sess: str(sess.content) != "aaa")
+@propagate(filter_.public, filter_.to_me, filter_(lambda sess: str(sess.content) != "aaa"), prepend=True)
async def _(session: Session):
return await session.send("Filter: public message, to me, but content is not 'aaa'")
-@command.on("add {a} {b}", [Interval(2, limit_prompt="太快了")])
+@command.on("add {a} {b}")
def add(a: int, b: int):
return f"{a + b =}"
+add.propagate(Interval(2, limit_prompt="太快了"))
+
+
kept_data = keeping("foo", [], lambda x: x.clear())
diff --git a/pdm.lock b/pdm.lock
index 2b0d9e8..848cd13 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -5,7 +5,7 @@
groups = ["default", "cron", "dev", "full", "reload", "yaml"]
strategy = ["inherit_metadata"]
lock_version = "4.5.0"
-content_hash = "sha256:0b5b6083d142646d440e0336a267320429384d75253c32037020d8d06fffa871"
+content_hash = "sha256:ddadba99063dcd3feacc48f681d38b2f7cc7f4afb66a46734b816e9c7944338a"
[[metadata.targets]]
requires_python = ">=3.9"
@@ -180,7 +180,7 @@ files = [
[[package]]
name = "arclet-letoderea"
-version = "0.14.9"
+version = "0.15.0"
requires_python = ">=3.9"
summary = "A high-performance, simple-structured event system, relies on asyncio"
groups = ["default"]
@@ -188,8 +188,8 @@ dependencies = [
"tarina>=0.6.7",
]
files = [
- {file = "arclet_letoderea-0.14.9-py3-none-any.whl", hash = "sha256:0b7bc10dd504bc5ebbe09187367b7328e6f9cfeb11ffc9062653ca215d1a1a2c"},
- {file = "arclet_letoderea-0.14.9.tar.gz", hash = "sha256:337ef8752031f73ebdfb6c2ac89ae12b27dd90812251d865859bf340224c85f1"},
+ {file = "arclet_letoderea-0.15.0-py3-none-any.whl", hash = "sha256:ce74f602ac0fd6573c0cd686078ae60878e2f4a97e2eda582c351af4ab5b9a4b"},
+ {file = "arclet_letoderea-0.15.0.tar.gz", hash = "sha256:33141d60aa147c700ab20d0b620ae49d22f1a3b06f19e32d4d0f01b4b8851156"},
]
[[package]]
@@ -206,13 +206,13 @@ files = [
[[package]]
name = "attrs"
-version = "24.3.0"
+version = "25.1.0"
requires_python = ">=3.8"
summary = "Classes Without Boilerplate"
groups = ["default"]
files = [
- {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"},
- {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"},
+ {file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"},
+ {file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"},
]
[[package]]
diff --git a/pyproject.toml b/pyproject.toml
index f977cf3..3c4c2af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ authors = [
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},
]
dependencies = [
- "arclet-letoderea<0.15.0,>=0.14.9",
+ "arclet-letoderea<0.16.0,>=0.15.0",
"arclet-alconna<2.0,>=1.8.34",
"satori-python-core>=0.15.2",
"satori-python-client>=0.15.2",