diff --git a/maubot/cli/commands/build.py b/maubot/cli/commands/build.py index 0f3201c..823cb2e 100644 --- a/maubot/cli/commands/build.py +++ b/maubot/cli/commands/build.py @@ -118,6 +118,8 @@ def upload_plugin(output: Union[str, IO]) -> None: default=False) def build(path: str, output: str, upload: bool) -> None: meta = read_meta(path) + if not meta: + return if output or not upload: output = read_output_path(output, meta) if not output: diff --git a/maubot/client.py b/maubot/client.py index 5b91d8c..1ec6d4f 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -55,7 +55,7 @@ class Client: token=self.access_token, client_session=self.http_client, log=self.log, loop=self.loop, store=self.db_instance) if self.autojoin: - self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) + self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) async def start(self, try_n: Optional[int] = 0) -> None: try: @@ -260,9 +260,9 @@ class Client: if value == self.db_instance.autojoin: return if value: - self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) + self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) else: - self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER) + self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) self.db_instance.autojoin = value @property diff --git a/maubot/config.py b/maubot/config.py index ab3e080..b36a5c5 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -29,7 +29,8 @@ class Config(BaseFileConfig): return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) def do_update(self, helper: ConfigUpdateHelper) -> None: - base, copy, _ = helper + base = helper.base + copy = helper.copy copy("database") copy("plugin_directories.upload") copy("plugin_directories.load") diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index c2e8199..5e1b894 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -13,89 +13,196 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any +from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, Dict import functools import re -from mautrix.client import EventHandler -from mautrix.types import MessageType +from mautrix.types import MessageType, EventType from ..matrix import MaubotMessageEvent -from .event import EventHandlerDecorator +from . import event -PrefixType = Union[str, Callable[[], str]] -CommandDecorator = Callable[[PrefixType, str], EventHandlerDecorator] +PrefixType = Optional[Union[str, Callable[[], str]]] +CommandHandlerFunc = NewType("CommandHandlerFunc", + Callable[[MaubotMessageEvent, Any], Awaitable[Any]]) +CommandHandlerDecorator = NewType("CommandHandlerDecorator", + Callable[[Union['CommandHandler', CommandHandlerFunc]], + 'CommandHandler']) +PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", + Callable[[CommandHandlerFunc], CommandHandlerFunc]) -def _get_subcommand_decorator(parent: EventHandler) -> CommandDecorator: - def subcommand(name: PrefixType, help: str = None) -> EventHandlerDecorator: - cmd_decorator = new(name=f"{parent.__mb_name__} {name}", help=help) +class CommandHandler: + def __init__(self, func: CommandHandlerFunc) -> None: + self.__mb_func__: CommandHandlerFunc = func + self.__mb_subcommands__: Dict[str, CommandHandler] = {} + self.__mb_arguments__: List[Argument] = [] + self.__mb_help__: str = None + self.__mb_name__: str = None + self.__mb_prefix__: str = None + self.__mb_require_subcommand__: bool = True + self.__mb_event_handler__: bool = True + self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE + self.__class_instance: Any = None - def decorator(func: EventHandler) -> EventHandler: - func = cmd_decorator(func) - parent.__mb_subcommands__.append(func) + async def __call__(self, evt: MaubotMessageEvent, *, + _existing_args: Dict[str, Any] = None) -> Any: + body = evt.content.body + if evt.sender == evt.client.mxid or not body.startswith(self.__mb_prefix__): + return + call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {} + remaining_val = body[len(self.__mb_prefix__) + 1:] + # TODO update remaining_val somehow + for arg in self.__mb_arguments__: + try: + call_args[arg.name] = arg.match(remaining_val) + if arg.required and not call_args[arg.name]: + raise ValueError("Argument required") + except ArgumentSyntaxError as e: + await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else "")) + return + except ValueError as e: + await evt.reply(self.__mb_usage__) + return + + if len(self.__mb_subcommands__) > 0: + split = remaining_val.split(" ") if len(remaining_val) > 0 else [] + try: + subcommand = self.__mb_subcommands__[split[0]] + return await subcommand(evt, _existing_args=call_args) + except (KeyError, IndexError): + if self.__mb_require_subcommand__: + await evt.reply(self.__mb_full_help__) + return + return (await self.__mb_func__(self.__class_instance, evt, **call_args) + if self.__class_instance + else await self.__mb_func__(evt, **call_args)) + + def __get__(self, instance, instancetype): + self.__class_instance = instance + return self + + @property + def __mb_full_help__(self) -> str: + basic = self.__mb_usage__ + usage = f"{basic} [...]\n\n" + usage += "\n".join(f"* {cmd.__mb_name__} {cmd.__mb_usage_args__} - {cmd.__mb_help__}" + for cmd in self.__mb_subcommands__.values()) + return usage + + @property + def __mb_usage_args__(self) -> str: + return " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]" + for arg in self.__mb_arguments__) + + @property + def __mb_usage__(self) -> str: + return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" + + def subcommand(self, name: PrefixType = None, help: str = None + ) -> CommandHandlerDecorator: + def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: + if not isinstance(func, CommandHandler): + func = CommandHandler(func) + func.__mb_name__ = name or func.__name__ + func.__mb_prefix__ = f"{self.__mb_prefix__} {func.__mb_name__}" + func.__mb_help__ = help + func.__mb_event_handler__ = False + self.__mb_subcommands__[func.__mb_name__] = func return func return decorator - return subcommand + +class ArgumentSyntaxError(ValueError): + def __init__(self, message: str, show_usage: bool = True) -> None: + super().__init__(message) + self.message = message + self.show_usage = show_usage -def new(name: Union[str, Callable[[], str]], help: str = None) -> EventHandlerDecorator: - def decorator(func: EventHandler) -> EventHandler: - func.__mb_subcommands__ = [] +class Argument: + def __init__(self, name: str, label: str = None, *, required: bool = False, + matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None, + pass_raw: bool = False) -> None: + self.name = name + self.required = required + self.label = label or name + + if not parser: + if matches: + regex = re.compile(matches) + + def parser(val: str) -> Optional[Sequence[str]]: + match = regex.match(val) + return match.groups() if match else None + else: + def parser(val: str) -> str: + return val + + if not pass_raw: + o_parser = parser + + def parser(val: str) -> Any: + val = val.strip().split(" ") + return o_parser(val[0]) + + self.parser = parser + + def match(self, val: str) -> Any: + return self.parser(val) + + def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: + if not isinstance(func, CommandHandler): + func = CommandHandler(func) + func.__mb_arguments__.append(self) + return func + + +def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE, + require_subcommand: bool = True) -> CommandHandlerDecorator: + def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: + if not isinstance(func, CommandHandler): + func = CommandHandler(func) func.__mb_help__ = help func.__mb_name__ = name or func.__name__ - func.subcommand = _get_subcommand_decorator(func) + func.__mb_require_subcommand__ = require_subcommand + func.__mb_prefix__ = f"!{func.__mb_name__}" + func.__mb_event_type__ = event_type return func return decorator -PassiveCommandHandler = Callable[[MaubotMessageEvent, ...], Awaitable[None]] -PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", - Callable[[PassiveCommandHandler], PassiveCommandHandler]) +def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None, + parser: Optional[Callable[[str], Any]] = None) -> CommandHandlerDecorator: + return Argument(name, label, required=required, matches=matches, parser=parser) def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,), - field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body - ) -> PassiveCommandHandlerDecorator: + field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body, + event_type: EventType = EventType.ROOM_MESSAGE) -> PassiveCommandHandlerDecorator: if not isinstance(regex, Pattern): regex = re.compile(regex) - def decorator(func: PassiveCommandHandler) -> PassiveCommandHandler: + def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc: + @event.on(event_type) @functools.wraps(func) - async def replacement(event: MaubotMessageEvent) -> None: - if event.sender == event.client.mxid: + async def replacement(self, evt: MaubotMessageEvent) -> None: + if isinstance(self, MaubotMessageEvent): + evt = self + self = None + if evt.sender == evt.client.mxid: return - elif msgtypes and event.content.msgtype not in msgtypes: + elif msgtypes and evt.content.msgtype not in msgtypes: return - match = regex.match(field(event)) + match = regex.match(field(evt)) if match: - await func(event, *list(match.groups())) + if self: + await func(self, evt, *list(match.groups())) + else: + await func(evt, *list(match.groups())) return replacement return decorator - - -class _Argument: - def __init__(self, name: str, required: bool, matches: Optional[str], - parser: Optional[Callable[[str], Any]]) -> None: - pass - - -def argument(name: str, *, required: bool = True, matches: Optional[str] = None, - parser: Optional[Callable[[str], Any]] = None) -> EventHandlerDecorator: - def decorator(func: EventHandler) -> EventHandler: - if not hasattr(func, "__mb_arguments__"): - func.__mb_arguments__ = [] - func.__mb_arguments__.append(_Argument(name, required, matches, parser)) - return func - - return decorator - - -def vararg(func: EventHandler) -> EventHandler: - func.__mb_vararg__ = True - return func diff --git a/maubot/handlers/event.py b/maubot/handlers/event.py index 7bec17e..d6bde8e 100644 --- a/maubot/handlers/event.py +++ b/maubot/handlers/event.py @@ -13,11 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, Union, NewType, Any, Tuple, Optional -import functools -import re +from typing import Callable, Union, NewType -from mautrix.types import EventType, Event, EventContent, MessageEvent, MessageEventContent +from mautrix.types import EventType from mautrix.client import EventHandler EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) @@ -25,93 +23,12 @@ EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler] def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]: def decorator(func: EventHandler) -> EventHandler: - @functools.wraps(func) - async def wrapper(event: Event) -> None: - pass - - wrapper.__mb_event_handler__ = True + func.__mb_event_handler__ = True if isinstance(var, EventType): - wrapper.__mb_event_type__ = var + func.__mb_event_type__ = var else: - wrapper.__mb_event_type__ = EventType.ALL + func.__mb_event_type__ = EventType.ALL - return wrapper - - return decorator if isinstance(var, EventType) else decorator(var) - - -class Field: - body: Callable[[MessageEventContent], str] = lambda content: content.body - msgtype: Callable[[MessageEventContent], str] = lambda content: content.msgtype - - -def _parse_key(key: str) -> Tuple[str, Optional[str]]: - if '.' not in key: - return key, None - key, next_key = key.split('.', 1) - if len(key) > 0 and key[0] == "[": - end_index = next_key.index("]") - key = key[1:] + "." + next_key[:end_index] - next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None - return key, next_key - - -def _recursive_get(data: EventContent, key: str) -> Any: - key, next_key = _parse_key(key) - if next_key is not None: - next_data = data.get(key, None) - if next_data is None: - return None - return _recursive_get(next_data, next_key) - return data.get(key, None) - - -def _find_content_field(content: EventContent, field: str) -> Any: - val = _recursive_get(content, field) - if not val and hasattr(content, "unrecognized_"): - val = _recursive_get(content.unrecognized_, field) - return val - - -def handle_own_events(func: EventHandler) -> EventHandler: - func.__mb_handle_own_events__ = True - - -def filter_content(field: Union[str, Callable[[EventContent], Any]], substr: str = None, - pattern: str = None, exact: bool = False): - if substr and pattern: - raise ValueError("You can only provide one of substr or pattern.") - elif not substr and not pattern: - raise ValueError("You must provide either substr or pattern.") - - if not callable(field): - field = functools.partial(_find_content_field, field=field) - - if substr: - def func(evt: MessageEvent) -> bool: - val = field(evt.content) - if val is None: - return False - elif substr in val: - return True - else: - pattern = re.compile(pattern) - - def func(evt: MessageEvent) -> bool: - val = field(evt.content) - if val is None: - return False - elif pattern.match(val): - return True - - return filter(func) - - -def filter(func: Callable[[MessageEvent], bool]) -> EventHandlerDecorator: - def decorator(func: EventHandler) -> EventHandler: - if not hasattr(func, "__mb_event_filters__"): - func.__mb_event_filters__ = [] - func.__mb_event_filters__.append(func) return func - return decorator + return decorator if isinstance(var, EventType) else decorator(var) diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 2835b66..c720cc5 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -51,7 +51,7 @@ class Plugin(ABC): async def start(self) -> None: for key in dir(self): val = getattr(self, key) - if hasattr(val, "__mb_event_handler__"): + if hasattr(val, "__mb_event_handler__") and val.__mb_event_handler__: self._handlers_at_startup.append((val, val.__mb_event_type__)) self.client.add_event_handler(val.__mb_event_type__, val)