diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index f27f58f..87aab29 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -61,9 +61,10 @@ def load_plugins(dir_: str | PathLike | Path): def dispose(plugin: str): if plugin not in service.plugins: - return + return False _plugin = service.plugins[plugin] _plugin.dispose() + return True @init_spec(PluginMetadata) diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index 10f74fa..d3fba96 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -188,6 +188,10 @@ def validate(self, func): f"`package({func.__module__!r})` before import it." ) + @property + def proxy(self): + return _ProxyModule(self.id) + class KeepingVariable: def __init__(self, obj: T, dispose: Callable[[T], None] | None = None): @@ -213,3 +217,20 @@ def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T: else: obj = service._keep_values[plug.id][id_].obj # type: ignore return obj + + +class _ProxyModule: + def __init__(self, plugin_id: str) -> None: + self.__plugin_id = plugin_id + + def __getattr__(self, name: str): + if self.__plugin_id not in service.plugins: + raise NameError(f"Plugin {self.__plugin_id!r} is not loaded") + return getattr(service.plugins[self.__plugin_id].module, name) + + def __setattr__(self, name: str, value): + if name == "_ProxyModule__plugin_id": + return super().__setattr__(name, value) + if self.__plugin_id not in service.plugins: + raise NameError(f"Plugin {self.__plugin_id!r} is not loaded") + setattr(service.plugins[self.__plugin_id].module, name, value) diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 2fc28d2..c5ca641 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -28,6 +28,14 @@ def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]): return tuple(getattr(_mod, alias) for alias in aliases) +def _check_import(name: str, plugin_name: str): + if name in service.plugins: + return service.plugins[name].proxy + if name in _SUBMODULE_WAITLIST.get(plugin_name, ()): + return import_plugin(name) + return __import__(name) + + class PluginLoader(SourceFileLoader): def __init__(self, fullname: str, path: str, parent_plugin_id: Optional[str] = None) -> None: self.loaded = False @@ -50,8 +58,9 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore ",".join(aliases) + f"=__unpack_import_from('{self.name}', '', {[alias.name for alias in body.names]!r})" ).body[0] - nodes.body[i].lineno = body.lineno - nodes.body[i].end_lineno = body.end_lineno + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore if body.level == 1: if body.module is None: aliases = [alias.asname or alias.name for alias in body.names] @@ -59,8 +68,9 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore ",".join(aliases) + f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})" ).body[0] - nodes.body[i].lineno = body.lineno - nodes.body[i].end_lineno = body.end_lineno + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore else: aliases = [alias.asname or alias.name for alias in body.names] nodes.body[i] = ast.parse( @@ -70,8 +80,9 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore f"{[alias.name for alias in body.names]!r})" ) ).body[0] - nodes.body[i].lineno = body.lineno - nodes.body[i].end_lineno = body.end_lineno + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore elif ( isinstance(body, ast.Expr) and isinstance(body.value, ast.Call) @@ -85,17 +96,11 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore nodes.body[i] = ast.parse( ",".join(aliases) + "=" - + ",".join( - ( - f"__import_plugin({alias.name!r})" - if (alias.name in _SUBMODULE_WAITLIST.get(self.name, ()) or alias.name in service.plugins) - else f"__import__({alias.name!r})" - ) - for alias in body.names - ) + + ",".join((f"__check_import({alias.name!r}, {self.name!r})") for alias in body.names) ).body[0] - nodes.body[i].lineno = body.lineno - nodes.body[i].end_lineno = body.end_lineno + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize) def create_module(self, spec) -> Optional[ModuleType]: @@ -105,12 +110,12 @@ def create_module(self, spec) -> Optional[ModuleType]: return super().create_module(spec) def exec_module(self, module: ModuleType) -> None: - if plugin := _current_plugin.get(service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None): + if plugin := service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None: if module.__name__ == plugin.module.__name__: # from . import xxxx return setattr(module, "__plugin__", plugin) setattr(module, "__unpack_import_from", _unpack_import_from) - setattr(module, "__import_plugin", import_plugin) + setattr(module, "__check_import", _check_import) try: super().exec_module(module) except Exception: @@ -127,7 +132,7 @@ def exec_module(self, module: ModuleType) -> None: plugin = Plugin(module.__name__, module) setattr(module, "__plugin__", plugin) setattr(module, "__unpack_import_from", _unpack_import_from) - setattr(module, "__import_plugin", import_plugin) + setattr(module, "__check_import", _check_import) # enter plugin context _plugin_token = _current_plugin.set(plugin) @@ -148,6 +153,41 @@ def exec_module(self, module: ModuleType) -> None: return +class _PluginFinder(MetaPathFinder): + @classmethod + def find_spec( + cls, + fullname: str, + path: Optional[Sequence[str]], + target: Optional[ModuleType] = None, + ): + module_spec = PathFinder.find_spec(fullname, path, target) + if not module_spec: + return + module_origin = module_spec.origin + if not module_origin: + return + if plug := _current_plugin.get(None): + if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin: + return plug.module.__spec__ + if module_spec.parent and module_spec.parent == plug.module.__name__: + module_spec.loader = PluginLoader(fullname, module_origin, plug.id) + return module_spec + elif module_spec.name in _SUBMODULE_WAITLIST.get(plug.module.__name__, ()): + module_spec.loader = PluginLoader(fullname, module_origin, plug.id) + # _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name) + return module_spec + + if module_spec.name in service.plugins: + module_spec.loader = PluginLoader(fullname, module_origin) + return module_spec + for plug in service.plugins.values(): + if module_spec.name in plug.submodules: + module_spec.loader = PluginLoader(fullname, module_origin, plug.id) + return module_spec + return + + def find_spec(name, package=None): fullname = resolve_name(name, package) if name.startswith(".") else name parent_name = fullname.rpartition(".")[0] @@ -165,6 +205,8 @@ def find_spec(name, package=None): ) from e else: parent_path = None + if spec := _PluginFinder.find_spec(fullname, parent_path): + return spec module_spec = PathFinder.find_spec(fullname, parent_path, None) if not module_spec: return @@ -185,38 +227,4 @@ def import_plugin(name, package=None): return -class _PluginFinder(MetaPathFinder): - def find_spec( - self, - fullname: str, - path: Optional[Sequence[str]], - target: Optional[ModuleType] = None, - ): - module_spec = PathFinder.find_spec(fullname, path, target) - if not module_spec: - return - module_origin = module_spec.origin - if not module_origin: - return - if plug := _current_plugin.get(None): - if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin: - return plug.module.__spec__ - if module_spec.parent and module_spec.parent == plug.module.__name__: - module_spec.loader = PluginLoader(fullname, module_origin, plug.id) - return module_spec - elif module_spec.name in _SUBMODULE_WAITLIST[plug.module.__name__]: - module_spec.loader = PluginLoader(fullname, module_origin, plug.id) - # _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name) - return module_spec - - if module_spec.name in service.plugins: - module_spec.loader = PluginLoader(fullname, module_origin) - return module_spec - for plug in service.plugins.values(): - if module_spec.name in plug.submodules: - module_spec.loader = PluginLoader(fullname, module_origin, plug.id) - return module_spec - return - - sys.meta_path.insert(0, _PluginFinder()) diff --git a/example_plugin.py b/example_plugin.py index c2a312f..603a46d 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -1,4 +1,5 @@ import re +import sys from arclet.alconna import Alconna, AllParam, Args @@ -85,3 +86,9 @@ async def append(data: str, session: Session): @command.on("show") async def show(session: Session): await session.send_message(f"Data: {kept_data}") + +TEST = 2 + +print([*Plugin.current().dispatchers.keys()]) +print(Plugin.current().submodules) +print("example_plugin not in sys.modules (expect True):", "example_plugin" not in sys.modules) diff --git a/main.py b/main.py index be3f519..270d017 100644 --- a/main.py +++ b/main.py @@ -15,13 +15,17 @@ async def echoimg(img: Image, session: Session): @command.on("load {plugin}") async def load(plugin: str, session: Session): - load_plugin(plugin) - await session.send_message(f"Loaded {plugin}") + if load_plugin(plugin): + await session.send_message(f"Loaded {plugin}") + else: + await session.send_message(f"Failed to load {plugin}") @command.on("unload {plugin}") async def unload(plugin: str, session: Session): - dispose_plugin(plugin) - await session.send_message(f"Unloaded {plugin}") + if dispose_plugin(plugin): + await session.send_message(f"Unloaded {plugin}") + else: + await session.send_message(f"Failed to unload {plugin}") app.run()