Skip to content

Commit

Permalink
💥 version 0.6.0
Browse files Browse the repository at this point in the history
changed plugin design
  • Loading branch information
RF-Tar-Railt committed Jul 10, 2024
1 parent a59c2d2 commit 5e5f13f
Show file tree
Hide file tree
Showing 14 changed files with 581 additions and 416 deletions.
25 changes: 20 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ from arclet.entari import Session, Entari, WS

app = Entari(WS(host="127.0.0.1", port=5140, path="satori"))


@app.on_message()
async def repeat(session: Session):
await session.send(session.content)
Expand All @@ -34,10 +33,7 @@ app.run()
指令 `add {a} {b}`:

```python
from arclet.entari import Session, Entari, EntariCommands, WS

command = EntariCommands()

from arclet.entari import Session, Entari, WS, command

@command.on("add {a} {b}")
async def add(a: int, b: int, session: Session):
Expand All @@ -47,3 +43,22 @@ async def add(a: int, b: int, session: Session):
app = Entari(WS(port=5500, token="XXX"))
app.run()
```

编写插件:

```python
from arclet.entari import Session, MessageEvent, PluginMetadata

__plugin_metadata__ = PluginMetadata(
name="Hello, World!",
author=["Arclet"],
version="0.1.0",
description="A simple plugin that replies 'Hello, World!' to every message."
)

on_message = MessageEvent.dispatch()

@on_message()
async def _(session: Session):
await session.send("Hello, World!")
```
5 changes: 2 additions & 3 deletions arclet/entari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
from satori.config import WebhookInfo as WebhookInfo
from satori.config import WebsocketsInfo as WebsocketsInfo

from .command import AlconnaDispatcher as AlconnaDispatcher
from .command import EntariCommands as EntariCommands
from .core import Entari as Entari
from .event import MessageCreatedEvent as MessageCreatedEvent
from .event import MessageEvent as MessageEvent
from .filter import is_direct_message as is_direct_message
from .filter import is_public_message as is_public_message
from .message import MessageChain as MessageChain
from .plugin import Plugin as Plugin
from .plugin import PluginMetadata as PluginMetadata
from .plugin import dispose as dispose_plugin # noqa: F401
from .plugin import load_plugin as load_plugin
from .plugin import load_plugins as load_plugins
from .session import Session as Session
Expand Down
226 changes: 221 additions & 5 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,221 @@
from .main import EntariCommands as EntariCommands
from .model import CommandResult as CommandResult
from .model import Match as Match
from .model import Query as Query
from .plugin import AlconnaDispatcher as AlconnaDispatcher
import asyncio
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload

from arclet.alconna import (
Alconna,
Arg,
Args,
Arparma,
CommandMeta,
Namespace,
command_manager,
config,
output_manager,
)
from arclet.alconna.args import TAValue
from arclet.alconna.tools.construct import AlconnaString, alconna_from_format
from arclet.letoderea import BaseAuxiliary, Provider, Publisher, Scope, Subscriber
from arclet.letoderea.handler import depend_handler
from arclet.letoderea.provider import ProviderFactory
from nepattern import DirectPattern
from pygtrie import CharTrie
from satori.element import At, Text
from tarina.string import split

from ..event import MessageEvent
from ..message import MessageChain
from .argv import MessageArgv # noqa: F401
from .model import CommandResult, Match, Query
from .plugin import mount
from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger, get_cmd

T = TypeVar("T")
TCallable = TypeVar("TCallable", bound=Callable[..., Any])


class EntariCommands:
__namespace__ = "Entari"

def __init__(self, need_tome: bool = False, remove_tome: bool = False):
self.trie: CharTrie = CharTrie()
self.publisher = Publisher("EntariCommands", MessageEvent)
self.publisher.providers.append(AlconnaProviderFactory())
self.need_tome = need_tome
self.remove_tome = remove_tome
config.namespaces["Entari"] = Namespace(
self.__namespace__,
to_text=lambda x: x.text if x.__class__ is Text else None,
converter=lambda x: MessageChain(x),
)

@self.publisher.register(auxiliaries=[MessageJudger()])
async def listener(event: MessageEvent):
msg = str(event.content.exclude(At)).lstrip()
if not msg:
return
if matches := list(self.trie.prefixes(msg)):
await asyncio.gather(
*(depend_handler(res.value, event, inner=True) for res in matches if res.value)
)
return
# shortcut
data = split(msg, (" ",))
for value in self.trie.values():
try:
command_manager.find_shortcut(get_cmd(value), data)
except ValueError:
continue
await depend_handler(value, event, inner=True)

@property
def all_helps(self) -> str:
return command_manager.all_command_help(namespace=self.__namespace__)

def get_help(self, command: str) -> str:
return command_manager.get_command(f"{self.__namespace__}::{command}").get_help()

async def execute(self, message: MessageChain):
async def _run(target: Subscriber, content: MessageChain):
aux = next((a for a in target.auxiliaries[Scope.prepare] if isinstance(a, AlconnaSuppiler)), None)
if not aux:
return
with output_manager.capture(aux.cmd.name) as cap:
output_manager.set_action(lambda x: x, aux.cmd.name)
try:
_res = aux.cmd.parse(content)
except Exception as e:
_res = Arparma(aux.cmd.path, message, False, error_info=e)
may_help_text: Optional[str] = cap.get("output", None)
if _res.matched:
args = {}
ctx = {"alc_result": CommandResult(aux.cmd, _res, may_help_text)}
for param in target.params:
args[param.name] = await param.solve(ctx)
return await target(**args)
elif may_help_text:
return may_help_text

msg = str(message.exclude(At)).lstrip()
if matches := list(self.trie.prefixes(msg)):
return await asyncio.gather(*(_run(res.value, message) for res in matches if res.value))
# shortcut
data = split(msg, (" ",))
res = []
for value in self.trie.values():
try:
command_manager.find_shortcut(get_cmd(value), data)
except ValueError:
continue
res.append(await _run(value, message))
return res

def command(
self,
command: str,
help_text: Optional[str] = None,
need_tome: bool = False,
remove_tome: bool = False,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
):
class Command(AlconnaString):
def __call__(_cmd_self, func: TCallable) -> TCallable:
return self.on(_cmd_self.build(), need_tome, remove_tome, auxiliaries, providers)(func)

return Command(command, help_text)

@overload
def on(
self,
command: Alconna,
need_tome: bool = False,
remove_tome: bool = False,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
) -> Callable[[TCallable], TCallable]: ...

@overload
def on(
self,
command: str,
need_tome: bool = False,
remove_tome: bool = False,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
) -> Callable[[TCallable], TCallable]: ...

def on(
self,
command: Union[Alconna, str],
need_tome: bool = False,
remove_tome: bool = False,
auxiliaries: Optional[list[BaseAuxiliary]] = None,
providers: Optional[
list[Union[Provider, type[Provider], ProviderFactory, type[ProviderFactory]]]
] = None,
*,
args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None,
meta: Optional[CommandMeta] = None,
) -> Callable[[TCallable], TCallable]:
auxiliaries = auxiliaries or []
providers = providers or []

def wrapper(func: TCallable) -> TCallable:
if isinstance(command, str):
mapping = {arg.name: arg.value for arg in Args.from_callable(func)[0]}
mapping.update(args or {}) # type: ignore
_command = alconna_from_format(command, mapping, meta, union=False)
_command.reset_namespace(self.__namespace__)
key = _command.name + "".join(
f" {arg.value.target}" for arg in _command.args if isinstance(arg.value, DirectPattern)
)
auxiliaries.insert(
0, AlconnaSuppiler(_command, need_tome or self.need_tome, remove_tome or self.remove_tome)
)
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
self.publisher.remove_subscriber(target)
self.trie[key] = target
else:
auxiliaries.insert(
0, AlconnaSuppiler(command, need_tome or self.need_tome, remove_tome or self.remove_tome)
)
target = self.publisher.register(auxiliaries=auxiliaries, providers=providers)(func)
self.publisher.remove_subscriber(target)
if not isinstance(command.command, str):
raise TypeError("Command name must be a string.")
if not command.prefixes:
self.trie[command.command] = target
elif not all(isinstance(i, str) for i in command.prefixes):
raise TypeError("Command prefixes must be a list of string.")
else:
self.publisher.remove_subscriber(target)
for prefix in cast(list[str], command.prefixes):
self.trie[prefix + command.command] = target
command.reset_namespace(self.__namespace__)
return func

return wrapper


_commands = EntariCommands()


def config_commands(need_tome: bool = False, remove_tome: bool = False):
_commands.need_tome = need_tome
_commands.remove_tome = remove_tome


command = _commands.command
on = _commands.on


__all__ = ["_commands", "config_commands", "Match", "Query", "CommandResult", "mount", "command", "on"]
Loading

0 comments on commit 5e5f13f

Please sign in to comment.