Skip to content

Commit

Permalink
✨ referent update chain
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent fddc228 commit 5e3299c
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 22 deletions.
8 changes: 5 additions & 3 deletions arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tarina.generic import get_origin

from .command import _commands
from .event import MessageEvent, event_parse
from .event import MessageCreatedEvent, event_parse
from .plugin.service import service
from .session import Session

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, *configs: Config):
self._ref_tasks = set()

@self.on_message(priority=0)
def log(event: MessageEvent):
def log(event: MessageCreatedEvent):
logger.info(
f"[{event.channel.name or event.channel.id}] "
f"{event.member.nick if event.member else (event.user.name or event.user.id)}"
Expand All @@ -72,7 +72,9 @@ def on_message(
auxiliaries: list[BaseAuxiliary] | None = None,
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
return self.event_system.on(MessageEvent, priority=priority, auxiliaries=auxiliaries, providers=providers)
return self.event_system.on(
MessageCreatedEvent, priority=priority, auxiliaries=auxiliaries, providers=providers
)

def ensure_manager(self, manager: Launart):
self.manager = manager
Expand Down
21 changes: 18 additions & 3 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import inspect
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Callable
Expand All @@ -26,6 +25,9 @@ def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = N
return plugin.dispatch(*events, predicate=predicate)


_recrusive_guard = set()


def load_plugin(path: str) -> Plugin | None:
"""
以导入路径方式加载模块
Expand All @@ -41,7 +43,18 @@ def load_plugin(path: str) -> Plugin | None:
logger.error(f"cannot found plugin {path!r}")
return
logger.success(f"loaded plugin {path!r}")

if mod.__name__ in service._unloaded:
if mod.__name__ in service._referents and service._referents[mod.__name__]:
for referent in service._referents[mod.__name__]:
if referent in _recrusive_guard:
continue
_recrusive_guard.add(referent)
if referent in service.plugins:
logger.debug(f"reloading {mod.__name__}'s referent {referent!r}")
dispose(referent)
load_plugin(referent)
_recrusive_guard.clear()
service._unloaded.discard(mod.__name__)
return mod.__plugin__
except RegisterNotInPluginError as e:
logger.exception(f"{e.args[0]}", exc_info=e)
Expand Down Expand Up @@ -69,4 +82,6 @@ def dispose(plugin: str):

@init_spec(PluginMetadata)
def metadata(data: PluginMetadata):
inspect.currentframe().f_back.f_globals["__plugin_metadata__"] = data # type: ignore
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
plugin._metadata = data # type: ignore
79 changes: 70 additions & 9 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Awaitable
from contextvars import ContextVar
from dataclasses import dataclass, field
import inspect
from pathlib import Path
import sys
from types import ModuleType
Expand Down Expand Up @@ -138,7 +139,10 @@ def on_disconnect(self, func: _AccountUpdate):

@staticmethod
def current() -> Plugin:
return _current_plugin.get() # type: ignore
try:
return _current_plugin.get() # type: ignore
except LookupError:
raise LookupError("no plugin context found") from None

@property
def metadata(self) -> PluginMetadata | None:
Expand All @@ -148,9 +152,12 @@ def __post_init__(self):
service.plugins[self.id] = self
if self.id not in service._keep_values:
service._keep_values[self.id] = {}
if self.id not in service._referents:
service._referents[self.id] = set()
finalize(self, self.dispose)

def dispose(self):
service._unloaded.add(self.id)
if self._is_disposed:
return
self._is_disposed = True
Expand Down Expand Up @@ -188,10 +195,12 @@ def validate(self, func):
f"`package({func.__module__!r})` before import it."
)

@property
def proxy(self):
return _ProxyModule(self.id)

def subproxy(self, sub_id: str):
return _ProxyModule(self.id, sub_id)


class KeepingVariable:
def __init__(self, obj: T, dispose: Callable[[T], None] | None = None):
Expand Down Expand Up @@ -219,18 +228,70 @@ def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T:
return obj


class _ProxyModule:
def __init__(self, plugin_id: str) -> None:
class _ProxyModule(ModuleType):

def __get_module(self):
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
if self.__sub_id:
return service.plugins[self.__plugin_id].submodules[self.__sub_id]
return service.plugins[self.__plugin_id].module

def __init__(self, plugin_id: str, sub_id: str | None = None) -> None:
self.__plugin_id = plugin_id
self.__sub_id = sub_id

super().__init__(self.__get_module().__name__)
self.__doc__ = self.__get_module().__doc__
self.__file__ = self.__get_module().__file__
self.__loader__ = self.__get_module().__loader__
self.__package__ = self.__get_module().__package__
if path := getattr(self.__get_module(), "__path__", None):
self.__path__ = path
self.__spec__ = self.__get_module().__spec__

def __repr__(self):
if self.__sub_id:
return f"<ProxyModule {self.__sub_id!r}>"
return f"<ProxyModule {self.__plugin_id!r}>"

@property
def __dict__(self) -> dict[str, Any]:
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
return self.__get_module().__dict__

def __getattr__(self, name: str):
if name in (
"_ProxyModule__plugin_id",
"_ProxyModule__sub_id",
"__name__",
"__doc__",
"__file__",
"__loader__",
"__package__",
"__path__",
"__spec__",
):
return super().__getattribute__(name)
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
return getattr(service.plugins[self.__plugin_id].module, name)
if plug := inspect.currentframe().f_back.f_globals.get("__plugin__"): # type: ignore
if plug.id != self.__plugin_id:
service._referents[self.__plugin_id].add(plug.id)
return getattr(self.__get_module(), name)

def __setattr__(self, name: str, value):
if name == "_ProxyModule__plugin_id":
if name in (
"_ProxyModule__plugin_id",
"_ProxyModule__sub_id",
"__name__",
"__doc__",
"__file__",
"__loader__",
"__package__",
"__path__",
"__spec__",
):
return super().__setattr__(name, value)
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
setattr(service.plugins[self.__plugin_id].module, name, value)
setattr(self.__get_module(), name, value)
27 changes: 21 additions & 6 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,33 @@ def package(*names: str):
_SUBMODULE_WAITLIST.setdefault(plugin.module.__name__, set()).update(names)


def _check_mod(name, package=None):
module = import_plugin(name, package)
if not module:
raise ModuleNotFoundError(f"module {name!r} not found")
if hasattr(module, "__plugin__"):
return module.__plugin__.subproxy(f"{package}{name}") if package else module.__plugin__.proxy()
return module


def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]):
if mod == ".":
return tuple(import_plugin(f".{alias}", __fullname) for alias in aliases)
_mod = import_plugin(f".{mod}", __fullname) if mod else import_plugin(__fullname)
if len(aliases) == 1:
return _check_mod(f".{aliases[0]}", __fullname)
return tuple(_check_mod(f".{alias}", __fullname) for alias in aliases)
_mod = _check_mod(f".{mod}", __fullname) if mod else _check_mod(__fullname)
if len(aliases) == 1:
return getattr(_mod, aliases[0])
return tuple(getattr(_mod, alias) for alias in aliases)


def _check_import(name: str, plugin_name: str):
if name in service.plugins:
return service.plugins[name].proxy
return service.plugins[name].proxy()
if name in _SUBMODULE_WAITLIST.get(plugin_name, ()):
return import_plugin(name)
mod = import_plugin(name)
if mod:
return mod.__plugin__.subproxy(name)
return __import__(name)


Expand All @@ -56,7 +71,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '', {[alias.name for alias in body.names]!r})"
+ f"=__unpack_import_from('{body.module}', '', {[alias.name for alias in body.names]!r})"
).body[0]
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
Expand Down Expand Up @@ -106,7 +121,7 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
def create_module(self, spec) -> Optional[ModuleType]:
if self.name in service.plugins:
self.loaded = True
return service.plugins[self.name].module
return service.plugins[self.name].proxy()
return super().create_module(spec)

def exec_module(self, module: ModuleType) -> None:
Expand Down
4 changes: 4 additions & 0 deletions arclet/entari/plugin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ class PluginService(Service):

plugins: dict[str, "Plugin"]
_keep_values: dict[str, dict[str, "KeepingVariable"]]
_referents: dict[str, set[str]]
_unloaded: set[str]

def __init__(self):
super().__init__()
self.plugins = {}
self._keep_values = {}
self._referents = {}
self._unloaded = set()

@property
def required(self) -> set[str]:
Expand Down
2 changes: 1 addition & 1 deletion example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def append(data: str, session: Session):
async def show(session: Session):
await session.send_message(f"Data: {kept_data}")

TEST = 2
TEST = 5

print([*Plugin.current().dispatchers.keys()])
print(Plugin.current().submodules)
Expand Down

0 comments on commit 5e3299c

Please sign in to comment.