1
- from typing import Callable , Union
1
+ from collections .abc import Awaitable
2
+ from typing import Any , Callable , Optional , Union
2
3
from typing_extensions import TypeAlias
3
4
4
- from arclet .letoderea import STOP , Contexts , Depends , Propagator , propagate
5
- from arclet .letoderea .typing import Result
6
- from satori import Channel , ChannelType , Guild , User
5
+ from arclet .letoderea import STOP , Propagator
7
6
8
7
from ..session import Session
8
+ from .message import direct_message , notice_me , public_message , reply_me , to_me
9
9
10
10
11
- def _user (* ids : str ):
12
- async def check_user (user : User ):
13
- return Result ( user .id in ids if ids else True )
11
+ def user (* ids : str ):
12
+ async def check_user (session : Session ):
13
+ return ( None if session . user .id in ids else STOP ) if ids else None
14
14
15
15
return check_user
16
16
17
17
18
- def _channel (* ids : str ):
19
- async def check_channel (channel : Channel ):
20
- return Result ( channel .id in ids if ids else True )
18
+ def channel (* ids : str ):
19
+ async def check_channel (session : Session ):
20
+ return ( None if session . channel .id in ids else STOP ) if ids else None
21
21
22
22
return check_channel
23
23
24
24
25
- def _guild (* ids : str ):
26
- async def check_guild (guild : Guild ):
27
- return Result ( guild .id in ids if ids else True )
25
+ def guild (* ids : str ):
26
+ async def check_guild (session : Session ):
27
+ return ( None if session . guild .id in ids else STOP ) if ids else None
28
28
29
29
return check_guild
30
30
31
31
32
- def _account (* ids : str ):
32
+ def account (* ids : str ):
33
33
async def check_account (session : Session ):
34
- return Result ( session .account .self_id in ids )
34
+ return ( None if session .account .self_id in ids else STOP ) if ids else None
35
35
36
36
return check_account
37
37
38
38
39
- def _platform (* ids : str ):
39
+ def platform (* ids : str ):
40
40
async def check_platform (session : Session ):
41
- return Result ( session .account .platform in ids )
41
+ return ( None if session .account .platform in ids else STOP ) if ids else None
42
42
43
43
return check_platform
44
44
45
45
46
46
_keys = {
47
- "user" : (_user , 2 ),
48
- "guild" : (_guild , 3 ),
49
- "channel" : (_channel , 4 ),
50
- "self" : (_account , 1 ),
51
- "platform" : (_platform , 0 ),
47
+ "user" : (user , 2 ),
48
+ "guild" : (guild , 3 ),
49
+ "channel" : (channel , 4 ),
50
+ "self" : (account , 1 ),
51
+ "platform" : (platform , 0 ),
52
+ "direct" : (lambda : direct_message , 5 ),
53
+ "private" : (lambda : direct_message , 5 ),
54
+ "public" : (lambda : public_message , 6 ),
52
55
}
53
56
54
57
_mess_keys = {
55
- "direct" : (lambda channel : Result (channel .type == ChannelType .DIRECT ), 5 ),
56
- "private" : (lambda channel : Result (channel .type == ChannelType .DIRECT ), 5 ),
57
- "public" : (lambda channel : Result (channel .type != ChannelType .DIRECT ), 6 ),
58
- "reply_me" : (lambda is_reply_me = False : Result (is_reply_me ), 7 ),
59
- "notice_me" : (lambda is_notice_me = False : Result (is_notice_me ), 8 ),
60
- "to_me" : (lambda is_reply_me = False , is_notice_me = False : Result (is_reply_me or is_notice_me ), 9 ),
58
+ "reply_me" : (reply_me , 7 ),
59
+ "notice_me" : (notice_me , 8 ),
60
+ "to_me" : (to_me , 9 ),
61
61
}
62
62
63
63
_op_keys = {
@@ -73,53 +73,73 @@ async def check_platform(session: Session):
73
73
74
74
75
75
class _Filter (Propagator ):
76
- def __init__ (self ):
77
- self . step : dict [ int , Callable ] = {}
78
- self . ops = []
79
-
80
- def get_flow ( self , entry : bool = False ):
81
- if not self . step :
82
- flow = lambda : True
83
-
84
- else :
85
- steps = [ slot [ 1 ] for slot in sorted ( self . step . items (), key = lambda x : x [ 0 ])]
86
-
87
- @ propagate ( * steps , prepend = True )
88
- async def flow ( ctx : Contexts ) :
89
- return ctx . get ( "$result" , False )
90
-
91
- other = []
76
+ def __init__ (
77
+ self ,
78
+ steps : list [ Callable [[ Session ], Awaitable [ bool ]]],
79
+ mess : list [ Callable [[ bool , bool ], bool ]],
80
+ ops : list [ tuple [ str , "_Filter" ]],
81
+ ) :
82
+ self . steps = steps
83
+ self . mess = mess
84
+ self . ops = ops
85
+
86
+ async def check ( self , session : Optional [ Session ] = None , is_reply_me : bool = False , is_notice_me : bool = False ):
87
+ res = True
88
+ if session and self . steps :
89
+ res = all ([ await step ( session ) for step in self . steps ] )
90
+ if self . mess :
91
+ res = res and all ( mess ( is_reply_me , is_notice_me ) for mess in self . mess )
92
92
for op , f_ in self .ops :
93
93
if op == "and" :
94
- other . append ( lambda result , res = Depends ( f_ .get_flow ()): Result ( result and res ))
94
+ res = res and ( await f_ .check ( session , is_reply_me , is_notice_me )) is None
95
95
elif op == "or" :
96
- other . append ( lambda result , res = Depends ( f_ .get_flow ()): Result ( result or res ))
96
+ res = res or ( await f_ .check ( session , is_reply_me , is_notice_me )) is None
97
97
else :
98
- other .append (lambda result , res = Depends (f_ .get_flow ()): Result (result and not res ))
99
- propagate (* other )(flow )
100
- if entry :
101
- propagate (lambda result : None if result else STOP )(flow )
102
- return flow
98
+ res = res and (await f_ .check (session , is_reply_me , is_notice_me )) is STOP
99
+ return None if res else STOP
103
100
104
101
def compose (self ):
105
- yield self .get_flow (entry = True ), True , 0
102
+ yield self .check , True , 0
103
+
104
+
105
+ def _wrapper (func : Callable [[Session ], Any ]):
106
+ async def _ (session : Session ):
107
+ return True if await func (session ) is None else False
108
+
109
+ return _
106
110
107
111
108
112
def parse (patterns : PATTERNS ):
109
- f = _Filter ()
113
+ step : dict [int , Callable [[Session ], Awaitable [bool ]]] = {}
114
+ mess : dict [int , Callable [[bool , bool ], bool ]] = {}
115
+ ops : list [tuple [str , _Filter ]] = []
110
116
111
117
for key , value in patterns .items ():
112
118
if key in _keys :
113
- f . step [_keys [key ][1 ]] = _keys [key ][0 ](* value )
119
+ step [_keys [key ][1 ]] = _wrapper ( _keys [key ][0 ](* value ) if isinstance ( value , list ) else _keys [ key ][ 0 ]() )
114
120
elif key in _mess_keys :
115
- if value is True :
116
- f .step [_mess_keys [key ][1 ]] = _mess_keys [key ][0 ]
121
+ if key == "reply_me" :
122
+ mess [_mess_keys [key ][1 ]] = lambda is_reply_me , is_notice_me : (
123
+ True if _mess_keys [key ][0 ](is_reply_me ) is None else False
124
+ )
125
+ elif key == "notice_me" :
126
+ mess [_mess_keys [key ][1 ]] = lambda is_reply_me , is_notice_me : (
127
+ True if _mess_keys [key ][0 ](is_notice_me ) is None else False
128
+ )
129
+ else :
130
+ mess [_mess_keys [key ][1 ]] = lambda is_reply_me , is_notice_me : (
131
+ True if _mess_keys [key ][0 ](is_reply_me , is_notice_me ) is None else False
132
+ )
117
133
elif key in _op_keys :
118
134
op = _op_keys [key ]
119
135
if not isinstance (value , dict ):
120
136
raise ValueError (f"Expect a dict for operator { key } " )
121
- f . ops .append ((op , parse (value )))
137
+ ops .append ((op , parse (value )))
122
138
else :
123
139
raise ValueError (f"Unknown key: { key } " )
124
140
125
- return f
141
+ return _Filter (
142
+ steps = [slot [1 ] for slot in sorted (step .items (), key = lambda x : x [0 ])],
143
+ mess = [slot [1 ] for slot in sorted (mess .items (), key = lambda x : x [0 ])],
144
+ ops = ops ,
145
+ )
0 commit comments