Make new command handling actually somewhat work
This commit is contained in:
parent
682eab348d
commit
5ff5eae3c6
@ -118,6 +118,8 @@ def upload_plugin(output: Union[str, IO]) -> None:
|
|||||||
default=False)
|
default=False)
|
||||||
def build(path: str, output: str, upload: bool) -> None:
|
def build(path: str, output: str, upload: bool) -> None:
|
||||||
meta = read_meta(path)
|
meta = read_meta(path)
|
||||||
|
if not meta:
|
||||||
|
return
|
||||||
if output or not upload:
|
if output or not upload:
|
||||||
output = read_output_path(output, meta)
|
output = read_output_path(output, meta)
|
||||||
if not output:
|
if not output:
|
||||||
|
@ -55,7 +55,7 @@ class Client:
|
|||||||
token=self.access_token, client_session=self.http_client,
|
token=self.access_token, client_session=self.http_client,
|
||||||
log=self.log, loop=self.loop, store=self.db_instance)
|
log=self.log, loop=self.loop, store=self.db_instance)
|
||||||
if self.autojoin:
|
if self.autojoin:
|
||||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||||
|
|
||||||
async def start(self, try_n: Optional[int] = 0) -> None:
|
async def start(self, try_n: Optional[int] = 0) -> None:
|
||||||
try:
|
try:
|
||||||
@ -260,9 +260,9 @@ class Client:
|
|||||||
if value == self.db_instance.autojoin:
|
if value == self.db_instance.autojoin:
|
||||||
return
|
return
|
||||||
if value:
|
if value:
|
||||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||||
else:
|
else:
|
||||||
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||||
self.db_instance.autojoin = value
|
self.db_instance.autojoin = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -29,7 +29,8 @@ class Config(BaseFileConfig):
|
|||||||
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||||
|
|
||||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||||
base, copy, _ = helper
|
base = helper.base
|
||||||
|
copy = helper.copy
|
||||||
copy("database")
|
copy("database")
|
||||||
copy("plugin_directories.upload")
|
copy("plugin_directories.upload")
|
||||||
copy("plugin_directories.load")
|
copy("plugin_directories.load")
|
||||||
|
@ -13,89 +13,196 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any
|
from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, Dict
|
||||||
import functools
|
import functools
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from mautrix.client import EventHandler
|
from mautrix.types import MessageType, EventType
|
||||||
from mautrix.types import MessageType
|
|
||||||
|
|
||||||
from ..matrix import MaubotMessageEvent
|
from ..matrix import MaubotMessageEvent
|
||||||
from .event import EventHandlerDecorator
|
from . import event
|
||||||
|
|
||||||
PrefixType = Union[str, Callable[[], str]]
|
PrefixType = Optional[Union[str, Callable[[], str]]]
|
||||||
CommandDecorator = Callable[[PrefixType, str], EventHandlerDecorator]
|
CommandHandlerFunc = NewType("CommandHandlerFunc",
|
||||||
|
Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
|
||||||
|
CommandHandlerDecorator = NewType("CommandHandlerDecorator",
|
||||||
|
Callable[[Union['CommandHandler', CommandHandlerFunc]],
|
||||||
|
'CommandHandler'])
|
||||||
|
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
|
||||||
|
Callable[[CommandHandlerFunc], CommandHandlerFunc])
|
||||||
|
|
||||||
|
|
||||||
def _get_subcommand_decorator(parent: EventHandler) -> CommandDecorator:
|
class CommandHandler:
|
||||||
def subcommand(name: PrefixType, help: str = None) -> EventHandlerDecorator:
|
def __init__(self, func: CommandHandlerFunc) -> None:
|
||||||
cmd_decorator = new(name=f"{parent.__mb_name__} {name}", help=help)
|
self.__mb_func__: CommandHandlerFunc = func
|
||||||
|
self.__mb_subcommands__: Dict[str, CommandHandler] = {}
|
||||||
|
self.__mb_arguments__: List[Argument] = []
|
||||||
|
self.__mb_help__: str = None
|
||||||
|
self.__mb_name__: str = None
|
||||||
|
self.__mb_prefix__: str = None
|
||||||
|
self.__mb_require_subcommand__: bool = True
|
||||||
|
self.__mb_event_handler__: bool = True
|
||||||
|
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
||||||
|
self.__class_instance: Any = None
|
||||||
|
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
async def __call__(self, evt: MaubotMessageEvent, *,
|
||||||
func = cmd_decorator(func)
|
_existing_args: Dict[str, Any] = None) -> Any:
|
||||||
parent.__mb_subcommands__.append(func)
|
body = evt.content.body
|
||||||
|
if evt.sender == evt.client.mxid or not body.startswith(self.__mb_prefix__):
|
||||||
|
return
|
||||||
|
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
|
||||||
|
remaining_val = body[len(self.__mb_prefix__) + 1:]
|
||||||
|
# TODO update remaining_val somehow
|
||||||
|
for arg in self.__mb_arguments__:
|
||||||
|
try:
|
||||||
|
call_args[arg.name] = arg.match(remaining_val)
|
||||||
|
if arg.required and not call_args[arg.name]:
|
||||||
|
raise ValueError("Argument required")
|
||||||
|
except ArgumentSyntaxError as e:
|
||||||
|
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
|
||||||
|
return
|
||||||
|
except ValueError as e:
|
||||||
|
await evt.reply(self.__mb_usage__)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.__mb_subcommands__) > 0:
|
||||||
|
split = remaining_val.split(" ") if len(remaining_val) > 0 else []
|
||||||
|
try:
|
||||||
|
subcommand = self.__mb_subcommands__[split[0]]
|
||||||
|
return await subcommand(evt, _existing_args=call_args)
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
if self.__mb_require_subcommand__:
|
||||||
|
await evt.reply(self.__mb_full_help__)
|
||||||
|
return
|
||||||
|
return (await self.__mb_func__(self.__class_instance, evt, **call_args)
|
||||||
|
if self.__class_instance
|
||||||
|
else await self.__mb_func__(evt, **call_args))
|
||||||
|
|
||||||
|
def __get__(self, instance, instancetype):
|
||||||
|
self.__class_instance = instance
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_full_help__(self) -> str:
|
||||||
|
basic = self.__mb_usage__
|
||||||
|
usage = f"{basic} <subcommand> [...]\n\n"
|
||||||
|
usage += "\n".join(f"* {cmd.__mb_name__} {cmd.__mb_usage_args__} - {cmd.__mb_help__}"
|
||||||
|
for cmd in self.__mb_subcommands__.values())
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_usage_args__(self) -> str:
|
||||||
|
return " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
|
||||||
|
for arg in self.__mb_arguments__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __mb_usage__(self) -> str:
|
||||||
|
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
|
||||||
|
|
||||||
|
def subcommand(self, name: PrefixType = None, help: str = None
|
||||||
|
) -> CommandHandlerDecorator:
|
||||||
|
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
||||||
|
if not isinstance(func, CommandHandler):
|
||||||
|
func = CommandHandler(func)
|
||||||
|
func.__mb_name__ = name or func.__name__
|
||||||
|
func.__mb_prefix__ = f"{self.__mb_prefix__} {func.__mb_name__}"
|
||||||
|
func.__mb_help__ = help
|
||||||
|
func.__mb_event_handler__ = False
|
||||||
|
self.__mb_subcommands__[func.__mb_name__] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
return subcommand
|
|
||||||
|
class ArgumentSyntaxError(ValueError):
|
||||||
|
def __init__(self, message: str, show_usage: bool = True) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.show_usage = show_usage
|
||||||
|
|
||||||
|
|
||||||
def new(name: Union[str, Callable[[], str]], help: str = None) -> EventHandlerDecorator:
|
class Argument:
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
def __init__(self, name: str, label: str = None, *, required: bool = False,
|
||||||
func.__mb_subcommands__ = []
|
matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None,
|
||||||
|
pass_raw: bool = False) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.required = required
|
||||||
|
self.label = label or name
|
||||||
|
|
||||||
|
if not parser:
|
||||||
|
if matches:
|
||||||
|
regex = re.compile(matches)
|
||||||
|
|
||||||
|
def parser(val: str) -> Optional[Sequence[str]]:
|
||||||
|
match = regex.match(val)
|
||||||
|
return match.groups() if match else None
|
||||||
|
else:
|
||||||
|
def parser(val: str) -> str:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if not pass_raw:
|
||||||
|
o_parser = parser
|
||||||
|
|
||||||
|
def parser(val: str) -> Any:
|
||||||
|
val = val.strip().split(" ")
|
||||||
|
return o_parser(val[0])
|
||||||
|
|
||||||
|
self.parser = parser
|
||||||
|
|
||||||
|
def match(self, val: str) -> Any:
|
||||||
|
return self.parser(val)
|
||||||
|
|
||||||
|
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
||||||
|
if not isinstance(func, CommandHandler):
|
||||||
|
func = CommandHandler(func)
|
||||||
|
func.__mb_arguments__.append(self)
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
|
||||||
|
require_subcommand: bool = True) -> CommandHandlerDecorator:
|
||||||
|
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
|
||||||
|
if not isinstance(func, CommandHandler):
|
||||||
|
func = CommandHandler(func)
|
||||||
func.__mb_help__ = help
|
func.__mb_help__ = help
|
||||||
func.__mb_name__ = name or func.__name__
|
func.__mb_name__ = name or func.__name__
|
||||||
func.subcommand = _get_subcommand_decorator(func)
|
func.__mb_require_subcommand__ = require_subcommand
|
||||||
|
func.__mb_prefix__ = f"!{func.__mb_name__}"
|
||||||
|
func.__mb_event_type__ = event_type
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
PassiveCommandHandler = Callable[[MaubotMessageEvent, ...], Awaitable[None]]
|
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
|
||||||
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
|
parser: Optional[Callable[[str], Any]] = None) -> CommandHandlerDecorator:
|
||||||
Callable[[PassiveCommandHandler], PassiveCommandHandler])
|
return Argument(name, label, required=required, matches=matches, parser=parser)
|
||||||
|
|
||||||
|
|
||||||
def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
|
def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
|
||||||
field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body
|
field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body,
|
||||||
) -> PassiveCommandHandlerDecorator:
|
event_type: EventType = EventType.ROOM_MESSAGE) -> PassiveCommandHandlerDecorator:
|
||||||
if not isinstance(regex, Pattern):
|
if not isinstance(regex, Pattern):
|
||||||
regex = re.compile(regex)
|
regex = re.compile(regex)
|
||||||
|
|
||||||
def decorator(func: PassiveCommandHandler) -> PassiveCommandHandler:
|
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
|
||||||
|
@event.on(event_type)
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def replacement(event: MaubotMessageEvent) -> None:
|
async def replacement(self, evt: MaubotMessageEvent) -> None:
|
||||||
if event.sender == event.client.mxid:
|
if isinstance(self, MaubotMessageEvent):
|
||||||
|
evt = self
|
||||||
|
self = None
|
||||||
|
if evt.sender == evt.client.mxid:
|
||||||
return
|
return
|
||||||
elif msgtypes and event.content.msgtype not in msgtypes:
|
elif msgtypes and evt.content.msgtype not in msgtypes:
|
||||||
return
|
return
|
||||||
match = regex.match(field(event))
|
match = regex.match(field(evt))
|
||||||
if match:
|
if match:
|
||||||
await func(event, *list(match.groups()))
|
if self:
|
||||||
|
await func(self, evt, *list(match.groups()))
|
||||||
|
else:
|
||||||
|
await func(evt, *list(match.groups()))
|
||||||
|
|
||||||
return replacement
|
return replacement
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class _Argument:
|
|
||||||
def __init__(self, name: str, required: bool, matches: Optional[str],
|
|
||||||
parser: Optional[Callable[[str], Any]]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def argument(name: str, *, required: bool = True, matches: Optional[str] = None,
|
|
||||||
parser: Optional[Callable[[str], Any]] = None) -> EventHandlerDecorator:
|
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
|
||||||
if not hasattr(func, "__mb_arguments__"):
|
|
||||||
func.__mb_arguments__ = []
|
|
||||||
func.__mb_arguments__.append(_Argument(name, required, matches, parser))
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def vararg(func: EventHandler) -> EventHandler:
|
|
||||||
func.__mb_vararg__ = True
|
|
||||||
return func
|
|
||||||
|
@ -13,11 +13,9 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Callable, Union, NewType, Any, Tuple, Optional
|
from typing import Callable, Union, NewType
|
||||||
import functools
|
|
||||||
import re
|
|
||||||
|
|
||||||
from mautrix.types import EventType, Event, EventContent, MessageEvent, MessageEventContent
|
from mautrix.types import EventType
|
||||||
from mautrix.client import EventHandler
|
from mautrix.client import EventHandler
|
||||||
|
|
||||||
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
|
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
|
||||||
@ -25,93 +23,12 @@ EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler]
|
|||||||
|
|
||||||
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
|
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
def decorator(func: EventHandler) -> EventHandler:
|
||||||
@functools.wraps(func)
|
func.__mb_event_handler__ = True
|
||||||
async def wrapper(event: Event) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
wrapper.__mb_event_handler__ = True
|
|
||||||
if isinstance(var, EventType):
|
if isinstance(var, EventType):
|
||||||
wrapper.__mb_event_type__ = var
|
func.__mb_event_type__ = var
|
||||||
else:
|
else:
|
||||||
wrapper.__mb_event_type__ = EventType.ALL
|
func.__mb_event_type__ = EventType.ALL
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator if isinstance(var, EventType) else decorator(var)
|
|
||||||
|
|
||||||
|
|
||||||
class Field:
|
|
||||||
body: Callable[[MessageEventContent], str] = lambda content: content.body
|
|
||||||
msgtype: Callable[[MessageEventContent], str] = lambda content: content.msgtype
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
|
|
||||||
if '.' not in key:
|
|
||||||
return key, None
|
|
||||||
key, next_key = key.split('.', 1)
|
|
||||||
if len(key) > 0 and key[0] == "[":
|
|
||||||
end_index = next_key.index("]")
|
|
||||||
key = key[1:] + "." + next_key[:end_index]
|
|
||||||
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
|
|
||||||
return key, next_key
|
|
||||||
|
|
||||||
|
|
||||||
def _recursive_get(data: EventContent, key: str) -> Any:
|
|
||||||
key, next_key = _parse_key(key)
|
|
||||||
if next_key is not None:
|
|
||||||
next_data = data.get(key, None)
|
|
||||||
if next_data is None:
|
|
||||||
return None
|
|
||||||
return _recursive_get(next_data, next_key)
|
|
||||||
return data.get(key, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_content_field(content: EventContent, field: str) -> Any:
|
|
||||||
val = _recursive_get(content, field)
|
|
||||||
if not val and hasattr(content, "unrecognized_"):
|
|
||||||
val = _recursive_get(content.unrecognized_, field)
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def handle_own_events(func: EventHandler) -> EventHandler:
|
|
||||||
func.__mb_handle_own_events__ = True
|
|
||||||
|
|
||||||
|
|
||||||
def filter_content(field: Union[str, Callable[[EventContent], Any]], substr: str = None,
|
|
||||||
pattern: str = None, exact: bool = False):
|
|
||||||
if substr and pattern:
|
|
||||||
raise ValueError("You can only provide one of substr or pattern.")
|
|
||||||
elif not substr and not pattern:
|
|
||||||
raise ValueError("You must provide either substr or pattern.")
|
|
||||||
|
|
||||||
if not callable(field):
|
|
||||||
field = functools.partial(_find_content_field, field=field)
|
|
||||||
|
|
||||||
if substr:
|
|
||||||
def func(evt: MessageEvent) -> bool:
|
|
||||||
val = field(evt.content)
|
|
||||||
if val is None:
|
|
||||||
return False
|
|
||||||
elif substr in val:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
pattern = re.compile(pattern)
|
|
||||||
|
|
||||||
def func(evt: MessageEvent) -> bool:
|
|
||||||
val = field(evt.content)
|
|
||||||
if val is None:
|
|
||||||
return False
|
|
||||||
elif pattern.match(val):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return filter(func)
|
|
||||||
|
|
||||||
|
|
||||||
def filter(func: Callable[[MessageEvent], bool]) -> EventHandlerDecorator:
|
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
|
||||||
if not hasattr(func, "__mb_event_filters__"):
|
|
||||||
func.__mb_event_filters__ = []
|
|
||||||
func.__mb_event_filters__.append(func)
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator if isinstance(var, EventType) else decorator(var)
|
||||||
|
@ -51,7 +51,7 @@ class Plugin(ABC):
|
|||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
for key in dir(self):
|
for key in dir(self):
|
||||||
val = getattr(self, key)
|
val = getattr(self, key)
|
||||||
if hasattr(val, "__mb_event_handler__"):
|
if hasattr(val, "__mb_event_handler__") and val.__mb_event_handler__:
|
||||||
self._handlers_at_startup.append((val, val.__mb_event_type__))
|
self._handlers_at_startup.append((val, val.__mb_event_type__))
|
||||||
self.client.add_event_handler(val.__mb_event_type__, val)
|
self.client.add_event_handler(val.__mb_event_type__, val)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user