Skip to content

Commit b3aacf4

Browse files
committed
✨ support submodules plugin
1 parent caa8ebb commit b3aacf4

File tree

4 files changed

+89
-18
lines changed

4 files changed

+89
-18
lines changed

arclet/entari/plugin/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
from __future__ import annotations
22

3+
import inspect
34
from os import PathLike
45
from pathlib import Path
56
from typing import TYPE_CHECKING, Callable
67

78
from loguru import logger
9+
from tarina import init_spec
810

9-
from .model import Plugin
11+
from .model import Plugin, RegisterNotInPluginError
1012
from .model import PluginMetadata as PluginMetadata
1113
from .model import _current_plugin
1214
from .module import import_plugin
15+
from .module import package as package
1316
from .service import service
1417

1518
if TYPE_CHECKING:
1619
from ..event import Event
1720

1821

1922
def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = None):
20-
if not (plugin := _current_plugin.get()):
23+
if not (plugin := _current_plugin.get(None)):
2124
raise LookupError("no plugin context found")
2225
return plugin.dispatch(*events, predicate=predicate)
2326

@@ -38,6 +41,8 @@ def load_plugin(path: str) -> Plugin | None:
3841
return
3942
logger.success(f"loaded plugin {path!r}")
4043
return mod.__plugin__
44+
except RegisterNotInPluginError as e:
45+
logger.exception(f"{e.args[0]}", exc_info=e)
4146
except Exception as e:
4247
logger.error(f"failed to load plugin {path!r} caused by {e!r}")
4348

@@ -57,3 +62,8 @@ def dispose(plugin: str):
5762
return
5863
_plugin = service.plugins[plugin]
5964
_plugin.dispose()
65+
66+
67+
@init_spec(PluginMetadata)
68+
def metadata(data: PluginMetadata):
69+
inspect.currentframe().f_back.f_globals["__plugin_metadata__"] = data # type: ignore

arclet/entari/plugin/model.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
if TYPE_CHECKING:
1919
from ..event import Event
2020

21-
_current_plugin: ContextVar[Plugin | None] = ContextVar("_current_plugin", default=None)
21+
_current_plugin: ContextVar[Plugin] = ContextVar("_current_plugin")
22+
23+
24+
class RegisterNotInPluginError(Exception):
25+
pass
2226

2327

2428
class PluginDispatcher(Publisher):
@@ -59,8 +63,20 @@ def dispose(self):
5963
self._run_by_system = False
6064
self.subscribers.clear()
6165

62-
on = Publisher.register
63-
handle = Publisher.register
66+
if TYPE_CHECKING:
67+
register = Publisher.register
68+
else:
69+
def register(self, *args, **kwargs):
70+
wrapper = super().register(*args, **kwargs)
71+
72+
def decorator(func):
73+
self.plugin.validate(func)
74+
return wrapper(func)
75+
76+
return decorator
77+
78+
on = register
79+
handle = register
6480

6581
def __call__(self, func):
6682
return self.register()(func)
@@ -93,7 +109,8 @@ class Plugin:
93109
id: str
94110
module: ModuleType
95111
dispatchers: dict[str, PluginDispatcher] = field(default_factory=dict)
96-
metadata: PluginMetadata | None = None
112+
submodules: dict[str, ModuleType] = field(default_factory=dict)
113+
_metadata: PluginMetadata | None = None
97114
_is_disposed: bool = False
98115

99116
_preparing: list[_Lifespan] = field(init=False, default_factory=list)
@@ -121,19 +138,19 @@ def on_disconnect(self, func: _AccountUpdate):
121138
def current() -> Plugin:
122139
return _current_plugin.get() # type: ignore
123140

141+
@property
142+
def metadata(self) -> PluginMetadata | None:
143+
return self._metadata
144+
124145
def __post_init__(self):
125146
service.plugins[self.id] = self
126147
finalize(self, self.dispose)
127148

128-
@init_spec(PluginMetadata, True)
129-
def meta(self, metadata: PluginMetadata):
130-
self.metadata = metadata
131-
return self
132-
133149
def dispose(self):
134150
if self._is_disposed:
135151
return
136152
self._is_disposed = True
153+
self.submodules.clear()
137154
for disp in self.dispatchers.values():
138155
disp.dispose()
139156
self.dispatchers.clear()
@@ -146,3 +163,12 @@ def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | No
146163
return self.dispatchers[disp.id]
147164
self.dispatchers[disp.id] = disp
148165
return disp
166+
167+
def validate(self, func):
168+
if func.__module__ != self.module.__name__:
169+
if "__plugin__" in func.__globals__ and func.__globals__["__plugin__"] is self:
170+
return
171+
raise RegisterNotInPluginError(
172+
f"Handler {func.__qualname__} should define in the same module as the plugin: {self.module.__name__}. "
173+
f"Please use the `load_plugin({func.__module__!r})` or `package({func.__module__!r})` before import it."
174+
)

arclet/entari/plugin/module.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
from .service import service
1111

1212

13+
_SUBMODULE_WAITLIST = set()
14+
15+
16+
def package(*names: str):
17+
"""手动指定特定模块作为插件的子模块"""
18+
_SUBMODULE_WAITLIST.update(names)
19+
20+
1321
class PluginLoader(SourceFileLoader):
1422
def __init__(self, fullname: str, path: str) -> None:
1523
self.loaded = False
@@ -22,6 +30,19 @@ def create_module(self, spec) -> Optional[ModuleType]:
2230
return super().create_module(spec)
2331

2432
def exec_module(self, module: ModuleType) -> None:
33+
if plugin := _current_plugin.get(None):
34+
if module.__name__ == plugin.module.__name__: # from . import xxxx
35+
return
36+
setattr(module, "__plugin__", plugin)
37+
try:
38+
super().exec_module(module)
39+
except Exception:
40+
delattr(module, "__plugin__")
41+
raise
42+
else:
43+
plugin.submodules[module.__name__] = module
44+
return
45+
2546
if self.loaded:
2647
return
2748

@@ -44,7 +65,7 @@ def exec_module(self, module: ModuleType) -> None:
4465
# get plugin metadata
4566
metadata: Optional[PluginMetadata] = getattr(module, "__plugin_metadata__", None)
4667
if metadata and not plugin.metadata:
47-
plugin.metadata = metadata
68+
plugin._metadata = metadata
4869
return
4970

5071

@@ -95,6 +116,17 @@ def find_spec(
95116
module_origin = module_spec.origin
96117
if not module_origin:
97118
return
119+
if plug := _current_plugin.get(None):
120+
if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin:
121+
return plug.module.__spec__
122+
if module_spec.parent and module_spec.parent == plug.module.__name__:
123+
module_spec.loader = PluginLoader(fullname, module_origin)
124+
return module_spec
125+
elif module_spec.name in _SUBMODULE_WAITLIST:
126+
module_spec.loader = PluginLoader(fullname, module_origin)
127+
_SUBMODULE_WAITLIST.remove(module_spec.name)
128+
return module_spec
129+
98130
if module_spec.name in service.plugins:
99131
module_spec.loader = PluginLoader(fullname, module_origin)
100132
return module_spec

example_plugin.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,34 @@
1010
command,
1111
is_public_message,
1212
bind,
13+
metadata,
1314
)
1415
from arclet.entari.command import Match
1516

16-
plug = Plugin.current().meta(__file__)
17+
metadata(__file__)
18+
19+
plug = Plugin.current()
1720

1821

1922
@plug.on_prepare
2023
async def prepare():
21-
print("Preparing")
24+
print("example: Preparing")
2225

2326

2427
@plug.on_cleanup
2528
async def cleanup():
26-
print("Cleanup")
29+
print("example: Cleanup")
2730

2831

2932
disp_message = MessageCreatedEvent.dispatch()
3033

3134

3235
@disp_message
3336
@bind(is_public_message)
34-
async def _(msg: MessageChain):
37+
async def _(msg: MessageChain, session: Session):
3538
content = msg.extract_plain_text()
3639
if re.match(r"(.{0,3})(上传|设定)(.{0,3})(上传|设定)(.{0,3})", content):
37-
return "上传设定的帮助是..."
40+
return await session.send("上传设定的帮助是...")
3841

3942

4043
disp_message1 = plug.dispatch(MessageCreatedEvent)
@@ -45,11 +48,11 @@ async def _(msg: MessageChain):
4548

4649
@disp_message1.on(auxiliaries=[is_public_message])
4750
async def _(event: MessageCreatedEvent):
48-
print(event.content)
4951
if event.quote and (authors := select(event.quote, Author)):
5052
author = authors[0]
5153
reply_self = author.id == event.account.self_id
5254
print(reply_self)
55+
print(event.content)
5356

5457

5558
on_alc = command.mount(Alconna("echo", Args["content?", AllParam]))

0 commit comments

Comments
 (0)