Skip to content

Commit

Permalink
❇️ support command prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 5, 2024
1 parent e8b04ff commit bf515b9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 22 deletions.
38 changes: 26 additions & 12 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
class EntariCommands:
__namespace__ = "Entari"

def __init__(self, need_tome: bool = False, remove_tome: bool = True):
def __init__(self, need_tome: bool = False, remove_tome: bool = True, use_config_prefix: bool = True):
self.trie: CharTrie[Subscriber] = CharTrie()
self.publisher = Publisher("entari.command", MessageCreatedEvent)
self.publisher.bind(AlconnaProviderFactory())
self.need_tome = need_tome
self.remove_tome = remove_tome
self.use_config_prefix = use_config_prefix
config.namespaces["Entari"] = Namespace(
self.__namespace__,
to_text=lambda x: x.text if x.__class__ is Text else None,
Expand Down Expand Up @@ -102,23 +103,27 @@ def command(
self,
command: str,
help_text: Optional[str] = None,
need_tome: bool = False,
remove_tome: bool = True,
need_tome: Optional[bool] = None,
remove_tome: Optional[bool] = None,
use_config_prefix: Optional[bool] = None,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
):
class Command(AlconnaString):
def __call__(_cmd_self, func: TTarget[T]) -> Subscriber[T]:
return self.on(_cmd_self.build(), need_tome, remove_tome, auxiliaries, providers)(func)
return self.on(_cmd_self.build(), need_tome, remove_tome, use_config_prefix, auxiliaries, providers)(
func
)

return Command(command, help_text)

@overload
def on(
self,
command: Alconna,
need_tome: bool = False,
remove_tome: bool = True,
need_tome: Optional[bool] = None,
remove_tome: Optional[bool] = None,
use_config_prefix: Optional[bool] = None,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
) -> Callable[[TTarget[T]], Subscriber[T]]: ...
Expand All @@ -127,8 +132,9 @@ def on(
def on(
self,
command: str,
need_tome: bool = False,
remove_tome: bool = True,
need_tome: Optional[bool] = None,
remove_tome: Optional[bool] = None,
use_config_prefix: Optional[bool] = None,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
Expand All @@ -139,8 +145,9 @@ def on(
def on(
self,
command: Union[Alconna, str],
need_tome: bool = False,
remove_tome: bool = True,
need_tome: Optional[bool] = None,
remove_tome: Optional[bool] = None,
use_config_prefix: Optional[bool] = None,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]] = None,
*,
Expand All @@ -160,7 +167,13 @@ def wrapper(func: TTarget[T]) -> Subscriber[T]:
f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern)
)
auxiliaries.insert(
0, AlconnaSuppiler(_command, need_tome or self.need_tome, remove_tome or self.remove_tome)
0,
AlconnaSuppiler(
_command,
self.need_tome if need_tome is None else need_tome,
self.remove_tome if remove_tome is None else remove_tome,
self.use_config_prefix if use_config_prefix is None else use_config_prefix,
),
)
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
self.publisher.remove_subscriber(target)
Expand Down Expand Up @@ -203,9 +216,10 @@ def _remove(_):
_commands = EntariCommands()


def config_commands(need_tome: bool = False, remove_tome: bool = True):
def config_commands(need_tome: bool = False, remove_tome: bool = True, use_config_prefix: bool = True):
_commands.need_tome = need_tome
_commands.remove_tome = remove_tome
_commands.use_config_prefix = use_config_prefix


command = _commands.command
Expand Down
14 changes: 10 additions & 4 deletions arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def __init__(
command: Alconna,
need_tome: bool = False,
remove_tome: bool = True,
use_config_prefix: bool = True,
):
self.supplier = AlconnaSuppiler(command, need_tome, remove_tome)
self.supplier = AlconnaSuppiler(command, need_tome, remove_tome, use_config_prefix)
super().__init__(plugin, MessageCreatedEvent)

self.publisher.bind(MessageJudger(), self.supplier)
Expand Down Expand Up @@ -55,7 +56,7 @@ def on_execute(
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
_auxiliaries = auxiliaries or []
_auxiliaries.append(ExecuteSuppiler(self.supplier.cmd))
_auxiliaries.append(ExecuteSuppiler(self.supplier.cmd, self.supplier.use_config_prefix))

def wrapper(func):
sub = execute_handles.register(func, priority=priority, auxiliaries=_auxiliaries, providers=providers)
Expand All @@ -68,10 +69,15 @@ def wrapper(func):
Query = Query


def mount(cmd: Alconna, need_tome: bool = False, remove_tome: bool = True) -> AlconnaPluginDispatcher:
def mount(
cmd: Alconna,
need_tome: bool = False,
remove_tome: bool = True,
use_config_prefix: bool = True,
) -> AlconnaPluginDispatcher:
if not (plugin := Plugin.current()):
raise LookupError("no plugin context found")
disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome)
disp = AlconnaPluginDispatcher(plugin, cmd, need_tome, remove_tome, use_config_prefix)
if disp.publisher.id in plugin.dispatchers:
return plugin.dispatchers[disp.id] # type: ignore
plugin.dispatchers[disp.publisher.id] = disp
Expand Down
27 changes: 23 additions & 4 deletions arclet/entari/command/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
import inspect
from typing import Any, Literal, Optional, Union, get_args

Expand All @@ -11,6 +10,7 @@
from satori.element import At, Text
from tarina.generic import get_origin

from ..config import EntariConfig
from ..message import MessageChain
from .model import CommandResult, Match, Query

Expand All @@ -25,7 +25,7 @@ def _is_tome(message: MessageChain, account: Account):

def _remove_tome(message: MessageChain, account: Account):
if _is_tome(message, account):
message = deepcopy(message)
message = message.copy()
message.pop(0)
if message and isinstance(message[0], Text):
text = message[0].text.lstrip() # type: ignore
Expand All @@ -37,6 +37,19 @@ def _remove_tome(message: MessageChain, account: Account):
return message


def _remove_config_prefix(message: MessageChain):
if not (command_prefix := EntariConfig.instance.basic.get("command_prefix", [])):
return message
if message and isinstance(message[0], Text):
text = message[0].text # type: ignore
for prefix in command_prefix:
if text.startswith(prefix):
message = message.copy()
message[0] = Text(text[len(prefix) :])
return message
return MessageChain()


class MessageJudger(JudgeAuxiliary):
async def __call__(self, scope: Scope, interface: Interface) -> Optional[bool]:
return "$message_content" in interface.ctx
Expand All @@ -55,11 +68,12 @@ class AlconnaSuppiler(SupplyAuxiliary):
need_tome: bool
remove_tome: bool

def __init__(self, cmd: Alconna, need_tome: bool, remove_tome: bool):
def __init__(self, cmd: Alconna, need_tome: bool, remove_tome: bool, use_config_prefix: bool = True):
super().__init__(priority=2)
self.cmd = cmd
self.need_tome = need_tome
self.remove_tome = remove_tome
self.use_config_prefix = use_config_prefix

async def __call__(self, scope: Scope, interface: Interface) -> Optional[Union[bool, Interface.Update]]:
account: Account = interface.ctx["account"]
Expand All @@ -70,6 +84,8 @@ async def __call__(self, scope: Scope, interface: Interface) -> Optional[Union[b
output_manager.set_action(lambda x: x, self.cmd.name)
if self.remove_tome:
message = _remove_tome(message, account)
if self.use_config_prefix and not (message := _remove_config_prefix(message)):
return False
try:
_res = self.cmd.parse(message)
except Exception as e:
Expand All @@ -92,12 +108,15 @@ def id(self) -> str:


class ExecuteSuppiler(SupplyAuxiliary):
def __init__(self, cmd: Alconna):
def __init__(self, cmd: Alconna, use_config_prefix: bool = True):
self.cmd = cmd
self.use_config_prefix = use_config_prefix
super().__init__(priority=1)

async def __call__(self, scope: Scope, interface: Interface):
message = interface.query(MessageChain, "command")
if self.use_config_prefix and not (message := _remove_config_prefix(message)):
return False
with output_manager.capture(self.cmd.name) as cap:
output_manager.set_action(lambda x: x, self.cmd.name)
try:
Expand Down
12 changes: 10 additions & 2 deletions arclet/entari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import json
import os
from pathlib import Path
from typing import Callable, ClassVar
from typing import Any, Callable, ClassVar, TypedDict


class BasicConfig(TypedDict, total=False):
network: list[dict[str, Any]]
ignore_self_message: bool
record_message: bool
log_level: int | str
command_prefix: list[str]


@dataclass
class EntariConfig:
path: Path
basic: dict = field(default_factory=dict, init=False)
basic: BasicConfig = field(default_factory=dict, init=False) # type: ignore
plugin: dict[str, dict | bool] = field(default_factory=dict, init=False)
updater: Callable[[EntariConfig], None]

Expand Down

0 comments on commit bf515b9

Please sign in to comment.