From 589fb582e245b1797adf7063181c5b453626d98d Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Tue, 1 Oct 2024 01:23:39 +0800 Subject: [PATCH] :sparkles: recursive_guard --- arclet/entari/plugin/__init__.py | 24 +++++++++++++++--------- arclet/entari/plugin/model.py | 3 +++ arclet/entari/plugin/module.py | 8 ++++---- arclet/entari/plugin/service.py | 2 ++ example_plugin.py | 2 +- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/arclet/entari/plugin/__init__.py b/arclet/entari/plugin/__init__.py index d2322e0..74e7a95 100644 --- a/arclet/entari/plugin/__init__.py +++ b/arclet/entari/plugin/__init__.py @@ -25,16 +25,19 @@ def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = N return plugin.dispatch(*events, predicate=predicate) -_recrusive_guard = set() - - -def load_plugin(path: str) -> Plugin | None: +def load_plugin(path: str, recursive_guard: set[str] | None = None) -> Plugin | None: """ 以导入路径方式加载模块 Args: path (str): 模块路径 + recursive_guard (set[str]): 递归保护 """ + if recursive_guard is None: + recursive_guard = set() + if path in service._submoded: + logger.error(f"plugin {path!r} is already defined as submodule of {service._submoded[path]!r}") + return if path in service.plugins: return service.plugins[path] try: @@ -45,15 +48,18 @@ def load_plugin(path: str) -> Plugin | None: logger.success(f"loaded plugin {path!r}") if mod.__name__ in service._unloaded: if mod.__name__ in service._referents and service._referents[mod.__name__]: - for referent in service._referents[mod.__name__]: - if referent in _recrusive_guard: + referents = service._referents[mod.__name__].copy() + service._referents[mod.__name__].clear() + for referent in referents: + if referent in recursive_guard: continue - _recrusive_guard.add(referent) if referent in service.plugins: logger.debug(f"reloading {mod.__name__}'s referent {referent!r}") dispose(referent) - load_plugin(referent) - _recrusive_guard.clear() + if not load_plugin(referent): + service._referents[mod.__name__].add(referent) + else: + recursive_guard.add(referent) service._unloaded.discard(mod.__name__) return mod.__plugin__ except RegisterNotInPluginError as e: diff --git a/arclet/entari/plugin/model.py b/arclet/entari/plugin/model.py index 8c019a0..99b53e2 100644 --- a/arclet/entari/plugin/model.py +++ b/arclet/entari/plugin/model.py @@ -279,6 +279,9 @@ def __getattr__(self, name: str): if plug := inspect.currentframe().f_back.f_globals.get("__plugin__"): # type: ignore if plug.id != self.__plugin_id: service._referents[self.__plugin_id].add(plug.id) + elif plug := inspect.currentframe().f_back.f_back.f_globals.get("__plugin__"): # type: ignore + if plug.id != self.__plugin_id: + service._referents[self.__plugin_id].add(plug.id) return getattr(self.__get_module(), name) def __setattr__(self, name: str, value): diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index 62519ef..ae00e03 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -138,6 +138,7 @@ def exec_module(self, module: ModuleType) -> None: raise else: plugin.submodules[module.__name__] = module + service._submoded[module.__name__] = plugin.id return if self.loaded: @@ -196,10 +197,9 @@ def find_spec( if module_spec.name in service.plugins: module_spec.loader = PluginLoader(fullname, module_origin) return module_spec - for plug in service.plugins.values(): - if module_spec.name in plug.submodules: - module_spec.loader = PluginLoader(fullname, module_origin, plug.id) - return module_spec + if module_spec.name in service._submoded: + module_spec.loader = PluginLoader(fullname, module_origin, service._submoded[module_spec.name]) + return module_spec return diff --git a/arclet/entari/plugin/service.py b/arclet/entari/plugin/service.py index 59bda5c..9194e60 100644 --- a/arclet/entari/plugin/service.py +++ b/arclet/entari/plugin/service.py @@ -16,6 +16,7 @@ class PluginService(Service): _keep_values: dict[str, dict[str, "KeepingVariable"]] _referents: dict[str, set[str]] _unloaded: set[str] + _submoded: dict[str, str] def __init__(self): super().__init__() @@ -23,6 +24,7 @@ def __init__(self): self._keep_values = {} self._referents = {} self._unloaded = set() + self._submoded = {} @property def required(self) -> set[str]: diff --git a/example_plugin.py b/example_plugin.py index 06c047d..bbce7fa 100644 --- a/example_plugin.py +++ b/example_plugin.py @@ -87,7 +87,7 @@ async def append(data: str, session: Session): async def show(session: Session): await session.send_message(f"Data: {kept_data}") -TEST = 5 +TEST = 6 print([*Plugin.current().dispatchers.keys()]) print(Plugin.current().submodules)