From 44f159cce5f3dbf098353250acab4beb2b7bedd1 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Mon, 30 Sep 2024 18:32:37 +0800 Subject: [PATCH] :sparkles: use AST to resolve related import --- arclet/entari/command/__init__.py | 32 ++-------- arclet/entari/command/provider.py | 11 +--- arclet/entari/core.py | 14 +---- arclet/entari/plugin/__init__.py | 2 +- arclet/entari/plugin/model.py | 2 + arclet/entari/plugin/module.py | 100 +++++++++++++++++++++++++++--- arclet/entari/session.py | 34 +++------- pyproject.toml | 6 +- 8 files changed, 113 insertions(+), 88 deletions(-) diff --git a/arclet/entari/command/__init__.py b/arclet/entari/command/__init__.py index dabcd58..11e9e69 100644 --- a/arclet/entari/command/__init__.py +++ b/arclet/entari/command/__init__.py @@ -1,17 +1,7 @@ import asyncio from typing import Any, Callable, Optional, TypeVar, Union, cast, overload -from arclet.alconna import ( - Alconna, - Arg, - Args, - Arparma, - CommandMeta, - Namespace, - command_manager, - config, - output_manager, -) +from arclet.alconna import Alconna, Arg, Args, Arparma, CommandMeta, Namespace, command_manager, config, output_manager from arclet.alconna.tools.construct import AlconnaString, alconna_from_format from arclet.alconna.typing import TAValue from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber @@ -54,9 +44,7 @@ async def listener(event: MessageCreatedEvent): if not msg: return if matches := list(self.trie.prefixes(msg)): - await asyncio.gather( - *(depend_handler(res.value, event, inner=True) for res in matches if res.value) - ) + await asyncio.gather(*(depend_handler(res.value, event, inner=True) for res in matches if res.value)) return # shortcut data = split(msg, (" ",)) @@ -116,9 +104,7 @@ def command( need_tome: bool = False, remove_tome: bool = True, auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, + providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None, ): class Command(AlconnaString): def __call__(_cmd_self, func: TCallable) -> TCallable: @@ -133,9 +119,7 @@ def on( need_tome: bool = False, remove_tome: bool = True, auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, + providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None, ) -> Callable[[TCallable], TCallable]: ... @overload @@ -145,9 +129,7 @@ def on( need_tome: bool = False, remove_tome: bool = True, auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, + providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None, *, args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, meta: Optional[CommandMeta] = None, @@ -159,9 +141,7 @@ def on( need_tome: bool = False, remove_tome: bool = True, auxiliaries: Optional[list[BaseAuxiliary]] = None, - providers: Optional[ - list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]] - ] = None, + providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None, *, args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, meta: Optional[CommandMeta] = None, diff --git a/arclet/entari/command/provider.py b/arclet/entari/command/provider.py index f258c86..75d1118 100644 --- a/arclet/entari/command/provider.py +++ b/arclet/entari/command/provider.py @@ -4,16 +4,7 @@ from arclet.alconna import Alconna, Arparma, Duplication, Empty, output_manager from arclet.alconna.builtin import generate_duplication -from arclet.letoderea import ( - Contexts, - Interface, - JudgeAuxiliary, - Param, - Provider, - Scope, - Subscriber, - SupplyAuxiliary, -) +from arclet.letoderea import Contexts, Interface, JudgeAuxiliary, Param, Provider, Scope, Subscriber, SupplyAuxiliary from arclet.letoderea.provider import ProviderFactory from nepattern.util import CUnionType from satori.client import Account diff --git a/arclet/entari/core.py b/arclet/entari/core.py index 1987618..dab6dca 100644 --- a/arclet/entari/core.py +++ b/arclet/entari/core.py @@ -3,15 +3,7 @@ import asyncio from contextlib import suppress -from arclet.letoderea import ( - BaseAuxiliary, - Contexts, - EventSystem, - Param, - Provider, - ProviderFactory, - global_providers, -) +from arclet.letoderea import BaseAuxiliary, Contexts, EventSystem, Param, Provider, ProviderFactory, global_providers from launart import Launart from loguru import logger from satori import LoginStatus @@ -80,9 +72,7 @@ def on_message( auxiliaries: list[BaseAuxiliary] | None = None, providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None, ): - return self.event_system.on( - MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers - ) + return self.event_system.on(MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers) def ensure_manager(self, manager: Launart): self.manager = manager diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index 9c37346..f27f58f 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -46,7 +46,7 @@ def load_plugin(path: str) -> Plugin | None: except RegisterNotInPluginError as e: logger.exception(f"{e.args[0]}", exc_info=e) except Exception as e: - logger.error(f"failed to load plugin {path!r} caused by {e!r}") + logger.exception(f"failed to load plugin {path!r} caused by {e!r}", exc_info=e) def load_plugins(dir_: str | PathLike | Path): diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index d03b4fa..10f74fa 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -157,7 +157,9 @@ def dispose(self): if self.module.__spec__ and self.module.__spec__.cached: Path(self.module.__spec__.cached).unlink(missing_ok=True) sys.modules.pop(self.module.__name__, None) + delattr(self.module, "__plugin__") for submod in self.submodules.values(): + delattr(submod, "__plugin__") sys.modules.pop(submod.__name__, None) if submod.__spec__ and submod.__spec__.cached: Path(submod.__spec__.cached).unlink(missing_ok=True) diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 9dd2289..2fc28d2 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -1,4 +1,6 @@ +import ast from collections.abc import Sequence +from importlib import _bootstrap # type: ignore from importlib.abc import MetaPathFinder from importlib.machinery import PathFinder, SourceFileLoader from importlib.util import module_from_spec, resolve_name @@ -9,12 +11,21 @@ from .model import Plugin, PluginMetadata, _current_plugin from .service import service -_SUBMODULE_WAITLIST = set() +_SUBMODULE_WAITLIST: dict[str, set[str]] = {} def package(*names: str): """手动指定特定模块作为插件的子模块""" - _SUBMODULE_WAITLIST.update(names) + if not (plugin := _current_plugin.get(None)): + raise LookupError("no plugin context found") + _SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names) + + +def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]): + if mod == ".": + return tuple(import_plugin(f".{alias}", __fullname) for alias in aliases) + _mod = import_plugin(f".{mod}", __fullname) if mod else import_plugin(__fullname) + return tuple(getattr(_mod, alias) for alias in aliases) class PluginLoader(SourceFileLoader): @@ -23,6 +34,70 @@ def __init__(self, fullname: str, path: str, parent_plugin_id: Optional[str] = N self.parent_plugin_id = parent_plugin_id super().__init__(fullname, path) + def source_to_code(self, data, path, *, _optimize=-1): # type: ignore + """Return the code object compiled from source. + + The 'data' argument can be any object type that compile() supports. + """ + nodes = ast.parse(data, type_comments=True) + for i, body in enumerate(nodes.body): + if isinstance(body, ast.ImportFrom): + if body.level == 0 and ( + body.module in _SUBMODULE_WAITLIST.get(self.name, ()) or body.module in service.plugins + ): + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + f"=__unpack_import_from('{self.name}', '', {[alias.name for alias in body.names]!r})" + ).body[0] + nodes.body[i].lineno = body.lineno + nodes.body[i].end_lineno = body.end_lineno + if body.level == 1: + if body.module is None: + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})" + ).body[0] + nodes.body[i].lineno = body.lineno + nodes.body[i].end_lineno = body.end_lineno + else: + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + ( + f"=__unpack_import_from('{self.name}', {body.module!r}, " + f"{[alias.name for alias in body.names]!r})" + ) + ).body[0] + nodes.body[i].lineno = body.lineno + nodes.body[i].end_lineno = body.end_lineno + elif ( + isinstance(body, ast.Expr) + and isinstance(body.value, ast.Call) + and isinstance(body.value.func, ast.Name) + and body.value.func.id == "package" + ): + if body.value.args and isinstance(body.value.args[0], ast.Constant): + _SUBMODULE_WAITLIST.setdefault(self.name, set()).update(arg.value for arg in body.value.args) # type: ignore + elif isinstance(body, ast.Import): + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + "=" + + ",".join( + ( + f"__import_plugin({alias.name!r})" + if (alias.name in _SUBMODULE_WAITLIST.get(self.name, ()) or alias.name in service.plugins) + else f"__import__({alias.name!r})" + ) + for alias in body.names + ) + ).body[0] + nodes.body[i].lineno = body.lineno + nodes.body[i].end_lineno = body.end_lineno + return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize) + def create_module(self, spec) -> Optional[ModuleType]: if self.name in service.plugins: self.loaded = True @@ -30,12 +105,12 @@ def create_module(self, spec) -> Optional[ModuleType]: return super().create_module(spec) def exec_module(self, module: ModuleType) -> None: - if plugin := _current_plugin.get( - service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None - ): + if plugin := _current_plugin.get(service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None): if module.__name__ == plugin.module.__name__: # from . import xxxx return setattr(module, "__plugin__", plugin) + setattr(module, "__unpack_import_from", _unpack_import_from) + setattr(module, "__import_plugin", import_plugin) try: super().exec_module(module) except Exception: @@ -51,6 +126,8 @@ def exec_module(self, module: ModuleType) -> None: # create plugin before executing plugin = Plugin(module.__name__, module) setattr(module, "__plugin__", plugin) + setattr(module, "__unpack_import_from", _unpack_import_from) + setattr(module, "__import_plugin", import_plugin) # enter plugin context _plugin_token = _current_plugin.set(plugin) @@ -75,7 +152,10 @@ def find_spec(name, package=None): fullname = resolve_name(name, package) if name.startswith(".") else name parent_name = fullname.rpartition(".")[0] if parent_name: - parent = __import__(parent_name, fromlist=["__path__"]) + if parent_name in service.plugins: + parent = service.plugins[parent_name].module + else: + parent = __import__(parent_name, fromlist=["__path__"]) try: parent_path = parent.__path__ except AttributeError as e: @@ -122,11 +202,11 @@ def find_spec( if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin: return plug.module.__spec__ if module_spec.parent and module_spec.parent == plug.module.__name__: - module_spec.loader = PluginLoader(fullname, module_origin) + module_spec.loader = PluginLoader(fullname, module_origin, plug.id) return module_spec - elif module_spec.name in _SUBMODULE_WAITLIST: - module_spec.loader = PluginLoader(fullname, module_origin) - _SUBMODULE_WAITLIST.remove(module_spec.name) + elif module_spec.name in _SUBMODULE_WAITLIST[plug.module.__name__]: + module_spec.loader = PluginLoader(fullname, module_origin, plug.id) + # _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name) return module_spec if module_spec.name in service.plugins: diff --git a/arclet/entari/session.py b/arclet/entari/session.py index d50fc18..d197bfd 100644 --- a/arclet/entari/session.py +++ b/arclet/entari/session.py @@ -46,11 +46,7 @@ async def waiter(content: MessageChain, session: Session[MessageEvent]): if self.context.channel: if self.context.channel.id == session.context.channel.id and ( not keep_sender - or ( - self.context.user - and session.context.user - and self.context.user.id == session.context.user.id - ) + or (self.context.user and session.context.user and self.context.user.id == session.context.user.id) ): return content elif self.context.user: @@ -120,9 +116,7 @@ async def send( ) -> list[MessageObject]: if not protocol_cls: return await self.account.protocol.send(self.context, message) - return await self.account.custom(self.account.config, protocol_cls).send( - self.context._origin, message - ) + return await self.account.custom(self.account.config, protocol_cls).send(self.context._origin, message) async def send_message( self, @@ -163,9 +157,7 @@ async def update_message( raise RuntimeError("Event cannot be replied to!") if not self.context.message: raise RuntimeError("Event cannot update message") - return await self.account.protocol.update_message( - self.context.channel, self.context.message.id, message - ) + return await self.account.protocol.update_message(self.context.channel, self.context.message.id, message) async def message_create( self, @@ -252,9 +244,7 @@ async def guild_member_kick(self, user_id: str | None = None, permanent: bool = return await self.account.protocol.guild_member_kick(self.context.guild.id, user_id, permanent) if not self.context.user: raise RuntimeError("Event cannot use to kick member!") - return await self.account.protocol.guild_member_kick( - self.context.guild.id, self.context.user.id, permanent - ) + return await self.account.protocol.guild_member_kick(self.context.guild.id, self.context.user.id, permanent) async def guild_member_role_set(self, role_id: str, user_id: str | None = None) -> None: if not self.context.guild: @@ -263,22 +253,16 @@ async def guild_member_role_set(self, role_id: str, user_id: str | None = None) return await self.account.protocol.guild_member_role_set(self.context.guild.id, user_id, role_id) if not self.context.user: raise RuntimeError("Event cannot use to guild member role set!") - return await self.account.protocol.guild_member_role_set( - self.context.guild.id, self.context.user.id, role_id - ) + return await self.account.protocol.guild_member_role_set(self.context.guild.id, self.context.user.id, role_id) async def guild_member_role_unset(self, role_id: str, user_id: str | None = None) -> None: if not self.context.guild: raise RuntimeError("Event cannot use to guild member role unset!") if user_id: - return await self.account.protocol.guild_member_role_unset( - self.context.guild.id, user_id, role_id - ) + return await self.account.protocol.guild_member_role_unset(self.context.guild.id, user_id, role_id) if not self.context.user: raise RuntimeError("Event cannot use to guild member role unset!") - return await self.account.protocol.guild_member_role_unset( - self.context.guild.id, self.context.user.id, role_id - ) + return await self.account.protocol.guild_member_role_unset(self.context.guild.id, self.context.user.id, role_id) async def guild_role_list(self, next_token: str | None = None) -> PageResult[Role]: if not self.context.guild: @@ -318,9 +302,7 @@ async def reaction_create( raise RuntimeError("Event cannot be replied to!") if not self.context.message: raise RuntimeError("Event cannot create reaction") - return await self.account.protocol.reaction_create( - self.context.channel.id, self.context.message.id, emoji - ) + return await self.account.protocol.reaction_create(self.context.channel.id, self.context.message.id, emoji) async def reaction_delete( self, diff --git a/pyproject.toml b/pyproject.toml index 273ed46..a869ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dev = [ "fix-future-annotations>=0.5.0", ] [tool.black] -line-length = 110 +line-length = 120 target-version = ["py39", "py310", "py311", "py312"] include = '\.pyi?$' extend-exclude = ''' @@ -48,13 +48,13 @@ extend-exclude = ''' [tool.isort] profile = "black" -line_length = 110 +line_length = 120 skip_gitignore = true force_sort_within_sections = true extra_standard_library = ["typing_extensions"] [tool.ruff] -line-length = 110 +line-length = 120 target-version = "py39" include = ["arclet/**.py"]