Skip to content

Commit 318543e

Browse files
committed
✨ builtin plugin auto_reload
1 parent 2b1c788 commit 318543e

File tree

11 files changed

+313
-61
lines changed

11 files changed

+313
-61
lines changed

arclet/entari/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .command import _commands
1818
from .event import MessageCreatedEvent, event_parse
19-
from .plugin.service import service
19+
from .plugin.service import plugin_service
2020
from .session import Session
2121

2222

@@ -78,15 +78,15 @@ def on_message(
7878

7979
def ensure_manager(self, manager: Launart):
8080
self.manager = manager
81-
manager.add_component(service)
81+
manager.add_component(plugin_service)
8282

8383
async def handle_event(self, account: Account, event: Event):
8484
async def event_parse_task(connection: Account, raw: Event):
8585
loop = asyncio.get_running_loop()
8686
with suppress(NotImplementedError):
8787
ev = event_parse(connection, raw)
8888
self.event_system.publish(ev)
89-
for plugin in service.plugins.values():
89+
for plugin in plugin_service.plugins.values():
9090
for disp in plugin.dispatchers.values():
9191
if not disp.validate(ev):
9292
continue
@@ -104,7 +104,7 @@ async def event_parse_task(connection: Account, raw: Event):
104104
async def account_hook(self, account: Account, state: LoginStatus):
105105
_connected = []
106106
_disconnected = []
107-
for plug in service.plugins.values():
107+
for plug in plugin_service.plugins.values():
108108
_connected.extend([func(account) for func in plug._connected])
109109
_disconnected.extend([func(account) for func in plug._disconnected])
110110
if state == LoginStatus.CONNECT:

arclet/entari/plugin/__init__.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .model import keeping as keeping
1414
from .module import import_plugin
1515
from .module import package as package
16-
from .service import service
16+
from .service import plugin_service
1717

1818
if TYPE_CHECKING:
1919
from ..event import Event
@@ -25,42 +25,44 @@ def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = N
2525
return plugin.dispatch(*events, predicate=predicate)
2626

2727

28-
def load_plugin(path: str, recursive_guard: set[str] | None = None) -> Plugin | None:
28+
def load_plugin(path: str, config: dict | None = None, recursive_guard: set[str] | None = None) -> Plugin | None:
2929
"""
3030
以导入路径方式加载模块
3131
3232
Args:
3333
path (str): 模块路径
34+
config (dict): 模块配置
3435
recursive_guard (set[str]): 递归保护
3536
"""
3637
if recursive_guard is None:
3738
recursive_guard = set()
38-
if path in service._submoded:
39-
logger.error(f"plugin {path!r} is already defined as submodule of {service._submoded[path]!r}")
39+
path = path.replace("::", "arclet.entari.plugins.")
40+
if path in plugin_service._submoded:
41+
logger.error(f"plugin {path!r} is already defined as submodule of {plugin_service._submoded[path]!r}")
4042
return
41-
if path in service.plugins:
42-
return service.plugins[path]
43+
if path in plugin_service.plugins:
44+
return plugin_service.plugins[path]
4345
try:
44-
mod = import_plugin(path)
46+
mod = import_plugin(path, config=config)
4547
if not mod:
4648
logger.error(f"cannot found plugin {path!r}")
4749
return
4850
logger.success(f"loaded plugin {path!r}")
49-
if mod.__name__ in service._unloaded:
50-
if mod.__name__ in service._referents and service._referents[mod.__name__]:
51-
referents = service._referents[mod.__name__].copy()
52-
service._referents[mod.__name__].clear()
51+
if mod.__name__ in plugin_service._unloaded:
52+
if mod.__name__ in plugin_service._referents and plugin_service._referents[mod.__name__]:
53+
referents = plugin_service._referents[mod.__name__].copy()
54+
plugin_service._referents[mod.__name__].clear()
5355
for referent in referents:
5456
if referent in recursive_guard:
5557
continue
56-
if referent in service.plugins:
58+
if referent in plugin_service.plugins:
5759
logger.debug(f"reloading {mod.__name__}'s referent {referent!r}")
5860
dispose(referent)
5961
if not load_plugin(referent):
60-
service._referents[mod.__name__].add(referent)
62+
plugin_service._referents[mod.__name__].add(referent)
6163
else:
6264
recursive_guard.add(referent)
63-
service._unloaded.discard(mod.__name__)
65+
plugin_service._unloaded.discard(mod.__name__)
6466
return mod.__plugin__
6567
except RegisterNotInPluginError as e:
6668
logger.exception(f"{e.args[0]}", exc_info=e)
@@ -79,9 +81,9 @@ def load_plugins(dir_: str | PathLike | Path):
7981

8082

8183
def dispose(plugin: str):
82-
if plugin not in service.plugins:
84+
if plugin not in plugin_service.plugins:
8385
return False
84-
_plugin = service.plugins[plugin]
86+
_plugin = plugin_service.plugins[plugin]
8587
_plugin.dispose()
8688
return True
8789

@@ -91,3 +93,31 @@ def metadata(data: PluginMetadata):
9193
if not (plugin := _current_plugin.get(None)):
9294
raise LookupError("no plugin context found")
9395
plugin._metadata = data # type: ignore
96+
97+
98+
def find_plugin(name: str) -> Plugin | None:
99+
if name in plugin_service.plugins:
100+
return plugin_service.plugins[name]
101+
if name in plugin_service._submoded:
102+
return plugin_service.plugins[plugin_service._submoded[name]]
103+
return None
104+
105+
106+
def find_plugin_by_file(file: str) -> Plugin | None:
107+
path = Path(file).resolve()
108+
for plugin in plugin_service.plugins.values():
109+
if plugin.module.__file__ == str(path):
110+
return plugin
111+
if plugin.module.__file__ and Path(plugin.module.__file__).parent == path:
112+
return plugin
113+
for submod in plugin.submodules.values():
114+
if submod.__file__ == str(path):
115+
return plugin
116+
if submod.__file__ and Path(submod.__file__).parent == path:
117+
return plugin
118+
path1 = Path(path)
119+
while path1.parent != path1:
120+
if str(path1) == plugin.module.__file__:
121+
return plugin
122+
path1 = path1.parent
123+
return None

arclet/entari/plugin/model.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from arclet.letoderea import BaseAuxiliary, Provider, Publisher, StepOut, system_ctx
1414
from arclet.letoderea.builtin.breakpoint import R
1515
from arclet.letoderea.typing import TTarget
16+
from creart import it
17+
from launart import Launart, Service
1618
from satori.client import Account
1719

18-
from .service import service
20+
from .service import plugin_service
1921

2022
if TYPE_CHECKING:
2123
from ..event import Event
@@ -113,6 +115,7 @@ class Plugin:
113115
module: ModuleType
114116
dispatchers: dict[str, PluginDispatcher] = field(default_factory=dict)
115117
submodules: dict[str, ModuleType] = field(default_factory=dict)
118+
config: dict[str, Any] = field(default_factory=dict)
116119
_metadata: PluginMetadata | None = None
117120
_is_disposed: bool = False
118121

@@ -121,6 +124,8 @@ class Plugin:
121124
_connected: list[_AccountUpdate] = field(init=False, default_factory=list)
122125
_disconnected: list[_AccountUpdate] = field(init=False, default_factory=list)
123126

127+
_services: dict[str, Service] = field(init=False, default_factory=dict)
128+
124129
def on_prepare(self, func: _Lifespan):
125130
self._preparing.append(func)
126131
return func
@@ -149,15 +154,15 @@ def metadata(self) -> PluginMetadata | None:
149154
return self._metadata
150155

151156
def __post_init__(self):
152-
service.plugins[self.id] = self
153-
if self.id not in service._keep_values:
154-
service._keep_values[self.id] = {}
155-
if self.id not in service._referents:
156-
service._referents[self.id] = set()
157+
plugin_service.plugins[self.id] = self
158+
if self.id not in plugin_service._keep_values:
159+
plugin_service._keep_values[self.id] = {}
160+
if self.id not in plugin_service._referents:
161+
plugin_service._referents[self.id] = set()
157162
finalize(self, self.dispose)
158163

159164
def dispose(self):
160-
service._unloaded.add(self.id)
165+
plugin_service._unloaded.add(self.id)
161166
if self._is_disposed:
162167
return
163168
self._is_disposed = True
@@ -168,15 +173,21 @@ def dispose(self):
168173
for submod in self.submodules.values():
169174
delattr(submod, "__plugin__")
170175
sys.modules.pop(submod.__name__, None)
171-
service._submoded.pop(submod.__name__, None)
176+
plugin_service._submoded.pop(submod.__name__, None)
172177
if submod.__spec__ and submod.__spec__.cached:
173178
Path(submod.__spec__.cached).unlink(missing_ok=True)
174179
self.submodules.clear()
175180
for disp in self.dispatchers.values():
176181
disp.dispose()
177182
self.dispatchers.clear()
178-
del service.plugins[self.id]
183+
del plugin_service.plugins[self.id]
179184
del self.module
185+
for serv in self._services.values():
186+
try:
187+
it(Launart).remove_component(serv)
188+
except ValueError:
189+
pass
190+
self._services.clear()
180191

181192
def dispatch(self, *events: type[Event], predicate: Callable[[Event], bool] | None = None):
182193
disp = PluginDispatcher(self, *events, predicate=predicate)
@@ -202,6 +213,14 @@ def proxy(self):
202213
def subproxy(self, sub_id: str):
203214
return _ProxyModule(self.id, sub_id)
204215

216+
def service(self, serv: Service | type[Service]):
217+
if isinstance(serv, type):
218+
serv = serv()
219+
self._services[serv.id] = serv
220+
if plugin_service.status.blocking:
221+
it(Launart).add_component(serv)
222+
return serv
223+
205224

206225
class KeepingVariable:
207226
def __init__(self, obj: T, dispose: Callable[[T], None] | None = None):
@@ -222,10 +241,10 @@ def dispose(self):
222241
def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T:
223242
if not (plug := _current_plugin.get(None)):
224243
raise LookupError("no plugin context found")
225-
if id_ not in service._keep_values[plug.id]:
226-
service._keep_values[plug.id][id_] = KeepingVariable(obj, dispose)
244+
if id_ not in plugin_service._keep_values[plug.id]:
245+
plugin_service._keep_values[plug.id][id_] = KeepingVariable(obj, dispose)
227246
else:
228-
obj = service._keep_values[plug.id][id_].obj # type: ignore
247+
obj = plugin_service._keep_values[plug.id][id_].obj # type: ignore
229248
return obj
230249

231250

@@ -240,12 +259,12 @@ def __get_module(self) -> ModuleType:
240259
def __init__(self, plugin_id: str, sub_id: str | None = None) -> None:
241260
self.__plugin_id = plugin_id
242261
self.__sub_id = sub_id
243-
if self.__plugin_id not in service.plugins:
262+
if self.__plugin_id not in plugin_service.plugins:
244263
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
245264
if self.__sub_id:
246-
self.__origin = ref(service.plugins[self.__plugin_id].submodules[self.__sub_id])
265+
self.__origin = ref(plugin_service.plugins[self.__plugin_id].submodules[self.__sub_id])
247266
else:
248-
self.__origin = ref(service.plugins[self.__plugin_id].module)
267+
self.__origin = ref(plugin_service.plugins[self.__plugin_id].module)
249268
super().__init__(self.__get_module().__name__)
250269
self.__doc__ = self.__get_module().__doc__
251270
self.__file__ = self.__get_module().__file__
@@ -278,14 +297,14 @@ def __getattr__(self, name: str):
278297
"__spec__",
279298
):
280299
return super().__getattribute__(name)
281-
if self.__plugin_id not in service.plugins:
300+
if self.__plugin_id not in plugin_service.plugins:
282301
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
283302
if plug := inspect.currentframe().f_back.f_globals.get("__plugin__"): # type: ignore
284303
if plug.id != self.__plugin_id:
285-
service._referents[self.__plugin_id].add(plug.id)
304+
plugin_service._referents[self.__plugin_id].add(plug.id)
286305
elif plug := inspect.currentframe().f_back.f_back.f_globals.get("__plugin__"): # type: ignore
287306
if plug.id != self.__plugin_id:
288-
service._referents[self.__plugin_id].add(plug.id)
307+
plugin_service._referents[self.__plugin_id].add(plug.id)
289308
return getattr(self.__get_module(), name)
290309

291310
def __setattr__(self, name: str, value):

0 commit comments

Comments
 (0)