Skip to content

Commit

Permalink
✨ lazy subscriber & tome check
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 9, 2024
1 parent 47ea02f commit c8e1e8a
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 105 deletions.
8 changes: 6 additions & 2 deletions arclet/entari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@
from satori.config import WebhookInfo as WebhookInfo
from satori.config import WebsocketsInfo as WebsocketsInfo

from . import command as command
from .config import load_config as load_config
from .core import Entari as Entari
from .event import MessageCreatedEvent as MessageCreatedEvent
from .event import MessageEvent as MessageEvent
from .filter import is_direct_message as is_direct_message
from .filter import is_public_message as is_public_message
from .filter import direct_message as direct_message
from .filter import notice_me as notice_me
from .filter import public_message as public_message
from .filter import reply_me as reply_me
from .filter import to_me as to_me
from .message import MessageChain as MessageChain
from .plugin import Plugin as Plugin
from .plugin import PluginMetadata as PluginMetadata
Expand Down
28 changes: 28 additions & 0 deletions arclet/entari/_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable, TypeVar

from arclet.letoderea import Subscriber
from arclet.letoderea.typing import TTarget

T = TypeVar("T")


class SubscribeLoader:
sub: Subscriber

def __init__(self, func: TTarget[T], caller: Callable[[TTarget[T]], Subscriber[T]]):
self.func = func
self.caller = caller
self.loaded = False

def load(self):
if not self.loaded:
self.sub = self.caller(self.func)
self.loaded = True
return self.sub

def dispose(self):
if self.loaded:
self.sub.dispose()
self.loaded = False
del self.func
del self.caller
23 changes: 21 additions & 2 deletions arclet/entari/builtins/auto_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ async def watch(self):
logger("DEBUG", f"Detected change in <blue>{plugin.id!r}</blue>, ignored")
continue
logger("INFO", f"Detected change in <blue>{plugin.id!r}</blue>, reloading...")
await plugin._cleanup()
pid = plugin.id
del plugin
dispose_plugin(pid)
if plugin := load_plugin(pid):
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin
else:
logger("ERROR", f"Failed to reload <blue>{pid!r}</blue>")
Expand All @@ -61,6 +65,10 @@ async def watch(self):
logger("INFO", f"Detected change in {change[1]!r} which failed to reload, retrying...")
if plugin := load_plugin(self.fail[change[1]]):
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin
del self.fail[change[1]]
else:
logger("ERROR", f"Failed to reload <blue>{self.fail[change[1]]!r}</blue>")
Expand Down Expand Up @@ -93,7 +101,9 @@ async def watch_config(self):
if (
plugin_name not in EntariConfig.instance.plugin
or EntariConfig.instance.plugin[plugin_name] is False
):
) and (plugin := find_plugin(pid)):
await plugin._cleanup()
del plugin
dispose_plugin(pid)
logger("INFO", f"Disposed plugin <blue>{pid!r}</blue>")
continue
Expand All @@ -112,8 +122,12 @@ async def watch_config(self):
if plugin := find_plugin(pid):
logger("INFO", f"Detected <blue>{pid!r}</blue>'s config change, reloading...")
plugin_file = str(plugin.module.__file__)
await plugin._cleanup()
dispose_plugin(plugin_name)
if plugin := load_plugin(plugin_name):
plugin._load()
await plugin._startup()
await plugin._ready()
logger("INFO", f"Reloaded <blue>{plugin.id!r}</blue>")
del plugin
else:
Expand All @@ -124,7 +138,12 @@ async def watch_config(self):
load_plugin(plugin_name)
if new := (set(EntariConfig.instance.plugin) - set(old_plugin)):
for plugin_name in new:
load_plugin(plugin_name)
if not (plugin := load_plugin(plugin_name)):
continue
plugin._load()
await plugin._startup()
await plugin._ready()
del plugin

async def launch(self, manager: Launart):
async with self.stage("blocking"):
Expand Down
85 changes: 55 additions & 30 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from arclet.alconna import Alconna, Arg, Args, CommandMeta, Namespace, command_manager, config
from arclet.alconna.tools.construct import AlconnaString, alconna_from_format
from arclet.alconna.typing import TAValue
from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Subscriber, es
from arclet.letoderea import BackendPublisher, BaseAuxiliary, Provider, Subscriber, es
from arclet.letoderea.event import get_providers
from arclet.letoderea.handler import generate_contexts
from arclet.letoderea.provider import ProviderFactory
from arclet.letoderea.typing import Contexts, TTarget
Expand All @@ -14,8 +15,10 @@
from tarina.trie import CharTrie

from ..event.command import CommandExecute
from ..event.config import ConfigReload
from ..event.protocol import MessageCreatedEvent
from ..message import MessageChain
from ..plugin import RootlessPlugin
from ..session import Session
from .argv import MessageArgv # noqa: F401
from .model import CommandResult, Match, Query
Expand All @@ -28,37 +31,17 @@
class EntariCommands:
__namespace__ = "Entari"

def __init__(self, need_tome: bool = False, remove_tome: bool = True, use_config_prefix: bool = True):
def __init__(self, need_notice_me: bool = False, need_reply_me: bool = False, use_config_prefix: bool = True):
self.trie: CharTrie[Subscriber[Optional[Union[str, MessageChain]]]] = CharTrie()
self.publisher = Publisher("entari.command", MessageCreatedEvent)
self.publisher.bind(AlconnaProviderFactory())
self.judge = MessageJudges(need_tome, remove_tome, use_config_prefix)
self.publisher = BackendPublisher("entari.command")
self.publisher.bind(*get_providers(MessageCreatedEvent), AlconnaProviderFactory())
self.judge = MessageJudges(need_notice_me, need_reply_me, use_config_prefix)
config.namespaces["Entari"] = Namespace(
self.__namespace__,
to_text=lambda x: x.text if x.__class__ is Text else None,
converter=lambda x: MessageChain(x),
)

@es.on(CommandExecute)
async def _execute(event: CommandExecute):
ctx = await generate_contexts(event)
msg = str(event.command)
if matches := list(self.trie.prefixes(msg)):
results = await asyncio.gather(
*(res.value.handle(ctx.copy(), inner=True) for res in matches if res.value)
)
for result in results:
if result is not None:
return result
data = split(msg, " ")
for value in self.trie.values():
try:
command_manager.find_shortcut(get_cmd(value), data)
except ValueError:
continue
result = await value.handle(ctx.copy(), inner=True)
if result is not None:
return result
es.on(CommandExecute, self.execute)

@property
def all_helps(self) -> str:
Expand All @@ -67,7 +50,7 @@ def all_helps(self) -> str:
def get_help(self, command: str) -> str:
return command_manager.get_command(f"{self.__namespace__}::{command}").get_help()

async def execute(self, message: MessageChain, session: Session, ctx: Contexts):
async def handle(self, session: Session, message: MessageChain, ctx: Contexts):
msg = str(message).lstrip()
if not msg:
return
Expand All @@ -88,6 +71,24 @@ async def execute(self, message: MessageChain, session: Session, ctx: Contexts):
if result is not None:
await session.send(result)

async def execute(self, event: CommandExecute):
ctx = await generate_contexts(event)
msg = str(event.command)
if matches := list(self.trie.prefixes(msg)):
results = await asyncio.gather(*(res.value.handle(ctx.copy(), inner=True) for res in matches if res.value))
for result in results:
if result is not None:
return result
data = split(msg, " ")
for value in self.trie.values():
try:
command_manager.find_shortcut(get_cmd(value), data)
except ValueError:
continue
result = await value.handle(ctx.copy(), inner=True)
if result is not None:
return result

def command(
self,
command: str,
Expand Down Expand Up @@ -183,9 +184,9 @@ def _remove(_):
_commands = EntariCommands()


def config_commands(need_tome: bool = False, remove_tome: bool = True, use_config_prefix: bool = True):
_commands.judge.need_tome = need_tome
_commands.judge.remove_tome = remove_tome
def config_commands(need_notice_me: bool = False, need_reply_me: bool = False, use_config_prefix: bool = True):
_commands.judge.need_notice_me = need_notice_me
_commands.judge.need_reply_me = need_reply_me
_commands.judge.use_config_prefix = use_config_prefix


Expand All @@ -199,4 +200,28 @@ async def execute(message: Union[str, MessageChain]):
return res.value


@RootlessPlugin.apply("commands")
def _(plg: RootlessPlugin):
if "need_notice_me" in plg.config:
_commands.judge.need_notice_me = plg.config["need_notice_me"]
if "need_reply_me" in plg.config:
_commands.judge.need_reply_me = plg.config["need_reply_me"]
if "use_config_prefix" in plg.config:
_commands.judge.use_config_prefix = plg.config["use_config_prefix"]

@plg.use(ConfigReload)
def update(event: ConfigReload):
if event.scope != "plugin":
return
if event.key != "~commands":
return
if "need_notice_me" in event.value:
_commands.judge.need_notice_me = event.value["need_notice_me"]
if "need_reply_me" in event.value:
_commands.judge.need_reply_me = event.value["need_reply_me"]
if "use_config_prefix" in event.value:
_commands.judge.use_config_prefix = event.value["use_config_prefix"]
return True


__all__ = ["_commands", "config_commands", "Match", "Query", "execute", "CommandResult", "mount", "command", "on"]
16 changes: 9 additions & 7 deletions arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from arclet.alconna import Alconna, command_manager
from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory

from .._subscriber import SubscribeLoader
from ..event import MessageCreatedEvent
from ..event.command import pub as execute_handles
from ..plugin.model import Plugin, PluginDispatcher
Expand All @@ -20,14 +21,14 @@ def __init__(
self,
plugin: Plugin,
command: Alconna,
need_tome: bool = False,
remove_tome: bool = True,
need_reply_me: bool = False,
need_notice_me: bool = False,
use_config_prefix: bool = True,
):
self.supplier = AlconnaSuppiler(command)
super().__init__(plugin, MessageCreatedEvent)

self.publisher.bind(MessageJudges(need_tome, remove_tome, use_config_prefix), self.supplier)
self.publisher.bind(MessageJudges(need_reply_me, need_notice_me, use_config_prefix), self.supplier)
self.publisher.bind(AlconnaProviderFactory())

def assign(
Expand Down Expand Up @@ -59,7 +60,8 @@ def on_execute(
_auxiliaries.append(self.supplier)

def wrapper(func):
sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers)
caller = execute_handles.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)
sub = SubscribeLoader(func, caller)
self._subscribers.append(sub)
return sub

Expand All @@ -71,13 +73,13 @@ def wrapper(func):

def mount(
cmd: Alconna,
need_tome: bool = False,
remove_tome: bool = True,
need_reply_me: bool = False,
need_notice_me: bool = False,
use_config_prefix: bool = True,
) -> AlconnaPluginDispatcher:
if not (plugin := Plugin.current()):
raise LookupError("no plugin context found")
disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome, use_config_prefix)
disp = AlconnaPluginDispatcher(plugin, cmd, need_reply_me, need_notice_me, use_config_prefix)
if disp.publisher.id in plugin.dispatchers:
return plugin.dispatchers[disp.id] # type: ignore
plugin.dispatchers[disp.publisher.id] = disp
Expand Down
40 changes: 8 additions & 32 deletions arclet/entari/command/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from arclet.letoderea import Contexts, Interface, JudgeAuxiliary, Param, Provider, Scope, Subscriber, SupplyAuxiliary
from arclet.letoderea.provider import ProviderFactory
from nepattern.util import CUnionType
from satori.client import Account
from satori.element import At, Text
from satori.element import Text
from tarina.generic import get_origin

from ..config import EntariConfig
Expand All @@ -16,28 +15,6 @@
from .model import CommandResult, Match, Query


def _is_tome(message: MessageChain, account: Account):
if message and isinstance(message[0], At):
at: At = message[0] # type: ignore
if at.id and at.id == account.self_id:
return True
return False


def _remove_tome(message: MessageChain, account: Account):
if _is_tome(message, account):
message = message.copy()
message.pop(0)
if message and isinstance(message[0], Text):
text = message[0].text.lstrip() # type: ignore
if not text:
message.pop(0)
else:
message[0] = Text(text)
return message
return message


def _remove_config_prefix(message: MessageChain):
if not (command_prefix := EntariConfig.instance.basic.get("prefix", [])):
return message
Expand All @@ -54,22 +31,21 @@ def _remove_config_prefix(message: MessageChain):


class MessageJudges(JudgeAuxiliary):
def __init__(self, need_tome: bool, remove_tome: bool, use_config_prefix: bool):
def __init__(self, need_reply_me: bool, need_notice_me: bool, use_config_prefix: bool):
super().__init__(priority=10)
self.need_tome = need_tome
self.remove_tome = remove_tome
self.need_reply_me = need_reply_me
self.need_notice_me = need_notice_me
self.use_config_prefix = use_config_prefix

async def __call__(self, scope: Scope, interface: Interface):
if "$message_content" in interface.ctx:
message: MessageChain = interface.ctx["$message_content"]
account = await interface.query(Account, "account", force_return=True)
if not account:
is_reply_me = interface.ctx.get("is_reply_me", False)
is_notice_me = interface.ctx.get("is_notice_me", False)
if self.need_reply_me and not is_reply_me:
return False
if self.need_tome and not _is_tome(message, account):
if self.need_notice_me and not is_notice_me:
return False
if self.remove_tome:
message = _remove_tome(message, account)
if self.use_config_prefix and not (message := _remove_config_prefix(message)):
return False
return interface.update(**{"$message_content": message})
Expand Down
Loading

0 comments on commit c8e1e8a

Please sign in to comment.