Skip to content

Commit 23cb1a7

Browse files
committed
✨ Basic ConfigModel
1 parent 8ef06cd commit 23cb1a7

File tree

6 files changed

+121
-32
lines changed

6 files changed

+121
-32
lines changed

Diff for: arclet/entari/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from . import command as command
4343
from . import scheduler as scheduler
44+
from .config import BasicConfModel as BasicConfModel
4445
from .config import load_config as load_config
4546
from .core import Entari as Entari
4647
from .event import MessageCreatedEvent as MessageCreatedEvent

Diff for: arclet/entari/builtins/auto_reload.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
raise ImportError("Please install `watchfiles` first. Install with `pip install arclet-entari[reload]`")
1313

1414
from arclet.entari import add_service, declare_static, load_plugin, metadata, plugin_config, unload_plugin
15-
from arclet.entari.config import EntariConfig
15+
from arclet.entari.config import BasicConfModel, EntariConfig, field
1616
from arclet.entari.event.config import ConfigReload
1717
from arclet.entari.logger import log
1818
from arclet.entari.plugin import find_plugin, find_plugin_by_file
1919

2020
declare_static()
2121

2222

23-
class Config:
24-
watch_dirs: list[str] = ["."]
23+
class Config(BasicConfModel):
24+
watch_dirs: list[Union[str, Path]] = field(default_factory=lambda: ["."])
2525
watch_config: bool = False
2626

2727

@@ -190,11 +190,9 @@ async def launch(self, manager: Launart):
190190
self.fail.clear()
191191

192192

193-
conf = plugin_config()
194-
watch_dirs = conf.get("watch_dirs", ["."])
195-
watch_config = conf.get("watch_config", False)
193+
conf = plugin_config(Config)
196194

197-
add_service(serv := Watcher(watch_dirs, watch_config))
195+
add_service(serv := Watcher(conf.watch_dirs, conf.watch_config))
198196

199197

200198
@es.on(ConfigReload)
@@ -203,6 +201,7 @@ def handle_config_reload(event: ConfigReload):
203201
return
204202
if event.key not in ("::auto_reload", "arclet.entari.builtins.auto_reload"):
205203
return
206-
serv.dirs = event.value.get("watch_dirs", ["."])
207-
serv.is_watch_config = event.value.get("watch_config", False)
204+
new_conf = event.plugin_config(Config)
205+
serv.dirs = new_conf.watch_dirs
206+
serv.is_watch_config = new_conf.watch_config
208207
return True

Diff for: arclet/entari/builtins/help.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from dataclasses import field
23
from typing import Optional
34

45
from arclet.alconna import (
@@ -16,22 +17,19 @@
1617
)
1718
from tarina import lang
1819

19-
from arclet.entari import Session, command, metadata, plugin_config
20+
from arclet.entari import BasicConfModel, Session, command, metadata, plugin_config
2021

21-
config = plugin_config()
22-
help_command: str = config.get("help_command", "help")
23-
help_alias: list[str] = config.get("help_alias", ["帮助", "命令帮助"])
24-
help_all_alias: list[str] = config.get("help_all_alias", ["所有帮助", "所有命令帮助"])
25-
page_size: Optional[int] = config.get("page_size", None)
2622

27-
28-
class Config:
23+
class Config(BasicConfModel):
2924
help_command: str = "help"
30-
help_alias: list[str] = ["帮助", "命令帮助"]
31-
help_all_alias: list[str] = ["所有帮助", "所有命令帮助"]
25+
help_alias: list[str] = field(default_factory=lambda: ["帮助", "命令帮助"])
26+
help_all_alias: list[str] = field(default_factory=lambda: ["所有帮助", "所有命令帮助"])
3227
page_size: Optional[int] = None
3328

3429

30+
config = plugin_config(Config)
31+
32+
3533
metadata(
3634
"help",
3735
["RF-Tar-Railt <[email protected]>"],
@@ -44,7 +42,7 @@ class Config:
4442
ns.disable_builtin_options = {"shortcut"}
4543

4644
help_cmd = Alconna(
47-
help_command,
45+
config.help_command,
4846
Args[
4947
"query#选择某条命令的id或者名称查看具体帮助;/?",
5048
str,
@@ -70,13 +68,13 @@ class Config:
7068
meta=CommandMeta(
7169
description="显示所有命令帮助",
7270
usage="可以使用 --hide 参数来显示隐藏命令,使用 -P 参数来显示命令所属插件名称",
73-
example=f"${help_command} 1",
71+
example=f"${config.help_command} 1",
7472
),
7573
)
7674

77-
for alias in set(help_alias):
75+
for alias in set(config.help_alias):
7876
help_cmd.shortcut(alias, {"prefix": True, "fuzzy": False})
79-
for alias in set(help_all_alias):
77+
for alias in set(config.help_all_alias):
8078
help_cmd.shortcut(alias, {"args": ["--hide"], "prefix": True, "fuzzy": False})
8179

8280

@@ -122,7 +120,7 @@ def help_cmd_handle(arp: Arparma, interactive: bool = False):
122120
return f"{command_string}\n{footer}"
123121
return slot.get_help()
124122

125-
if not page_size:
123+
if not config.page_size:
126124
header = lang.require("manager", "help_header")
127125
command_string = "\n".join(
128126
(
@@ -134,10 +132,10 @@ def help_cmd_handle(arp: Arparma, interactive: bool = False):
134132
)
135133
return f"{header}\n{command_string}\n{footer}"
136134

137-
max_page = len(cmds) // page_size + 1
135+
max_page = len(cmds) // config.page_size + 1
138136
if page < 1 or page > max_page:
139137
page = 1
140-
max_length = page_size
138+
max_length = config.page_size
141139
if interactive:
142140
footer += "\n" + "输入 '<', 'a' 或 '>', 'd' 来翻页"
143141

Diff for: arclet/entari/config.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass, fields, is_dataclass
4+
from dataclasses import field as field
5+
from inspect import Signature
46
import json
57
import os
68
from pathlib import Path
7-
from typing import Any, Callable, ClassVar, TypedDict
9+
from typing import Any, Callable, ClassVar, TypedDict, TypeVar, get_args, get_origin
10+
from typing_extensions import dataclass_transform
811
import warnings
912

13+
_available_dc_attrs = set(Signature.from_callable(dataclass).parameters.keys())
14+
1015

1116
class BasicConfig(TypedDict, total=False):
1217
network: list[dict[str, Any]]
@@ -124,3 +129,64 @@ def _updater(self: EntariConfig):
124129

125130

126131
load_config = EntariConfig.load
132+
133+
134+
_config_model_validators = {}
135+
136+
C = TypeVar("C")
137+
138+
139+
def config_validator_register(base: type):
140+
def wrapper(func: Callable[[dict[str, Any], type[C]], C]):
141+
_config_model_validators[base] = func
142+
return func
143+
144+
return wrapper
145+
146+
147+
def config_model_validate(base: type[C], data: dict[str, Any]) -> C:
148+
for b in base.__mro__[-2::-1]:
149+
if b in _config_model_validators:
150+
return _config_model_validators[b](data, base)
151+
return base(**data)
152+
153+
154+
@dataclass_transform(kw_only_default=True)
155+
class BasicConfModel:
156+
def __init__(self, **kwargs):
157+
for k, v in kwargs.items():
158+
setattr(self, k, v)
159+
160+
def __init_subclass__(cls, **kwargs):
161+
super().__init_subclass__(**kwargs)
162+
dataclass(**{k: v for k, v in kwargs.items() if k in _available_dc_attrs})(cls)
163+
164+
165+
@config_validator_register(BasicConfModel)
166+
def _basic_config_validate(data: dict[str, Any], base: type[C]) -> C:
167+
def _nested_validate(namespace: dict[str, Any], cls):
168+
result = {}
169+
for field_ in fields(cls):
170+
if field_.name not in namespace:
171+
continue
172+
if is_dataclass(field_.type):
173+
result[field_.name] = _nested_validate(namespace[field_.name], field_.type)
174+
elif get_origin(field_.type) is list and is_dataclass(get_args(field_.type)[0]):
175+
result[field_.name] = [_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]]
176+
elif get_origin(field_.type) is set and is_dataclass(get_args(field_.type)[0]):
177+
result[field_.name] = {_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]}
178+
elif get_origin(field_.type) is dict and is_dataclass(get_args(field_.type)[1]):
179+
result[field_.name] = {
180+
k: _nested_validate(v, get_args(field_.type)[1]) for k, v in namespace[field_.name].items()
181+
}
182+
elif get_origin(field_.type) is tuple:
183+
args = get_args(field_.type)
184+
result[field_.name] = tuple(
185+
_nested_validate(d, args[i]) if is_dataclass(args[i]) else d
186+
for i, d in enumerate(namespace[field_.name])
187+
)
188+
else:
189+
result[field_.name] = namespace[field_.name]
190+
return cls(**result)
191+
192+
return _nested_validate(data, base)

Diff for: arclet/entari/event/config.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from dataclasses import dataclass
2-
from typing import Any, Optional
2+
from typing import Any, Optional, overload
33

44
from arclet.letoderea import make_event
55

6+
from ..config import C, config_model_validate
7+
68

79
@dataclass
810
@make_event(name="entari.event/config_reload")
@@ -13,3 +15,16 @@ class ConfigReload:
1315
old: Optional[Any] = None
1416

1517
__result_type__: type[bool] = bool
18+
19+
@overload
20+
def plugin_config(self) -> dict[str, Any]: ...
21+
22+
@overload
23+
def plugin_config(self, model_type: type[C]) -> C: ...
24+
25+
def plugin_config(self, model_type: Optional[type[C]] = None):
26+
if self.scope != "plugin":
27+
raise ValueError("not a plugin config")
28+
if model_type:
29+
return config_model_validate(model_type, self.value)
30+
return self.value

Diff for: arclet/entari/plugin/__init__.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from os import PathLike
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, overload
66

77
from tarina import init_spec
88

9-
from ..config import EntariConfig
9+
from ..config import C, EntariConfig, config_model_validate
1010
from ..logger import log
1111
from .model import PluginMetadata as PluginMetadata
1212
from .model import RegisterNotInPluginError
@@ -116,10 +116,20 @@ def metadata(data: PluginMetadata):
116116
plugin._metadata = data # type: ignore
117117

118118

119-
def plugin_config() -> dict[str, Any]:
119+
@overload
120+
def plugin_config() -> dict[str, Any]: ...
121+
122+
123+
@overload
124+
def plugin_config(model_type: type[C]) -> C: ...
125+
126+
127+
def plugin_config(model_type: type[C] | None = None):
120128
"""获取当前插件的配置"""
121129
if not (plugin := _current_plugin.get(None)):
122130
raise LookupError("no plugin context found")
131+
if model_type:
132+
return config_model_validate(model_type, plugin.config)
123133
return plugin.config
124134

125135

0 commit comments

Comments
 (0)