diff --git a/arclet/entari/__init__.py b/arclet/entari/__init__.py index 48c12e6..ed09f27 100644 --- a/arclet/entari/__init__.py +++ b/arclet/entari/__init__.py @@ -48,6 +48,7 @@ from .plugin import Plugin as Plugin from .plugin import PluginMetadata as PluginMetadata from .plugin import dispose as dispose_plugin # noqa: F401 +from .plugin import keeping as keeping from .plugin import load_plugin as load_plugin from .plugin import load_plugins as load_plugins from .plugin import metadata as metadata diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index e0c906b..0d4088e 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -12,8 +12,8 @@ config, output_manager, ) -from arclet.alconna.typing import TAValue from arclet.alconna.tools.construct import AlconnaString, alconna_from_format +from arclet.alconna.typing import TAValue from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber from arclet.letoderea.handler import depend_handler from arclet.letoderea.provider import ProviderFactory diff --git a/arclet/entari/event.py b/arclet/entari/event.py index ced0358..e65b002 100644 --- a/arclet/entari/event.py +++ b/arclet/entari/event.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Callable, ClassVar, TypeVar, Generic, Any +from typing import Any, Callable, ClassVar, Generic, TypeVar from arclet.letoderea import Contexts, Param, Provider from satori import ArgvInteraction, ButtonInteraction, Channel @@ -31,7 +31,7 @@ def __get__(self, instance: Event, owner: type[Event]) -> T: def __set__(self, instance: Event, value): raise AttributeError("can't set attribute") - + def attr(key: str | None = None) -> Any: return Attr(key) @@ -67,7 +67,18 @@ async def gather(self, context: Contexts): context["$account"] = self.account context["$origin_event"] = self._origin - for attrname in {"argv", "button", "channel", "guild", "login", "member", "message", "operator", "role", "user"}: + for attrname in { + "argv", + "button", + "channel", + "guild", + "login", + "member", + "message", + "operator", + "role", + "user", + }: value = getattr(self, attrname) if value is not None: context["$message_origin" if attrname == "message" else attrname] = value diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index e89b143..9c37346 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -8,9 +8,10 @@ from loguru import logger from tarina import init_spec -from .model import Plugin, RegisterNotInPluginError +from .model import Plugin from .model import PluginMetadata as PluginMetadata -from .model import _current_plugin +from .model import RegisterNotInPluginError, _current_plugin +from .model import keeping as keeping from .module import import_plugin from .module import package as package from .service import service @@ -40,6 +41,7 @@ def load_plugin(path: str) -> Plugin | None: logger.error(f"cannot found plugin {path!r}") return logger.success(f"loaded plugin {path!r}") + return mod.__plugin__ except RegisterNotInPluginError as e: logger.exception(f"{e.args[0]}", exc_info=e) diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index e066b33..be3f89c 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -4,14 +4,13 @@ from contextvars import ContextVar from dataclasses import dataclass, field from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import finalize from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx from arclet.letoderea.builtin.breakpoint import R from arclet.letoderea.typing import TTarget from satori.client import Account -from tarina import init_spec from .service import service @@ -66,6 +65,7 @@ def dispose(self): if TYPE_CHECKING: register = Publisher.register else: + def register(self, *args, **kwargs): wrapper = super().register(*args, **kwargs) @@ -144,6 +144,8 @@ def metadata(self) -> PluginMetadata | None: def __post_init__(self): service.plugins[self.id] = self + if self.id not in service._keep_values: + service._keep_values[self.id] = {} finalize(self, self.dispose) def dispose(self): @@ -169,6 +171,34 @@ def validate(self, func): if "__plugin__" in func.__globals__ and func.__globals__["__plugin__"] is self: return raise RegisterNotInPluginError( - f"Handler {func.__qualname__} should define in the same module as the plugin: {self.module.__name__}. " - f"Please use the `load_plugin({func.__module__!r})` or `package({func.__module__!r})` before import it." + f"Handler {func.__qualname__} should define " + f"in the same module as the plugin: {self.module.__name__}. " + f"Please use the `load_plugin({func.__module__!r})` or " + f"`package({func.__module__!r})` before import it." ) + + +class KeepingVariable: + def __init__(self, obj: T, dispose: Callable[[T], None] | None = None): + self.obj = obj + self._dispose = dispose + + def dispose(self): + if hasattr(self.obj, "dispose"): + self.obj.dispose() # type: ignore + elif self._dispose: + self._dispose(self.obj) + del self.obj + + +T = TypeVar("T") + + +def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T: + if not (plug := _current_plugin.get(None)): + raise LookupError("no plugin context found") + if id_ not in service._keep_values[plug.id]: + service._keep_values[plug.id][id_] = KeepingVariable(obj, dispose) + else: + obj = service._keep_values[plug.id][id_].obj # type: ignore + return obj diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 1230ce9..9af5f81 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -9,7 +9,6 @@ from .model import Plugin, PluginMetadata, _current_plugin from .service import service - _SUBMODULE_WAITLIST = set() diff --git a/arclet/entari/plugin/service.py b/arclet/entari/plugin/service.py index 553eccf..617184c 100644 --- a/arclet/entari/plugin/service.py +++ b/arclet/entari/plugin/service.py @@ -6,17 +6,19 @@ from loguru import logger if TYPE_CHECKING: - from .model import Plugin + from .model import KeepingVariable, Plugin class PluginService(Service): id = "arclet.entari.plugin_service" plugins: dict[str, "Plugin"] + _keep_values: dict[str, dict[str, "KeepingVariable"]] def __init__(self): super().__init__() self.plugins = {} + self._keep_values = {} @property def required(self) -> set[str]: @@ -46,6 +48,11 @@ async def launch(self, manager: Launart): except Exception as e: logger.error(f"failed to dispose plugin {plug.id} caused by {e!r}") self.plugins.pop(plug_id, None) + for values in self._keep_values.values(): + for value in values.values(): + value.dispose() + values.clear() + self._keep_values.clear() service = PluginService() diff --git a/arclet/entari/session.py b/arclet/entari/session.py index 13936d8..d50fc18 100644 --- a/arclet/entari/session.py +++ b/arclet/entari/session.py @@ -120,7 +120,9 @@ async def send( ) -> list[MessageObject]: if not protocol_cls: return await self.account.protocol.send(self.context, message) - return await self.account.custom(self.account.config, protocol_cls).send(self.context._origin, message) + return await self.account.custom(self.account.config, protocol_cls).send( + self.context._origin, message + ) async def send_message( self, @@ -269,7 +271,9 @@ async def guild_member_role_unset(self, role_id: str, user_id: str | None = None if not self.context.guild: raise RuntimeError("Event cannot use to guild member role unset!") if user_id: - return await self.account.protocol.guild_member_role_unset(self.context.guild.id, user_id, role_id) + return await self.account.protocol.guild_member_role_unset( + self.context.guild.id, user_id, role_id + ) if not self.context.user: raise RuntimeError("Event cannot use to guild member role unset!") return await self.account.protocol.guild_member_role_unset( diff --git a/example_plugin.py b/example_plugin.py index 53d849f..c2a312f 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -11,6 +11,7 @@ is_public_message, bind, metadata, + keeping ) from arclet.entari.command import Match @@ -70,3 +71,17 @@ async def _(content: Match[MessageChain], session: Session): @command.on("add {a} {b}") async def add(a: int, b: int, session: Session): await session.send_message(f"{a + b =}") + + +kept_data = keeping("foo", [], lambda x: x.clear()) + + +@command.on("append {data}") +async def append(data: str, session: Session): + kept_data.append(data) + await session.send_message(f"Appended {data}") + + +@command.on("show") +async def show(session: Session): + await session.send_message(f"Data: {kept_data}")