Skip to content

Commit

Permalink
🍻 revert _Filter
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Jan 27, 2025
1 parent 3021fa5 commit f5dc762
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 100 deletions.
54 changes: 20 additions & 34 deletions arclet/entari/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,38 @@
import asyncio
from collections.abc import Awaitable
from datetime import datetime
from typing import Any, Callable, Final, Optional, Union
from typing_extensions import ParamSpec, TypeAlias
from typing import Callable, Final, Optional, Union
from typing_extensions import TypeAlias

from arclet.letoderea import STOP, Depends, Propagator
from arclet.letoderea.typing import Result, TTarget, run_sync
from arclet.letoderea import STOP, Propagator
from arclet.letoderea.typing import run_sync
from tarina import is_coroutinefunction

from . import common
from . import common, message
from ..message import MessageChain
from ..session import Session
from .common import parse as parse
from .message import direct_message, notice_me, public_message, reply_me, to_me

_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]
P = ParamSpec("P")


def wrapper(func: Callable[P, TTarget]) -> Callable[P, Any]:
def _wrapper(*args: P.args, **kwargs: P.kwargs):
async def _(res: Result[bool] = Depends(func(*args, **kwargs))):
if res.value is False:
return STOP

return _

return _wrapper


class _Filter:
user = staticmethod(wrapper(common._user))
guild = staticmethod(wrapper(common._guild))
channel = staticmethod(wrapper(common._channel))
self_ = staticmethod(wrapper(common._account))
platform = staticmethod(wrapper(common._platform))
direct = staticmethod(direct_message)
private = staticmethod(direct_message)
direct_message = staticmethod(direct_message)
public = staticmethod(public_message)
public_message = staticmethod(public_message)
notice_me = staticmethod(notice_me)
reply_me = staticmethod(reply_me)
to_me = staticmethod(to_me)
user = staticmethod(common.user)
guild = staticmethod(common.guild)
channel = staticmethod(common.channel)
self_ = staticmethod(common.account)
platform = staticmethod(common.platform)
direct = staticmethod(message.direct_message)
private = staticmethod(message.direct_message)
direct_message = staticmethod(message.direct_message)
public = staticmethod(message.public_message)
public_message = staticmethod(message.public_message)
notice_me = staticmethod(message.notice_me)
reply_me = staticmethod(message.reply_me)
to_me = staticmethod(message.to_me)

def __call__(self, func: _SessionFilter):
_func = run_sync(func) if is_coroutinefunction(func) else func
_func = func if is_coroutinefunction(func) else run_sync(func)

async def _(session: Session):
if not await _func(session): # type: ignore
Expand Down Expand Up @@ -74,10 +61,9 @@ async def before(self, session: Optional[Session] = None):
await session.send(self.limit_prompt)
return STOP

async def after(self) -> Optional[bool]:
async def after(self):
if self.success:
self.last_time = datetime.now()
return True

def compose(self):
yield self.before, True, 15
Expand Down
136 changes: 78 additions & 58 deletions arclet/entari/filter/common.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,63 @@
from typing import Callable, Union
from collections.abc import Awaitable
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeAlias

from arclet.letoderea import STOP, Contexts, Depends, Propagator, propagate
from arclet.letoderea.typing import Result
from satori import Channel, ChannelType, Guild, User
from arclet.letoderea import STOP, Propagator

from ..session import Session
from .message import direct_message, notice_me, public_message, reply_me, to_me


def _user(*ids: str):
async def check_user(user: User):
return Result(user.id in ids if ids else True)
def user(*ids: str):
async def check_user(session: Session):
return (None if session.user.id in ids else STOP) if ids else None

return check_user


def _channel(*ids: str):
async def check_channel(channel: Channel):
return Result(channel.id in ids if ids else True)
def channel(*ids: str):
async def check_channel(session: Session):
return (None if session.channel.id in ids else STOP) if ids else None

return check_channel


def _guild(*ids: str):
async def check_guild(guild: Guild):
return Result(guild.id in ids if ids else True)
def guild(*ids: str):
async def check_guild(session: Session):
return (None if session.guild.id in ids else STOP) if ids else None

return check_guild


def _account(*ids: str):
def account(*ids: str):
async def check_account(session: Session):
return Result(session.account.self_id in ids)
return (None if session.account.self_id in ids else STOP) if ids else None

return check_account


def _platform(*ids: str):
def platform(*ids: str):
async def check_platform(session: Session):
return Result(session.account.platform in ids)
return (None if session.account.platform in ids else STOP) if ids else None

return check_platform


_keys = {
"user": (_user, 2),
"guild": (_guild, 3),
"channel": (_channel, 4),
"self": (_account, 1),
"platform": (_platform, 0),
"user": (user, 2),
"guild": (guild, 3),
"channel": (channel, 4),
"self": (account, 1),
"platform": (platform, 0),
"direct": (lambda: direct_message, 5),
"private": (lambda: direct_message, 5),
"public": (lambda: public_message, 6),
}

_mess_keys = {
"direct": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5),
"private": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5),
"public": (lambda channel: Result(channel.type != ChannelType.DIRECT), 6),
"reply_me": (lambda is_reply_me=False: Result(is_reply_me), 7),
"notice_me": (lambda is_notice_me=False: Result(is_notice_me), 8),
"to_me": (lambda is_reply_me=False, is_notice_me=False: Result(is_reply_me or is_notice_me), 9),
"reply_me": (reply_me, 7),
"notice_me": (notice_me, 8),
"to_me": (to_me, 9),
}

_op_keys = {
Expand All @@ -73,53 +73,73 @@ async def check_platform(session: Session):


class _Filter(Propagator):
def __init__(self):
self.step: dict[int, Callable] = {}
self.ops = []

def get_flow(self, entry: bool = False):
if not self.step:
flow = lambda: True

else:
steps = [slot[1] for slot in sorted(self.step.items(), key=lambda x: x[0])]

@propagate(*steps, prepend=True)
async def flow(ctx: Contexts):
return ctx.get("$result", False)

other = []
def __init__(
self,
steps: list[Callable[[Session], Awaitable[bool]]],
mess: list[Callable[[bool, bool], bool]],
ops: list[tuple[str, "_Filter"]],
):
self.steps = steps
self.mess = mess
self.ops = ops

async def check(self, session: Optional[Session] = None, is_reply_me: bool = False, is_notice_me: bool = False):
res = True
if session and self.steps:
res = all([await step(session) for step in self.steps])
if self.mess:
res = res and all(mess(is_reply_me, is_notice_me) for mess in self.mess)
for op, f_ in self.ops:
if op == "and":
other.append(lambda result, res=Depends(f_.get_flow()): Result(result and res))
res = res and (await f_.check(session, is_reply_me, is_notice_me)) is None
elif op == "or":
other.append(lambda result, res=Depends(f_.get_flow()): Result(result or res))
res = res or (await f_.check(session, is_reply_me, is_notice_me)) is None
else:
other.append(lambda result, res=Depends(f_.get_flow()): Result(result and not res))
propagate(*other)(flow)
if entry:
propagate(lambda result: None if result else STOP)(flow)
return flow
res = res and (await f_.check(session, is_reply_me, is_notice_me)) is STOP
return None if res else STOP

def compose(self):
yield self.get_flow(entry=True), True, 0
yield self.check, True, 0


def _wrapper(func: Callable[[Session], Any]):
async def _(session: Session):
return True if await func(session) is None else False

return _


def parse(patterns: PATTERNS):
f = _Filter()
step: dict[int, Callable[[Session], Awaitable[bool]]] = {}
mess: dict[int, Callable[[bool, bool], bool]] = {}
ops: list[tuple[str, _Filter]] = []

for key, value in patterns.items():
if key in _keys:
f.step[_keys[key][1]] = _keys[key][0](*value)
step[_keys[key][1]] = _wrapper(_keys[key][0](*value) if isinstance(value, list) else _keys[key][0]())
elif key in _mess_keys:
if value is True:
f.step[_mess_keys[key][1]] = _mess_keys[key][0]
if key == "reply_me":
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
True if _mess_keys[key][0](is_reply_me) is None else False
)
elif key == "notice_me":
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
True if _mess_keys[key][0](is_notice_me) is None else False
)
else:
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
True if _mess_keys[key][0](is_reply_me, is_notice_me) is None else False
)
elif key in _op_keys:
op = _op_keys[key]
if not isinstance(value, dict):
raise ValueError(f"Expect a dict for operator {key}")
f.ops.append((op, parse(value)))
ops.append((op, parse(value)))
else:
raise ValueError(f"Unknown key: {key}")

return f
return _Filter(
steps=[slot[1] for slot in sorted(step.items(), key=lambda x: x[0])],
mess=[slot[1] for slot in sorted(mess.items(), key=lambda x: x[0])],
ops=ops,
)
18 changes: 10 additions & 8 deletions arclet/entari/filter/message.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
from arclet.letoderea import STOP
from satori import Channel, ChannelType
from satori import ChannelType

from ..session import Session

async def direct_message(channel: Channel):
if channel.type != ChannelType.DIRECT:

async def direct_message(sess: Session):
if sess.channel.type != ChannelType.DIRECT:
return STOP


async def public_message(channel: Channel):
if channel.type == ChannelType.DIRECT:
async def public_message(sess: Session):
if sess.channel.type == ChannelType.DIRECT:
return STOP


async def reply_me(is_reply_me: bool = False):
def reply_me(is_reply_me: bool = False):
if not is_reply_me:
return STOP


async def notice_me(is_notice_me: bool = False):
def notice_me(is_notice_me: bool = False):
if not is_notice_me:
return STOP


async def to_me(is_reply_me: bool = False, is_notice_me: bool = False):
def to_me(is_reply_me: bool = False, is_notice_me: bool = False):
if not is_reply_me and not is_notice_me:
return STOP

0 comments on commit f5dc762

Please sign in to comment.