diff --git a/README.md b/README.md index b1c2b08..a882b02 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,6 @@ from arclet.entari import Session, Entari, WS app = Entari(WS(host="127.0.0.1", port=5140, path="satori")) - @app.on_message() async def repeat(session: Session): await session.send(session.content) @@ -34,10 +33,7 @@ app.run() 指令 `add {a} {b}`: ```python -from arclet.entari import Session, Entari, EntariCommands, WS - -command = EntariCommands() - +from arclet.entari import Session, Entari, WS, command @command.on("add {a} {b}") async def add(a: int, b: int, session: Session): @@ -47,3 +43,22 @@ async def add(a: int, b: int, session: Session): app = Entari(WS(port=5500, token="XXX")) app.run() ``` + +编写插件: + +```python +from arclet.entari import Session, MessageEvent, PluginMetadata + +__plugin_metadata__ = PluginMetadata( + name="Hello, World!", + author=["Arclet"], + version="0.1.0", + description="A simple plugin that replies 'Hello, World!' to every message." +) + +on_message = MessageEvent.dispatch() + +@on_message() +async def _(session: Session): + await session.send("Hello, World!") +``` diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py index b1c91cc..bbb6544 100644 --- a/arclet/entari/__init__.py +++ b/arclet/entari/__init__.py @@ -38,15 +38,14 @@ from satori.config import WebhookInfo as WebhookInfo from satori.config import WebsocketsInfo as WebsocketsInfo -from .command import AlconnaDispatcher as AlconnaDispatcher -from .command import EntariCommands as EntariCommands from .core import Entari as Entari from .event import MessageCreatedEvent as MessageCreatedEvent from .event import MessageEvent as MessageEvent from .filter import is_direct_message as is_direct_message from .filter import is_public_message as is_public_message from .message import MessageChain as MessageChain -from .plugin import Plugin as Plugin +from .plugin import PluginMetadata as PluginMetadata +from .plugin import dispose as dispose_plugin # noqa: F401 from .plugin import load_plugin as load_plugin from .plugin import load_plugins as load_plugins from .session import Session as Session diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index baaa1a0..4ac9104 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -1,5 +1,221 @@ -from .main import EntariCommands as EntariCommands -from .model import CommandResult as CommandResult -from .model import Match as Match -from .model import Query as Query -from .plugin import AlconnaDispatcher as AlconnaDispatcher +import asyncio +from typing import Any, Callable, Optional, TypeVar, Union, cast, overload + +from arclet.alconna import ( + Alconna, + Arg, + Args, + Arparma, + CommandMeta, + Namespace, + command_manager, + config, + output_manager, +) +from arclet.alconna.args import TAValue +from arclet.alconna.tools.construct import AlconnaString, alconna_from_format +from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber +from arclet.letoderea.handler import depend_handler +from arclet.letoderea.provider import ProviderFactory +from nepattern import DirectPattern +from pygtrie import CharTrie +from satori.element import At, Text +from tarina.string import split + +from ..event import MessageEvent +from ..message import MessageChain +from .argv import MessageArgv # noqa: F401 +from .model import CommandResult, Match, Query +from .plugin import mount +from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd + +T = TypeVar("T") +TCallable = TypeVar("TCallable", bound=Callable[..., Any]) + + +class EntariCommands: + __namespace__ = "Entari" + + def __init__(self, need_tome: bool = False, remove_tome: bool = False): + self.trie: CharTrie = CharTrie() + self.publisher = Publisher("EntariCommands", MessageEvent) + self.publisher.providers.append(AlconnaProviderFactory()) + self.need_tome = need_tome + self.remove_tome = remove_tome + config.namespaces["Entari"] = Namespace( + self.__namespace__, + to_text=lambda x: x.text if x.__class__ is Text else None, + converter=lambda x: MessageChain(x), + ) + + @self.publisher.register(auxiliaries=[MessageJudger()]) + async def listener(event: MessageEvent): + msg = str(event.content.exclude(At)).lstrip() + if not msg: + return + if matches := list(self.trie.prefixes(msg)): + await asyncio.gather( + *(depend_handler(res.value, event, inner=True) for res in matches if res.value) + ) + return + # shortcut + data = split(msg, (" ",)) + for value in self.trie.values(): + try: + command_manager.find_shortcut(get_cmd(value), data) + except ValueError: + continue + await depend_handler(value, event, inner=True) + + @property + def all_helps(self) -> str: + return command_manager.all_command_help(namespace=self.__namespace__) + + def get_help(self, command: str) -> str: + return command_manager.get_command(f"{self.__namespace__}::{command}").get_help() + + async def execute(self, message: MessageChain): + async def _run(target: Subscriber, content: MessageChain): + aux = next((a for a in target.auxiliaries[Scope.prepare] if isinstance(a, AlconnaSuppiler)), None) + if not aux: + return + with output_manager.capture(aux.cmd.name) as cap: + output_manager.set_action(lambda x: x, aux.cmd.name) + try: + _res = aux.cmd.parse(content) + except Exception as e: + _res = Arparma(aux.cmd.path, message, False, error_info=e) + may_help_text: Optional[str] = cap.get("output", None) + if _res.matched: + args = {} + ctx = {"alc_result": CommandResult(aux.cmd, _res, may_help_text)} + for param in target.params: + args[param.name] = await param.solve(ctx) + return await target(**args) + elif may_help_text: + return may_help_text + + msg = str(message.exclude(At)).lstrip() + if matches := list(self.trie.prefixes(msg)): + return await asyncio.gather(*(_run(res.value, message) for res in matches if res.value)) + # shortcut + data = split(msg, (" ",)) + res = [] + for value in self.trie.values(): + try: + command_manager.find_shortcut(get_cmd(value), data) + except ValueError: + continue + res.append(await _run(value, message)) + return res + + def command( + self, + command: str, + help_text: Optional[str] = None, + need_tome: bool = False, + remove_tome: bool = False, + auxiliaries: Optional[list[BaseAuxiliary]] = None, + providers: Optional[ + list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] + ] = None, + ): + class Command(AlconnaString): + def __call__(_cmd_self, func: TCallable) -> TCallable: + return self.on(_cmd_self.build(), need_tome, remove_tome, auxiliaries, providers)(func) + + return Command(command, help_text) + + @overload + def on( + self, + command: Alconna, + need_tome: bool = False, + remove_tome: bool = False, + auxiliaries: Optional[list[BaseAuxiliary]] = None, + providers: Optional[ + list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] + ] = None, + ) -> Callable[[TCallable], TCallable]: ... + + @overload + def on( + self, + command: str, + need_tome: bool = False, + remove_tome: bool = False, + auxiliaries: Optional[list[BaseAuxiliary]] = None, + providers: Optional[ + list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] + ] = None, + *, + args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, + meta: Optional[CommandMeta] = None, + ) -> Callable[[TCallable], TCallable]: ... + + def on( + self, + command: Union[Alconna, str], + need_tome: bool = False, + remove_tome: bool = False, + auxiliaries: Optional[list[BaseAuxiliary]] = None, + providers: Optional[ + list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] + ] = None, + *, + args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, + meta: Optional[CommandMeta] = None, + ) -> Callable[[TCallable], TCallable]: + auxiliaries = auxiliaries or [] + providers = providers or [] + + def wrapper(func: TCallable) -> TCallable: + if isinstance(command, str): + mapping = {arg.name: arg.value for arg in Args.from_callable(func)[0]} + mapping.update(args or {}) # type: ignore + _command = alconna_from_format(command, mapping, meta, union=False) + _command.reset_namespace(self.__namespace__) + key = _command.name + "".join( + f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern) + ) + auxiliaries.insert( + 0, AlconnaSuppiler(_command, need_tome or self.need_tome, remove_tome or self.remove_tome) + ) + target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) + self.publisher.remove_subscriber(target) + self.trie[key] = target + else: + auxiliaries.insert( + 0, AlconnaSuppiler(command, need_tome or self.need_tome, remove_tome or self.remove_tome) + ) + target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) + self.publisher.remove_subscriber(target) + if not isinstance(command.command, str): + raise TypeError("Command name must be a string.") + if not command.prefixes: + self.trie[command.command] = target + elif not all(isinstance(i, str) for i in command.prefixes): + raise TypeError("Command prefixes must be a list of string.") + else: + self.publisher.remove_subscriber(target) + for prefix in cast(list[str], command.prefixes): + self.trie[prefix + command.command] = target + command.reset_namespace(self.__namespace__) + return func + + return wrapper + + +_commands = EntariCommands() + + +def config_commands(need_tome: bool = False, remove_tome: bool = False): + _commands.need_tome = need_tome + _commands.remove_tome = remove_tome + + +command = _commands.command +on = _commands.on + + +__all__ = ["_commands", "config_commands", "Match", "Query", "CommandResult", "mount", "command", "on"] diff --git a/arclet/entari/command/main.py b/arclet/entari/command/main.py deleted file mode 100644 index 80991a5..0000000 --- a/arclet/entari/command/main.py +++ /dev/null @@ -1,217 +0,0 @@ -import asyncio -from typing import Any, Callable, Optional, TypeVar, Union, cast, overload - -from arclet.alconna import ( - Alconna, - Arg, - Args, - Arparma, - CommandMeta, - Namespace, - command_manager, - config, - output_manager, -) -from arclet.alconna.args import TAValue -from arclet.alconna.tools.construct import AlconnaString, alconna_from_format -from arclet.letoderea import BaseAuxiliary, Provider, Scope, Subscriber -from arclet.letoderea.handler import depend_handler -from arclet.letoderea.provider import ProviderFactory -from nepattern import DirectPattern -from pygtrie import CharTrie -from satori.element import At, Text -from tarina.context import ContextModel -from tarina.string import split - -from ..event import MessageEvent -from ..message import MessageChain -from ..plugin import Plugin, PluginDispatcher, dispatchers -from .argv import MessageArgv # noqa: F401 -from .model import CommandResult -from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd - -T = TypeVar("T") -TCallable = TypeVar("TCallable", bound=Callable[..., Any]) - - -cx_command: ContextModel["EntariCommands"] = ContextModel("EntariCommands") - - -class EntariCommands: - __namespace__ = "Entari" - - @classmethod - def current(cls) -> "EntariCommands": - return cx_command.get() - - def __init__(self, need_tome: bool = False, remove_tome: bool = False): - self._plugin = Plugin(["RF-Tar-Railt"], "EntariCommands") - cx_command.set(self) - self.trie: CharTrie = CharTrie() - self.publisher = PluginDispatcher(self._plugin, MessageEvent) - self.publisher.providers.append(AlconnaProviderFactory()) - dispatchers["~command.EntariCommands"] = self.publisher - self.need_tome = need_tome - self.remove_tome = remove_tome - config.namespaces["Entari"] = Namespace( - self.__namespace__, - to_text=lambda x: x.text if x.__class__ is Text else None, - converter=lambda x: MessageChain(x), - ) - - @self.publisher.register(auxiliaries=[MessageJudger()]) - async def listener(event: MessageEvent): - msg = str(event.content.exclude(At)).lstrip() - if not msg: - return - if matches := list(self.trie.prefixes(msg)): - await asyncio.gather( - *(depend_handler(res.value, event, inner=True) for res in matches if res.value) - ) - return - # shortcut - data = split(msg, (" ",)) - for value in self.trie.values(): - try: - command_manager.find_shortcut(get_cmd(value), data) - except ValueError: - continue - await depend_handler(value, event, inner=True) - - @property - def all_helps(self) -> str: - return command_manager.all_command_help(namespace=self.__namespace__) - - def get_help(self, command: str) -> str: - return command_manager.get_command(f"{self.__namespace__}::{command}").get_help() - - async def execute(self, message: MessageChain): - async def _run(target: Subscriber, content: MessageChain): - aux = next((a for a in target.auxiliaries[Scope.prepare] if isinstance(a, AlconnaSuppiler)), None) - if not aux: - return - with output_manager.capture(aux.cmd.name) as cap: - output_manager.set_action(lambda x: x, aux.cmd.name) - try: - _res = aux.cmd.parse(content) - except Exception as e: - _res = Arparma(aux.cmd.path, message, False, error_info=e) - may_help_text: Optional[str] = cap.get("output", None) - if _res.matched: - args = {} - ctx = {"alc_result": CommandResult(aux.cmd, _res, may_help_text)} - for param in target.params: - args[param.name] = await param.solve(ctx) - return await target(**args) - elif may_help_text: - return may_help_text - - msg = str(message.exclude(At)).lstrip() - if matches := list(self.trie.prefixes(msg)): - return await asyncio.gather(*(_run(res.value, message) for res in matches if res.value)) - # shortcut - data = split(msg, (" ",)) - res = [] - for value in self.trie.values(): - try: - command_manager.find_shortcut(get_cmd(value), data) - except ValueError: - continue - res.append(await _run(value, message)) - return res - - def command( - self, - command: str, - help_text: Optional[str] = None, - need_tome: bool = False, - remove_tome: bool = False, - auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, - ): - class Command(AlconnaString): - def __call__(_cmd_self, func: TCallable) -> TCallable: - return self.on(_cmd_self.build(), need_tome, remove_tome, auxiliaries, providers)(func) - - return Command(command, help_text) - - @overload - def on( - self, - command: Alconna, - need_tome: bool = False, - remove_tome: bool = False, - auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, - ) -> Callable[[TCallable], TCallable]: ... - - @overload - def on( - self, - command: str, - need_tome: bool = False, - remove_tome: bool = False, - auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, - *, - args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, - meta: Optional[CommandMeta] = None, - ) -> Callable[[TCallable], TCallable]: ... - - def on( - self, - command: Union[Alconna, str], - need_tome: bool = False, - remove_tome: bool = False, - auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, - *, - args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, - meta: Optional[CommandMeta] = None, - ) -> Callable[[TCallable], TCallable]: - auxiliaries = auxiliaries or [] - providers = providers or [] - - def wrapper(func: TCallable) -> TCallable: - if isinstance(command, str): - mapping = {arg.name: arg.value for arg in Args.from_callable(func)[0]} - mapping.update(args or {}) # type: ignore - _command = alconna_from_format(command, mapping, meta, union=False) - _command.reset_namespace(self.__namespace__) - key = _command.name + "".join( - f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern) - ) - auxiliaries.insert( - 0, AlconnaSuppiler(_command, need_tome or self.need_tome, remove_tome or self.remove_tome) - ) - target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) - self.publisher.remove_subscriber(target) - self.trie[key] = target - else: - auxiliaries.insert( - 0, AlconnaSuppiler(command, need_tome or self.need_tome, remove_tome or self.remove_tome) - ) - target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func) - self.publisher.remove_subscriber(target) - if not isinstance(command.command, str): - raise TypeError("Command name must be a string.") - if not command.prefixes: - self.trie[command.command] = target - elif not all(isinstance(i, str) for i in command.prefixes): - raise TypeError("Command prefixes must be a list of string.") - else: - self.publisher.remove_subscriber(target) - for prefix in cast(list[str], command.prefixes): - self.trie[prefix + command.command] = target - command.reset_namespace(self.__namespace__) - return func - - return wrapper diff --git a/arclet/entari/command/plugin.py b/arclet/entari/command/plugin.py index ebc9cd3..229e8df 100644 --- a/arclet/entari/command/plugin.py +++ b/arclet/entari/command/plugin.py @@ -1,30 +1,38 @@ -from arclet.alconna import Alconna +from arclet.alconna import Alconna, command_manager from ..event import MessageEvent -from ..plugin import Plugin, PluginDispatcher, PluginDispatcherFactory, register_factory +from ..plugin import Plugin, PluginDispatcher +from .model import Match, Query from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger -class AlconnaDispatcher(PluginDispatcherFactory): +class AlconnaPluginDispatcher(PluginDispatcher): def __init__( self, + plugin: Plugin, command: Alconna, need_tome: bool = False, remove_tome: bool = False, ): - self.command = command - self.need_tome = need_tome - self.remove_tome = remove_tome + self.supplier = AlconnaSuppiler(command, need_tome, remove_tome) + super().__init__(plugin, MessageEvent) - def dispatch(self, plugin: Plugin) -> PluginDispatcher: - disp = PluginDispatcher(plugin, MessageEvent) - disp.bind(MessageJudger(), AlconnaSuppiler(self.command, self.need_tome, self.remove_tome)) - disp.bind(AlconnaProviderFactory()) - return disp + self.bind(MessageJudger(), self.supplier) + self.bind(AlconnaProviderFactory()) + def dispose(self): + super().dispose() + command_manager.delete(self.supplier.cmd) + del self.supplier.cmd + del self.supplier -register_factory( - Alconna, - lambda cmd, *args, **kwargs: AlconnaDispatcher(cmd, *args, **kwargs), -) + Match = Match + Query = Query + + +def mount(cmd: Alconna, need_tome: bool = False, remove_tome: bool = False): + if not (plugin := Plugin.current()): + raise LookupError("no plugin context found") + disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome) + return disp diff --git a/arclet/entari/core.py b/arclet/entari/core.py index 50a6bf7..2c28bfe 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -20,8 +20,9 @@ from satori.model import Event from tarina.generic import get_origin +from .command import _commands from .event import MessageEvent, event_parse -from .plugin import dispatchers +from .plugin.model import _plugins from .session import Session @@ -49,6 +50,7 @@ class Entari(App): def __init__(self, *configs: Config): super().__init__(*configs) self.event_system = EventSystem() + self.event_system.register(_commands.publisher) self.register(self.handle_event) self._ref_tasks = set() @@ -77,12 +79,15 @@ async def event_parse_task(connection: Account, raw: Event): with suppress(NotImplementedError): ev = event_parse(connection, raw) self.event_system.publish(ev) - for disp in dispatchers.values(): - if not disp.validate(ev): - continue - task = loop.create_task(disp.publish(ev)) - self._ref_tasks.add(task) - task.add_done_callback(self._ref_tasks.discard) + for plugin in _plugins.values(): + for disp in plugin.dispatchers.values(): + if not disp.validate(ev): + continue + if disp._run_by_system: + continue + task = loop.create_task(disp.publish(ev)) + self._ref_tasks.add(task) + task.add_done_callback(self._ref_tasks.discard) return logger.warning(f"received unsupported event {raw.type}: {raw}") diff --git a/arclet/entari/event.py b/arclet/entari/event.py index bedec24..714b0b7 100644 --- a/arclet/entari/event.py +++ b/arclet/entari/event.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, fields from datetime import datetime -from typing import ClassVar +from typing import Callable, ClassVar, TypeVar from arclet.letoderea import Contexts, Param, Provider from satori import ArgvInteraction, ButtonInteraction, Channel @@ -13,6 +13,9 @@ from tarina import gen_subclass from .message import MessageChain +from .plugin import dispatch + +TE = TypeVar("TE", bound="Event") @dataclass @@ -24,6 +27,10 @@ class Event: timestamp: datetime account: Account + @classmethod + def dispatch(cls: type[TE], predicate: Callable[[TE], bool] | None = None): + return dispatch(cls, predicate=predicate) + @classmethod def parse(cls, account: Account, origin: SatoriEvent): fs = fields(cls) diff --git a/arclet/entari/plugin.py b/arclet/entari/plugin.py deleted file mode 100644 index af01782..0000000 --- a/arclet/entari/plugin.py +++ /dev/null @@ -1,145 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -import importlib -import inspect -from os import PathLike -from pathlib import Path -from typing import Any, Callable, TypeVar, overload -from typing_extensions import Unpack - -from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx -from arclet.letoderea.builtin.breakpoint import R -from arclet.letoderea.typing import TTarget -from loguru import logger - -from .event import Event - -dispatchers: dict[str, PluginDispatcher] = {} - - -class PluginDispatcher(Publisher): - def __init__( - self, - plugin: Plugin, - *events: type[Event], - predicate: Callable[[Event], bool] | None = None, - ): - super().__init__(plugin.name, *events, predicate=predicate) # type: ignore - self.plugin = plugin - if es := system_ctx.get(): - es.register(self) - else: - dispatchers[self.id] = self - self._events = events - - def waiter( - self, - *events: type[Event], - providers: list[Provider | type[Provider]] | None = None, - auxiliaries: list[BaseAuxiliary] | None = None, - priority: int = 15, - block: bool = False, - ) -> Callable[[TTarget[R]], StepOut[R]]: - def wrapper(func: TTarget[R]): - nonlocal events - if not events: - events = self._events - return StepOut(list(events), func, providers, auxiliaries, priority, block) # type: ignore - - return wrapper - - on = Publisher.register - handle = Publisher.register - - -class PluginDispatcherFactory(ABC): - @abstractmethod - def dispatch(self, plugin: Plugin) -> PluginDispatcher: ... - - -MAPPING: dict[type, Callable[..., PluginDispatcherFactory]] = {} - -T = TypeVar("T") - - -def register_factory(cls: type[T], factory: Callable[[T, Unpack[tuple[Any, ...]]], PluginDispatcherFactory]): - MAPPING[cls] = factory - - -@dataclass -class Plugin: - author: list[str] = field(default_factory=list) - name: str | None = None - version: str | None = None - license: str | None = None - urls: dict[str, str] | None = None - description: str | None = None - icon: str | None = None - classifier: list[str] = field(default_factory=list) - dependencies: list[str] = field(default_factory=list) - - # standards: list[str] = field(default_factory=list) - # frameworks: list[str] = field(default_factory=list) - # config_endpoints: list[str] = field(default_factory=list) - # component_endpoints: list[str] = field(default_factory=list) - - _dispatchers: dict[str, PluginDispatcher] = field(default_factory=dict, init=False) - - def __post_init__(self): - self.name = self.name or self.__module__ - - def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | None = None): - disp = PluginDispatcher(self, *events, predicate=predicate) - self._dispatchers[disp.id] = disp - return disp - - @overload - def mount(self, factory: PluginDispatcherFactory) -> PluginDispatcher: ... - - @overload - def mount(self, factory: object, *args, **kwargs) -> PluginDispatcher: ... - - def mount(self, factory: Any, *args, **kwargs): - if isinstance(factory, PluginDispatcherFactory): - disp = factory.dispatch(self) - elif factory_cls := MAPPING.get(factory.__class__): - disp = factory_cls(factory, *args, **kwargs).dispatch(self) - else: - raise TypeError(f"unsupported factory {factory!r}") - self._dispatchers[disp.id] = disp - return disp - - def dispose(self): - for disp in self._dispatchers.values(): - if disp.id in dispatchers: - del dispatchers[disp.id] - if es := system_ctx.get(): - es.publishers.pop(disp.id, None) - self._dispatchers.clear() - - -def load_plugin(path: str) -> list[Plugin] | None: - """ - 以导入路径方式加载模块 - - Args: - path (str): 模块路径 - """ - try: - imported_module = importlib.import_module(path, path) - logger.success(f"loaded plugin {path!r}") - return [m for _, m in inspect.getmembers(imported_module, lambda x: isinstance(x, Plugin))] - except Exception as e: - logger.error(f"failed to load plugin {path!r} caused by {e!r}") - - -def load_plugins(dir_: str | PathLike | Path): - path = dir_ if isinstance(dir_, Path) else Path(dir_) - if path.is_dir(): - for p in path.iterdir(): - if p.suffix in (".py", "") and p.stem not in {"__init__", "__pycache__"}: - load_plugin(".".join(p.parts[:-1:1]) + "." + p.stem) - elif path.is_file(): - load_plugin(".".join(path.parts[:-1:1]) + "." + path.stem) diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py new file mode 100644 index 0000000..62007df --- /dev/null +++ b/arclet/entari/plugin/__init__.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from os import PathLike +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +from loguru import logger + +from .model import Plugin, PluginDispatcher +from .model import PluginMetadata as PluginMetadata +from .model import _current_plugin, _plugins +from .module import import_plugin + +if TYPE_CHECKING: + from ..event import Event + + +def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = None): + if not (plugin := _current_plugin.get()): + raise LookupError("no plugin context found") + disp = PluginDispatcher(plugin, *events, predicate=predicate) + return disp + + +def load_plugin(path: str) -> Plugin | None: + """ + 以导入路径方式加载模块 + + Args: + path (str): 模块路径 + """ + if path in _plugins: + return _plugins[path] + try: + mod = import_plugin(path) + if not mod: + logger.error(f"cannot found plugin {path!r}") + return + logger.success(f"loaded plugin {path!r}") + return mod.__plugin__ + except Exception as e: + logger.error(f"failed to load plugin {path!r} caused by {e!r}") + + +def load_plugins(dir_: str | PathLike | Path): + path = dir_ if isinstance(dir_, Path) else Path(dir_) + if path.is_dir(): + for p in path.iterdir(): + if p.suffix in (".py", "") and p.stem not in {"__init__", "__pycache__"}: + load_plugin(".".join(p.parts[:-1:1]) + "." + p.stem) + elif path.is_file(): + load_plugin(".".join(path.parts[:-1:1]) + "." + path.stem) + + +def dispose(plugin: str): + if plugin not in _plugins: + return + _plugin = _plugins[plugin] + _plugin.dispose() diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py new file mode 100644 index 0000000..97a2a57 --- /dev/null +++ b/arclet/entari/plugin/model.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from contextvars import ContextVar +from dataclasses import dataclass, field +from types import ModuleType +from typing import TYPE_CHECKING, Callable +from weakref import WeakValueDictionary, finalize + +from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx +from arclet.letoderea.builtin.breakpoint import R +from arclet.letoderea.typing import TTarget + +if TYPE_CHECKING: + from ..event import Event + +_current_plugin: ContextVar[Plugin | None] = ContextVar("_current_plugin", default=None) + +_plugins: dict[str, Plugin] = {} + + +class PluginDispatcher(Publisher): + def __init__( + self, + plugin: Plugin, + *events: type[Event], + predicate: Callable[[Event], bool] | None = None, + ): + super().__init__(f"{plugin.id}@{id(plugin)}", *events, predicate=predicate) # type: ignore + self.plugin = plugin + plugin.dispatchers[self.id] = self + self._run_by_system = False + if es := system_ctx.get(): + es.register(self) + self._run_by_system = True + self._events = events + + def waiter( + self, + *events: type[Event], + providers: list[Provider | type[Provider]] | None = None, + auxiliaries: list[BaseAuxiliary] | None = None, + priority: int = 15, + block: bool = False, + ) -> Callable[[TTarget[R]], StepOut[R]]: + def wrapper(func: TTarget[R]): + nonlocal events + if not events: + events = self._events + return StepOut(list(events), func, providers, auxiliaries, priority, block) # type: ignore + + return wrapper + + def dispose(self): + if self._run_by_system: + if es := system_ctx.get(): + es.publishers.pop(self.id, None) + self._run_by_system = False + self.subscribers.clear() + + on = Publisher.register + handle = Publisher.register + __call__ = Publisher.register + + +@dataclass +class PluginMetadata: + name: str + author: list[str] = field(default_factory=list) + version: str | None = None + license: str | None = None + urls: dict[str, str] | None = None + description: str | None = None + icon: str | None = None + classifier: list[str] = field(default_factory=list) + dependencies: list[str] = field(default_factory=list) + + # standards: list[str] = field(default_factory=list) + # frameworks: list[str] = field(default_factory=list) + # config_endpoints: list[str] = field(default_factory=list) + # component_endpoints: list[str] = field(default_factory=list) + + +@dataclass +class Plugin: + id: str + module: ModuleType + dispatchers: WeakValueDictionary[str, PluginDispatcher] = field(default_factory=WeakValueDictionary) + metadata: PluginMetadata | None = None + _is_disposed: bool = False + + @staticmethod + def current() -> Plugin | None: + return _current_plugin.get() + + def __post_init__(self): + _plugins[self.id] = self + finalize(self, self.dispose) + + def dispose(self): + if self._is_disposed: + return + self._is_disposed = True + for disp in self.dispatchers.values(): + disp.dispose() + self.dispatchers.clear() + del _plugins[self.id] + del self.module diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py new file mode 100644 index 0000000..d279e1d --- /dev/null +++ b/arclet/entari/plugin/module.py @@ -0,0 +1,97 @@ +from collections.abc import Sequence +from importlib.abc import MetaPathFinder +from importlib.machinery import PathFinder, SourceFileLoader +from importlib.util import module_from_spec, resolve_name +import sys +from types import ModuleType +from typing import Optional + +from .model import Plugin, PluginMetadata, _current_plugin, _plugins + + +class PluginFinder(MetaPathFinder): + def find_spec( + self, + fullname: str, + path: Optional[Sequence[str]], + target: Optional[ModuleType] = None, + ): + module_spec = PathFinder.find_spec(fullname, path, target) + if not module_spec: + return + module_origin = module_spec.origin + if not module_origin: + return + + module_spec.loader = PluginLoader(fullname, module_origin) + return module_spec + + +class PluginLoader(SourceFileLoader): + def __init__(self, fullname: str, path: str) -> None: + self.loaded = False + super().__init__(fullname, path) + + def create_module(self, spec) -> Optional[ModuleType]: + if self.name in _plugins: + self.loaded = True + return _plugins[self.name].module + return super().create_module(spec) + + def exec_module(self, module: ModuleType) -> None: + if self.loaded: + return + + # create plugin before executing + plugin = Plugin(module.__name__, module) + setattr(module, "__plugin__", plugin) + + # enter plugin context + _plugin_token = _current_plugin.set(plugin) + + try: + super().exec_module(module) + # except Exception: + # # _revert_plugin(plugin) + # raise + finally: + # leave plugin context + _current_plugin.reset(_plugin_token) + + # get plugin metadata + metadata: Optional[PluginMetadata] = getattr(module, "__plugin_metadata__", None) + plugin.metadata = metadata + return + + +_finder = PluginFinder() + + +def find_spec(name, package=None): + fullname = resolve_name(name, package) if name.startswith(".") else name + parent_name = fullname.rpartition(".")[0] + if parent_name: + parent = __import__(parent_name, fromlist=["__path__"]) + try: + parent_path = parent.__path__ + except AttributeError as e: + raise ModuleNotFoundError( + f"__path__ attribute not found on {parent_name!r} " f"while trying to find {fullname!r}", + name=fullname, + ) from e + else: + parent_path = None + return _finder.find_spec(fullname, parent_path) + + +def import_plugin(name, package=None): + spec = find_spec(name, package) + if spec: + mod = module_from_spec(spec) + if spec.loader: + spec.loader.exec_module(mod) + return mod + return + + +sys.meta_path.insert(0, _finder) diff --git a/example_plugin.py b/example_plugin.py index b9c7c8c..4a08f46 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -1,21 +1,33 @@ +import re + from arclet.alconna import Alconna, AllParam, Args from arclet.entari import ( Session, - EntariCommands, MessageChain, MessageCreatedEvent, - Plugin, + PluginMetadata, + command, is_direct_message, ) from arclet.entari.command import Match -plug = Plugin() +__plugin_metadata__ = PluginMetadata() + +disp_message = MessageCreatedEvent.dispatch() + + +@disp_message.on() +async def _(msg: MessageChain): + content = msg.extract_plain_text() + if re.match(r"(.{0,3})(上传|设定)(.{0,3})(上传|设定)(.{0,3})", content): + return "上传设定的帮助是..." + -disp_message = plug.dispatch(MessageCreatedEvent) from satori import select, Author + @disp_message.on(auxiliaries=[]) async def _(event: MessageCreatedEvent): print(event.content) @@ -24,10 +36,10 @@ async def _(event: MessageCreatedEvent): reply_self = author.id == event.account.self_id -on_alconna = plug.mount(Alconna("echo", Args["content?", AllParam])) +on_alconna = command.mount(Alconna("echo", Args["content?", AllParam])) -@on_alconna.on() +@on_alconna() async def _(content: Match[MessageChain], session: Session): if content.available: await session.send(content.result) @@ -36,9 +48,6 @@ async def _(content: Match[MessageChain], session: Session): await session.send(await session.prompt("请输入内容")) -commands = EntariCommands.current() - - -@commands.on("add {a} {b}") +@command.on("add {a} {b}") async def add(a: int, b: int, session: Session): await session.send_message(f"{a + b =}") diff --git a/main.py b/main.py index 3a1200a..be3f519 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,9 @@ from satori import Image -from arclet.entari import Session, Entari, EntariCommands, WebsocketsInfo, load_plugin +from arclet.entari import Session, Entari, command, WebsocketsInfo, load_plugin, dispose_plugin -commands = EntariCommands() - -@commands.on("echoimg {img}") +@command.on("echoimg {img}") async def echoimg(img: Image, session: Session): await session.send_message([img]) @@ -15,8 +13,15 @@ async def echoimg(img: Image, session: Session): app = Entari(WebsocketsInfo(host="127.0.0.1", port=5140, path="satori")) -@app.on_message() -async def repeat(session: Session): - await session.send(session.content) +@command.on("load {plugin}") +async def load(plugin: str, session: Session): + load_plugin(plugin) + await session.send_message(f"Loaded {plugin}") + + +@command.on("unload {plugin}") +async def unload(plugin: str, session: Session): + dispose_plugin(plugin) + await session.send_message(f"Unloaded {plugin}") app.run() diff --git a/pyproject.toml b/pyproject.toml index 7952218..cc8a7b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arclet-entari" -version = "0.5.1" +version = "0.6.0" description = "Simple IM Framework based on satori-python" authors = [ {name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},