Skip to content

Commit f5dc762

Browse files
committed
🍻 revert _Filter
1 parent 3021fa5 commit f5dc762

File tree

3 files changed

+108
-100
lines changed

3 files changed

+108
-100
lines changed

arclet/entari/filter/__init__.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,38 @@
11
import asyncio
22
from collections.abc import Awaitable
33
from datetime import datetime
4-
from typing import Any, Callable, Final, Optional, Union
5-
from typing_extensions import ParamSpec, TypeAlias
4+
from typing import Callable, Final, Optional, Union
5+
from typing_extensions import TypeAlias
66

7-
from arclet.letoderea import STOP, Depends, Propagator
8-
from arclet.letoderea.typing import Result, TTarget, run_sync
7+
from arclet.letoderea import STOP, Propagator
8+
from arclet.letoderea.typing import run_sync
99
from tarina import is_coroutinefunction
1010

11-
from . import common
11+
from . import common, message
1212
from ..message import MessageChain
1313
from ..session import Session
1414
from .common import parse as parse
15-
from .message import direct_message, notice_me, public_message, reply_me, to_me
1615

1716
_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]]
18-
P = ParamSpec("P")
19-
20-
21-
def wrapper(func: Callable[P, TTarget]) -> Callable[P, Any]:
22-
def _wrapper(*args: P.args, **kwargs: P.kwargs):
23-
async def _(res: Result[bool] = Depends(func(*args, **kwargs))):
24-
if res.value is False:
25-
return STOP
26-
27-
return _
28-
29-
return _wrapper
3017

3118

3219
class _Filter:
33-
user = staticmethod(wrapper(common._user))
34-
guild = staticmethod(wrapper(common._guild))
35-
channel = staticmethod(wrapper(common._channel))
36-
self_ = staticmethod(wrapper(common._account))
37-
platform = staticmethod(wrapper(common._platform))
38-
direct = staticmethod(direct_message)
39-
private = staticmethod(direct_message)
40-
direct_message = staticmethod(direct_message)
41-
public = staticmethod(public_message)
42-
public_message = staticmethod(public_message)
43-
notice_me = staticmethod(notice_me)
44-
reply_me = staticmethod(reply_me)
45-
to_me = staticmethod(to_me)
20+
user = staticmethod(common.user)
21+
guild = staticmethod(common.guild)
22+
channel = staticmethod(common.channel)
23+
self_ = staticmethod(common.account)
24+
platform = staticmethod(common.platform)
25+
direct = staticmethod(message.direct_message)
26+
private = staticmethod(message.direct_message)
27+
direct_message = staticmethod(message.direct_message)
28+
public = staticmethod(message.public_message)
29+
public_message = staticmethod(message.public_message)
30+
notice_me = staticmethod(message.notice_me)
31+
reply_me = staticmethod(message.reply_me)
32+
to_me = staticmethod(message.to_me)
4633

4734
def __call__(self, func: _SessionFilter):
48-
_func = run_sync(func) if is_coroutinefunction(func) else func
35+
_func = func if is_coroutinefunction(func) else run_sync(func)
4936

5037
async def _(session: Session):
5138
if not await _func(session): # type: ignore
@@ -74,10 +61,9 @@ async def before(self, session: Optional[Session] = None):
7461
await session.send(self.limit_prompt)
7562
return STOP
7663

77-
async def after(self) -> Optional[bool]:
64+
async def after(self):
7865
if self.success:
7966
self.last_time = datetime.now()
80-
return True
8167

8268
def compose(self):
8369
yield self.before, True, 15

arclet/entari/filter/common.py

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,63 @@
1-
from typing import Callable, Union
1+
from collections.abc import Awaitable
2+
from typing import Any, Callable, Optional, Union
23
from typing_extensions import TypeAlias
34

4-
from arclet.letoderea import STOP, Contexts, Depends, Propagator, propagate
5-
from arclet.letoderea.typing import Result
6-
from satori import Channel, ChannelType, Guild, User
5+
from arclet.letoderea import STOP, Propagator
76

87
from ..session import Session
8+
from .message import direct_message, notice_me, public_message, reply_me, to_me
99

1010

11-
def _user(*ids: str):
12-
async def check_user(user: User):
13-
return Result(user.id in ids if ids else True)
11+
def user(*ids: str):
12+
async def check_user(session: Session):
13+
return (None if session.user.id in ids else STOP) if ids else None
1414

1515
return check_user
1616

1717

18-
def _channel(*ids: str):
19-
async def check_channel(channel: Channel):
20-
return Result(channel.id in ids if ids else True)
18+
def channel(*ids: str):
19+
async def check_channel(session: Session):
20+
return (None if session.channel.id in ids else STOP) if ids else None
2121

2222
return check_channel
2323

2424

25-
def _guild(*ids: str):
26-
async def check_guild(guild: Guild):
27-
return Result(guild.id in ids if ids else True)
25+
def guild(*ids: str):
26+
async def check_guild(session: Session):
27+
return (None if session.guild.id in ids else STOP) if ids else None
2828

2929
return check_guild
3030

3131

32-
def _account(*ids: str):
32+
def account(*ids: str):
3333
async def check_account(session: Session):
34-
return Result(session.account.self_id in ids)
34+
return (None if session.account.self_id in ids else STOP) if ids else None
3535

3636
return check_account
3737

3838

39-
def _platform(*ids: str):
39+
def platform(*ids: str):
4040
async def check_platform(session: Session):
41-
return Result(session.account.platform in ids)
41+
return (None if session.account.platform in ids else STOP) if ids else None
4242

4343
return check_platform
4444

4545

4646
_keys = {
47-
"user": (_user, 2),
48-
"guild": (_guild, 3),
49-
"channel": (_channel, 4),
50-
"self": (_account, 1),
51-
"platform": (_platform, 0),
47+
"user": (user, 2),
48+
"guild": (guild, 3),
49+
"channel": (channel, 4),
50+
"self": (account, 1),
51+
"platform": (platform, 0),
52+
"direct": (lambda: direct_message, 5),
53+
"private": (lambda: direct_message, 5),
54+
"public": (lambda: public_message, 6),
5255
}
5356

5457
_mess_keys = {
55-
"direct": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5),
56-
"private": (lambda channel: Result(channel.type == ChannelType.DIRECT), 5),
57-
"public": (lambda channel: Result(channel.type != ChannelType.DIRECT), 6),
58-
"reply_me": (lambda is_reply_me=False: Result(is_reply_me), 7),
59-
"notice_me": (lambda is_notice_me=False: Result(is_notice_me), 8),
60-
"to_me": (lambda is_reply_me=False, is_notice_me=False: Result(is_reply_me or is_notice_me), 9),
58+
"reply_me": (reply_me, 7),
59+
"notice_me": (notice_me, 8),
60+
"to_me": (to_me, 9),
6161
}
6262

6363
_op_keys = {
@@ -73,53 +73,73 @@ async def check_platform(session: Session):
7373

7474

7575
class _Filter(Propagator):
76-
def __init__(self):
77-
self.step: dict[int, Callable] = {}
78-
self.ops = []
79-
80-
def get_flow(self, entry: bool = False):
81-
if not self.step:
82-
flow = lambda: True
83-
84-
else:
85-
steps = [slot[1] for slot in sorted(self.step.items(), key=lambda x: x[0])]
86-
87-
@propagate(*steps, prepend=True)
88-
async def flow(ctx: Contexts):
89-
return ctx.get("$result", False)
90-
91-
other = []
76+
def __init__(
77+
self,
78+
steps: list[Callable[[Session], Awaitable[bool]]],
79+
mess: list[Callable[[bool, bool], bool]],
80+
ops: list[tuple[str, "_Filter"]],
81+
):
82+
self.steps = steps
83+
self.mess = mess
84+
self.ops = ops
85+
86+
async def check(self, session: Optional[Session] = None, is_reply_me: bool = False, is_notice_me: bool = False):
87+
res = True
88+
if session and self.steps:
89+
res = all([await step(session) for step in self.steps])
90+
if self.mess:
91+
res = res and all(mess(is_reply_me, is_notice_me) for mess in self.mess)
9292
for op, f_ in self.ops:
9393
if op == "and":
94-
other.append(lambda result, res=Depends(f_.get_flow()): Result(result and res))
94+
res = res and (await f_.check(session, is_reply_me, is_notice_me)) is None
9595
elif op == "or":
96-
other.append(lambda result, res=Depends(f_.get_flow()): Result(result or res))
96+
res = res or (await f_.check(session, is_reply_me, is_notice_me)) is None
9797
else:
98-
other.append(lambda result, res=Depends(f_.get_flow()): Result(result and not res))
99-
propagate(*other)(flow)
100-
if entry:
101-
propagate(lambda result: None if result else STOP)(flow)
102-
return flow
98+
res = res and (await f_.check(session, is_reply_me, is_notice_me)) is STOP
99+
return None if res else STOP
103100

104101
def compose(self):
105-
yield self.get_flow(entry=True), True, 0
102+
yield self.check, True, 0
103+
104+
105+
def _wrapper(func: Callable[[Session], Any]):
106+
async def _(session: Session):
107+
return True if await func(session) is None else False
108+
109+
return _
106110

107111

108112
def parse(patterns: PATTERNS):
109-
f = _Filter()
113+
step: dict[int, Callable[[Session], Awaitable[bool]]] = {}
114+
mess: dict[int, Callable[[bool, bool], bool]] = {}
115+
ops: list[tuple[str, _Filter]] = []
110116

111117
for key, value in patterns.items():
112118
if key in _keys:
113-
f.step[_keys[key][1]] = _keys[key][0](*value)
119+
step[_keys[key][1]] = _wrapper(_keys[key][0](*value) if isinstance(value, list) else _keys[key][0]())
114120
elif key in _mess_keys:
115-
if value is True:
116-
f.step[_mess_keys[key][1]] = _mess_keys[key][0]
121+
if key == "reply_me":
122+
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
123+
True if _mess_keys[key][0](is_reply_me) is None else False
124+
)
125+
elif key == "notice_me":
126+
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
127+
True if _mess_keys[key][0](is_notice_me) is None else False
128+
)
129+
else:
130+
mess[_mess_keys[key][1]] = lambda is_reply_me, is_notice_me: (
131+
True if _mess_keys[key][0](is_reply_me, is_notice_me) is None else False
132+
)
117133
elif key in _op_keys:
118134
op = _op_keys[key]
119135
if not isinstance(value, dict):
120136
raise ValueError(f"Expect a dict for operator {key}")
121-
f.ops.append((op, parse(value)))
137+
ops.append((op, parse(value)))
122138
else:
123139
raise ValueError(f"Unknown key: {key}")
124140

125-
return f
141+
return _Filter(
142+
steps=[slot[1] for slot in sorted(step.items(), key=lambda x: x[0])],
143+
mess=[slot[1] for slot in sorted(mess.items(), key=lambda x: x[0])],
144+
ops=ops,
145+
)

arclet/entari/filter/message.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from arclet.letoderea import STOP
2-
from satori import Channel, ChannelType
2+
from satori import ChannelType
33

4+
from ..session import Session
45

5-
async def direct_message(channel: Channel):
6-
if channel.type != ChannelType.DIRECT:
6+
7+
async def direct_message(sess: Session):
8+
if sess.channel.type != ChannelType.DIRECT:
79
return STOP
810

911

10-
async def public_message(channel: Channel):
11-
if channel.type == ChannelType.DIRECT:
12+
async def public_message(sess: Session):
13+
if sess.channel.type == ChannelType.DIRECT:
1214
return STOP
1315

1416

15-
async def reply_me(is_reply_me: bool = False):
17+
def reply_me(is_reply_me: bool = False):
1618
if not is_reply_me:
1719
return STOP
1820

1921

20-
async def notice_me(is_notice_me: bool = False):
22+
def notice_me(is_notice_me: bool = False):
2123
if not is_notice_me:
2224
return STOP
2325

2426

25-
async def to_me(is_reply_me: bool = False, is_notice_me: bool = False):
27+
def to_me(is_reply_me: bool = False, is_notice_me: bool = False):
2628
if not is_reply_me and not is_notice_me:
2729
return STOP

0 commit comments

Comments
 (0)