diff --git a/arclet/entari/filter/__init__.py b/arclet/entari/filter/__init__.py index 7b92e04..db89438 100644 --- a/arclet/entari/filter/__init__.py +++ b/arclet/entari/filter/__init__.py @@ -1,122 +1,71 @@ -from collections.abc import Awaitable -from typing import Callable, Optional, Union -from typing_extensions import Self, TypeAlias +import asyncio +from datetime import datetime +from typing import Optional, Union from arclet.letoderea import Interface, JudgeAuxiliary, Scope -from arclet.letoderea import bind as _bind -from arclet.letoderea.typing import run_sync -from tarina import is_async +from ..message import MessageChain from ..session import Session -from .common import ChannelFilter, GuildFilter, PlatformFilter, SelfFilter, UserFilter -from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger -from .op import ExcludeFilter, IntersectFilter, UnionFilter +from .common import Filter as Filter -_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]] - -class Filter(JudgeAuxiliary): - def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10): - super().__init__(priority=priority) - self.steps = [] - if callback: - if is_async(callback): - self.callback = callback - else: - self.callback = run_sync(callback) - else: - self.callback = None - - async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: - for step in sorted(self.steps, key=lambda x: x.priority): - if not await step(scope, interface): - return False - if self.callback: - session = await interface.query(Session, "session", force_return=True) - if not session: - return False - if not await self.callback(session): # type: ignore - return False - return True - - @property - def scopes(self) -> set[Scope]: - return {Scope.prepare} - - @property - def id(self) -> str: - return "entari.filter" - - def user(self, *user_ids: str) -> Self: - self.steps.append(UserFilter(*user_ids, priority=6)) - return self - - def guild(self, *guild_ids: str) -> Self: - self.steps.append(GuildFilter(*guild_ids, priority=4)) - return self - - def channel(self, *channel_ids: str) -> Self: - self.steps.append(ChannelFilter(*channel_ids, priority=5)) - return self - - def self(self, *self_ids: str) -> Self: - self.steps.append(SelfFilter(*self_ids, priority=3)) - return self - - def platform(self, *platforms: str) -> Self: - self.steps.append(PlatformFilter(*platforms, priority=2)) - return self +class Interval(JudgeAuxiliary): + def __init__(self, interval: float, limit_prompt: Optional[Union[str, MessageChain]] = None): + self.success = True + self.last_time = None + self.interval = interval + self.limit_prompt = limit_prompt + super().__init__(priority=20) @property - def direct(self) -> Self: - self.steps.append(DirectMessageJudger(priority=8)) - return self - - private = direct + def id(self): + return "entari.filter/interval" @property - def public(self) -> Self: - self.steps.append(PublicMessageJudger(priority=8)) - return self + def scopes(self): + return {Scope.prepare, Scope.cleanup} - @property - def reply_me(self) -> Self: - self.steps.append(ReplyMeJudger(priority=9)) - return self + async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: + if scope == Scope.prepare: + if not self.last_time: + return True + # if self.condition: + # if not await self.condition(scope, interface): + # self.success = False + # return False + self.success = (datetime.now() - self.last_time).total_seconds() > self.interval + if not self.success: + session = await interface.query(Session, "session", force_return=True) + if session and self.limit_prompt: + await session.send(self.limit_prompt) + return self.success + if self.success: + self.last_time = datetime.now() + return True + + +class Semaphore(JudgeAuxiliary): + def __init__(self, count: int, limit_prompt: Optional[Union[str, MessageChain]] = None): + self.count = count + self.limit_prompt = limit_prompt + self.semaphore = asyncio.Semaphore(count) + super().__init__(priority=20) @property - def notice_me(self) -> Self: - self.steps.append(NoticeMeJudger(priority=10)) - return self + def id(self): + return "entari.filter/access" @property - def to_me(self) -> Self: - self.steps.append(ToMeJudger(priority=11)) - return self - - def bind(self, func): - return _bind(self)(func) - - def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter": - new = Filter(priority=self.priority) - _other = other if isinstance(other, Filter) else Filter(callback=other) - new.steps.append(IntersectFilter(self, _other, priority=1)) - return new + def scopes(self): + return {Scope.prepare, Scope.cleanup} - intersect = and_ - - def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter": - new = Filter(priority=self.priority) - _other = other if isinstance(other, Filter) else Filter(callback=other) - new.steps.append(UnionFilter(self, _other, priority=1)) - return new - - union = or_ - - def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter": - new = Filter(priority=self.priority) - _other = other if isinstance(other, Filter) else Filter(callback=other) - new.steps.append(ExcludeFilter(self, _other, priority=1)) - return new - - exclude = not_ + async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]: + if scope == Scope.prepare: + if not await self.semaphore.acquire(): + session = await interface.query(Session, "session", force_return=True) + if session and self.limit_prompt: + await session.send(self.limit_prompt) + return False + return True + self.semaphore.release() + return True diff --git a/arclet/entari/filter/common.py b/arclet/entari/filter/common.py index bf8bc34..959c206 100644 --- a/arclet/entari/filter/common.py +++ b/arclet/entari/filter/common.py @@ -1,8 +1,17 @@ -from typing import Optional +from collections.abc import Awaitable +from typing import Callable, Optional, Union +from typing_extensions import Self, TypeAlias from arclet.letoderea import Interface, JudgeAuxiliary, Scope +from arclet.letoderea import bind as _bind +from arclet.letoderea.typing import run_sync from satori import Channel, Guild, User from satori.client import Account +from tarina import is_async + +from ..session import Session +from .message import DirectMessageJudger, NoticeMeJudger, PublicMessageJudger, ReplyMeJudger, ToMeJudger +from .op import ExcludeFilter, IntersectFilter, UnionFilter class UserFilter(JudgeAuxiliary): @@ -98,3 +107,113 @@ def scopes(self) -> set[Scope]: @property def id(self) -> str: return "entari.filter/platform" + + +_SessionFilter: TypeAlias = Union[Callable[[Session], bool], Callable[[Session], Awaitable[bool]]] + + +class Filter(JudgeAuxiliary): + def __init__(self, callback: Optional[_SessionFilter] = None, priority: int = 10): + super().__init__(priority=priority) + self.steps = [] + if callback: + if is_async(callback): + self.callback = callback + else: + self.callback = run_sync(callback) + else: + self.callback = None + + async def __call__(self, scope: Scope, interface: Interface): + for step in sorted(self.steps, key=lambda x: x.priority): + if not await step(scope, interface): + return False + if self.callback: + session = await interface.query(Session, "session", force_return=True) + if not session: + return False + if not await self.callback(session): # type: ignore + return False + return True + + @property + def scopes(self) -> set[Scope]: + return {Scope.prepare} + + @property + def id(self) -> str: + return "entari.filter" + + def user(self, *user_ids: str) -> Self: + self.steps.append(UserFilter(*user_ids, priority=6)) + return self + + def guild(self, *guild_ids: str) -> Self: + self.steps.append(GuildFilter(*guild_ids, priority=4)) + return self + + def channel(self, *channel_ids: str) -> Self: + self.steps.append(ChannelFilter(*channel_ids, priority=5)) + return self + + def self(self, *self_ids: str) -> Self: + self.steps.append(SelfFilter(*self_ids, priority=3)) + return self + + def platform(self, *platforms: str) -> Self: + self.steps.append(PlatformFilter(*platforms, priority=2)) + return self + + @property + def direct(self) -> Self: + self.steps.append(DirectMessageJudger(priority=8)) + return self + + private = direct + + @property + def public(self) -> Self: + self.steps.append(PublicMessageJudger(priority=8)) + return self + + @property + def reply_me(self) -> Self: + self.steps.append(ReplyMeJudger(priority=9)) + return self + + @property + def notice_me(self) -> Self: + self.steps.append(NoticeMeJudger(priority=10)) + return self + + @property + def to_me(self) -> Self: + self.steps.append(ToMeJudger(priority=11)) + return self + + def bind(self, func): + return _bind(self)(func) + + def and_(self, other: Union["Filter", _SessionFilter]) -> "Filter": + new = Filter(priority=self.priority) + _other = other if isinstance(other, Filter) else Filter(callback=other) + new.steps.append(IntersectFilter(self, _other, priority=1)) + return new + + intersect = and_ + + def or_(self, other: Union["Filter", _SessionFilter]) -> "Filter": + new = Filter(priority=self.priority) + _other = other if isinstance(other, Filter) else Filter(callback=other) + new.steps.append(UnionFilter(self, _other, priority=1)) + return new + + union = or_ + + def not_(self, other: Union["Filter", _SessionFilter]) -> "Filter": + new = Filter(priority=self.priority) + _other = other if isinstance(other, Filter) else Filter(callback=other) + new.steps.append(ExcludeFilter(self, _other, priority=1)) + return new + + exclude = not_ diff --git a/arclet/entari/scheduler.py b/arclet/entari/scheduler.py index 3fe0297..79c4945 100644 --- a/arclet/entari/scheduler.py +++ b/arclet/entari/scheduler.py @@ -91,12 +91,12 @@ async def launch(self, manager: Launart): id = "entari.scheduler" -scheduler = Scheduler() +scheduler = service = Scheduler() @RootlessPlugin.apply("scheduler") def _(plg: RootlessPlugin): - plg.service(scheduler) + plg.service(service) def every_second(): @@ -186,16 +186,13 @@ def crontab(cron_str: str): cron_str (str): cron 表达式 """ - def _(): - now = datetime.now() - it = croniter(cron_str, now) - return it.get_next(datetime) - now + it = croniter(cron_str, datetime.now()) - return _ + return lambda iter=it: iter.get_next(datetime) - datetime.now() def cron(pattern: str): - return scheduler.schedule(crontab(pattern)) + return service.schedule(crontab(pattern)) def every( @@ -207,4 +204,4 @@ def every( "minute": every_minutes, "hour": every_hours, } - return scheduler.schedule(_TIMER_MAPPING[mode](value)) + return service.schedule(_TIMER_MAPPING[mode](value)) diff --git a/example.yml b/example.yml index 8de0f1b..de589ac 100644 --- a/example.yml +++ b/example.yml @@ -11,4 +11,5 @@ plugins: watch_dirs: ["."] ::echo: true example_plugin: true - ~record_message: true \ No newline at end of file + ~record_message: true + ~scheduler: true \ No newline at end of file diff --git a/example_plugin.py b/example_plugin.py index 5195520..8af10e2 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -9,7 +9,10 @@ command, metadata, keeping, + scheduler, + Entari, ) +from arclet.entari.filter import Interval metadata(__file__) @@ -50,7 +53,7 @@ async def _(session: Session): return await session.send("Filter: public message, to me, but content is not 'aaa'") -@command.on("add {a} {b}") +@command.on("add {a} {b}", [Interval(2, limit_prompt="太快了")]) def add(a: int, b: int): return f"{a + b =}" @@ -80,3 +83,11 @@ async def show(session: Session): @plug.use("::before_send") async def send_hook(message: MessageChain): return message + "喵" + + +@scheduler.cron("* * * * *") +async def broadcast(app: Entari): + for account in app.accounts.values(): + channels = [channel for guild in (await account.guild_list()).data for channel in (await account.channel_list(guild.id)).data] + for channel in channels: + await account.send_message(channel, "Hello, World!")