diff --git a/arclet/entari/filter/__init__.py b/arclet/entari/filter/__init__.py index c864c5d..a5b2ae2 100644 --- a/arclet/entari/filter/__init__.py +++ b/arclet/entari/filter/__init__.py @@ -1,51 +1,38 @@ import asyncio from collections.abc import Awaitable from datetime import datetime -from typing import Any, Callable, Final, Optional, Union -from typing_extensions import ParamSpec, TypeAlias +from typing import Callable, Final, Optional, Union +from typing_extensions import TypeAlias -from arclet.letoderea import STOP, Depends, Propagator -from arclet.letoderea.typing import Result, TTarget, run_sync +from arclet.letoderea import STOP, Propagator +from arclet.letoderea.typing import run_sync from tarina import is_coroutinefunction -from . import common +from . import common, message from ..message import MessageChain from ..session import Session from .common import parse as parse -from .message import direct_message, notice_me, public_message, reply_me, to_me _SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]] -P = ParamSpec("P") - - -def wrapper(func: Callable[P, TTarget]) -> Callable[P, Any]: - def _wrapper(*args: P.args, **kwargs: P.kwargs): - async def _(res: Result[bool] = Depends(func(*args, **kwargs))): - if res.value is False: - return STOP - - return _ - - return _wrapper class _Filter: - user = staticmethod(wrapper(common._user)) - guild = staticmethod(wrapper(common._guild)) - channel = staticmethod(wrapper(common._channel)) - self_ = staticmethod(wrapper(common._account)) - platform = staticmethod(wrapper(common._platform)) - direct = staticmethod(direct_message) - private = staticmethod(direct_message) - direct_message = staticmethod(direct_message) - public = staticmethod(public_message) - public_message = staticmethod(public_message) - notice_me = staticmethod(notice_me) - reply_me = staticmethod(reply_me) - to_me = staticmethod(to_me) + user = staticmethod(common.user) + guild = staticmethod(common.guild) + channel = staticmethod(common.channel) + self_ = staticmethod(common.account) + platform = staticmethod(common.platform) + direct = staticmethod(message.direct_message) + private = staticmethod(message.direct_message) + direct_message = staticmethod(message.direct_message) + public = staticmethod(message.public_message) + public_message = staticmethod(message.public_message) + notice_me = staticmethod(message.notice_me) + reply_me = staticmethod(message.reply_me) + to_me = staticmethod(message.to_me) def __call__(self, func: _SessionFilter): - _func = run_sync(func) if is_coroutinefunction(func) else func + _func = func if is_coroutinefunction(func) else run_sync(func) async def _(session: Session): if not await _func(session): # type: ignore @@ -74,10 +61,9 @@ async def before(self, session: Optional[Session] = None): await session.send(self.limit_prompt) return STOP - async def after(self) -> Optional[bool]: + async def after(self): if self.success: self.last_time = datetime.now() - return True def compose(self): yield self.before, True, 15 diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py index b6b3fa7..f4e29cb 100644 --- a/arclet/entari/filter/common.py +++ b/arclet/entari/filter/common.py @@ -1,63 +1,63 @@ -from typing import Callable, Union +from collections.abc import Awaitable +from typing import Any, Callable, Optional, Union from typing_extensions import TypeAlias -from arclet.letoderea import STOP, Contexts, Depends, Propagator, propagate -from arclet.letoderea.typing import Result -from satori import Channel, ChannelType, Guild, User +from arclet.letoderea import STOP, Propagator from ..session import Session +from .message import direct_message, notice_me, public_message, reply_me, to_me -def _user(*ids: str): - async def check_user(user: User): - return Result(user.id in ids if ids else True) +def user(*ids: str): + async def check_user(session: Session): + return (None if session.user.id in ids else STOP) if ids else None return check_user -def _channel(*ids: str): - async def check_channel(channel: Channel): - return Result(channel.id in ids if ids else True) +def channel(*ids: str): + async def check_channel(session: Session): + return (None if session.channel.id in ids else STOP) if ids else None return check_channel -def _guild(*ids: str): - async def check_guild(guild: Guild): - return Result(guild.id in ids if ids else True) +def guild(*ids: str): + async def check_guild(session: Session): + return (None if session.guild.id in ids else STOP) if ids else None return check_guild -def _account(*ids: str): +def account(*ids: str): async def check_account(session: Session): - return Result(session.account.self_id in ids) + return (None if session.account.self_id in ids else STOP) if ids else None return check_account -def _platform(*ids: str): +def platform(*ids: str): async def check_platform(session: Session): - return Result(session.account.platform in ids) + return (None if session.account.platform in ids else STOP) if ids else None return check_platform _keys = { - "user": (_user, 2), - "guild": (_guild, 3), - "channel": (_channel, 4), - "self": (_account, 1), - "platform": (_platform, 0), + "user": (user, 2), + "guild": (guild, 3), + "channel": (channel, 4), + "self": (account, 1), + "platform": (platform, 0), + "direct": (lambda: direct_message, 5), + "private": (lambda: direct_message, 5), + "public": (lambda: public_message, 6), } _mess_keys = { - "direct": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5), - "private": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5), - "public": (lambda channel: Result(channel.type != ChannelType.DIRECT), 6), - "reply_me": (lambda is_reply_me=False: Result(is_reply_me), 7), - "notice_me": (lambda is_notice_me=False: Result(is_notice_me), 8), - "to_me": (lambda is_reply_me=False, is_notice_me=False: Result(is_reply_me or is_notice_me), 9), + "reply_me": (reply_me, 7), + "notice_me": (notice_me, 8), + "to_me": (to_me, 9), } _op_keys = { @@ -73,53 +73,73 @@ async def check_platform(session: Session): class _Filter(Propagator): - def __init__(self): - self.step: dict[int, Callable] = {} - self.ops = [] - - def get_flow(self, entry: bool = False): - if not self.step: - flow = lambda: True - - else: - steps = [slot[1] for slot in sorted(self.step.items(), key=lambda x: x[0])] - - @propagate(*steps, prepend=True) - async def flow(ctx: Contexts): - return ctx.get("$result", False) - - other = [] + def __init__( + self, + steps: list[Callable[[Session], Awaitable[bool]]], + mess: list[Callable[[bool, bool], bool]], + ops: list[tuple[str, "_Filter"]], + ): + self.steps = steps + self.mess = mess + self.ops = ops + + async def check(self, session: Optional[Session] = None, is_reply_me: bool = False, is_notice_me: bool = False): + res = True + if session and self.steps: + res = all([await step(session) for step in self.steps]) + if self.mess: + res = res and all(mess(is_reply_me, is_notice_me) for mess in self.mess) for op, f_ in self.ops: if op == "and": - other.append(lambda result, res=Depends(f_.get_flow()): Result(result and res)) + res = res and (await f_.check(session, is_reply_me, is_notice_me)) is None elif op == "or": - other.append(lambda result, res=Depends(f_.get_flow()): Result(result or res)) + res = res or (await f_.check(session, is_reply_me, is_notice_me)) is None else: - other.append(lambda result, res=Depends(f_.get_flow()): Result(result and not res)) - propagate(*other)(flow) - if entry: - propagate(lambda result: None if result else STOP)(flow) - return flow + res = res and (await f_.check(session, is_reply_me, is_notice_me)) is STOP + return None if res else STOP def compose(self): - yield self.get_flow(entry=True), True, 0 + yield self.check, True, 0 + + +def _wrapper(func: Callable[[Session], Any]): + async def _(session: Session): + return True if await func(session) is None else False + + return _ def parse(patterns: PATTERNS): - f = _Filter() + step: dict[int, Callable[[Session], Awaitable[bool]]] = {} + mess: dict[int, Callable[[bool, bool], bool]] = {} + ops: list[tuple[str, _Filter]] = [] for key, value in patterns.items(): if key in _keys: - f.step[_keys[key][1]] = _keys[key][0](*value) + step[_keys[key][1]] = _wrapper(_keys[key][0](*value) if isinstance(value, list) else _keys[key][0]()) elif key in _mess_keys: - if value is True: - f.step[_mess_keys[key][1]] = _mess_keys[key][0] + if key == "reply_me": + mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: ( + True if _mess_keys[key][0](is_reply_me) is None else False + ) + elif key == "notice_me": + mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: ( + True if _mess_keys[key][0](is_notice_me) is None else False + ) + else: + mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: ( + True if _mess_keys[key][0](is_reply_me, is_notice_me) is None else False + ) elif key in _op_keys: op = _op_keys[key] if not isinstance(value, dict): raise ValueError(f"Expect a dict for operator {key}") - f.ops.append((op, parse(value))) + ops.append((op, parse(value))) else: raise ValueError(f"Unknown key: {key}") - return f + return _Filter( + steps=[slot[1] for slot in sorted(step.items(), key=lambda x: x[0])], + mess=[slot[1] for slot in sorted(mess.items(), key=lambda x: x[0])], + ops=ops, + ) diff --git a/arclet/entari/filter/message.py b/arclet/entari/filter/message.py index d17c1b0..1630b55 100644 --- a/arclet/entari/filter/message.py +++ b/arclet/entari/filter/message.py @@ -1,27 +1,29 @@ from arclet.letoderea import STOP -from satori import Channel, ChannelType +from satori import ChannelType +from ..session import Session -async def direct_message(channel: Channel): - if channel.type != ChannelType.DIRECT: + +async def direct_message(sess: Session): + if sess.channel.type != ChannelType.DIRECT: return STOP -async def public_message(channel: Channel): - if channel.type == ChannelType.DIRECT: +async def public_message(sess: Session): + if sess.channel.type == ChannelType.DIRECT: return STOP -async def reply_me(is_reply_me: bool = False): +def reply_me(is_reply_me: bool = False): if not is_reply_me: return STOP -async def notice_me(is_notice_me: bool = False): +def notice_me(is_notice_me: bool = False): if not is_notice_me: return STOP -async def to_me(is_reply_me: bool = False, is_notice_me: bool = False): +def to_me(is_reply_me: bool = False, is_notice_me: bool = False): if not is_reply_me and not is_notice_me: return STOP