diff --git a/maubot/__init__.py b/maubot/__init__.py index 366b951..3e2ffac 100644 --- a/maubot/__init__.py +++ b/maubot/__init__.py @@ -1,3 +1,2 @@ from .plugin_base import Plugin from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent -from .handlers import event, command diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index 2dbc58d..c2e8199 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -13,3 +13,89 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any +import functools +import re + +from mautrix.client import EventHandler +from mautrix.types import MessageType + +from ..matrix import MaubotMessageEvent +from .event import EventHandlerDecorator + +PrefixType = Union[str, Callable[[], str]] +CommandDecorator = Callable[[PrefixType, str], EventHandlerDecorator] + + +def _get_subcommand_decorator(parent: EventHandler) -> CommandDecorator: + def subcommand(name: PrefixType, help: str = None) -> EventHandlerDecorator: + cmd_decorator = new(name=f"{parent.__mb_name__} {name}", help=help) + + def decorator(func: EventHandler) -> EventHandler: + func = cmd_decorator(func) + parent.__mb_subcommands__.append(func) + return func + + return decorator + + return subcommand + + +def new(name: Union[str, Callable[[], str]], help: str = None) -> EventHandlerDecorator: + def decorator(func: EventHandler) -> EventHandler: + func.__mb_subcommands__ = [] + func.__mb_help__ = help + func.__mb_name__ = name or func.__name__ + func.subcommand = _get_subcommand_decorator(func) + return func + + return decorator + + +PassiveCommandHandler = Callable[[MaubotMessageEvent, ...], Awaitable[None]] +PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", + Callable[[PassiveCommandHandler], PassiveCommandHandler]) + + +def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,), + field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body + ) -> PassiveCommandHandlerDecorator: + if not isinstance(regex, Pattern): + regex = re.compile(regex) + + def decorator(func: PassiveCommandHandler) -> PassiveCommandHandler: + @functools.wraps(func) + async def replacement(event: MaubotMessageEvent) -> None: + if event.sender == event.client.mxid: + return + elif msgtypes and event.content.msgtype not in msgtypes: + return + match = regex.match(field(event)) + if match: + await func(event, *list(match.groups())) + + return replacement + + return decorator + + +class _Argument: + def __init__(self, name: str, required: bool, matches: Optional[str], + parser: Optional[Callable[[str], Any]]) -> None: + pass + + +def argument(name: str, *, required: bool = True, matches: Optional[str] = None, + parser: Optional[Callable[[str], Any]] = None) -> EventHandlerDecorator: + def decorator(func: EventHandler) -> EventHandler: + if not hasattr(func, "__mb_arguments__"): + func.__mb_arguments__ = [] + func.__mb_arguments__.append(_Argument(name, required, matches, parser)) + return func + + return decorator + + +def vararg(func: EventHandler) -> EventHandler: + func.__mb_vararg__ = True + return func diff --git a/maubot/handlers/event.py b/maubot/handlers/event.py index 2562bf3..7bec17e 100644 --- a/maubot/handlers/event.py +++ b/maubot/handlers/event.py @@ -23,20 +23,21 @@ from mautrix.client import EventHandler EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) -def handler(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]: +def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]: def decorator(func: EventHandler) -> EventHandler: - func.__mb_event_handler__ = True + @functools.wraps(func) + async def wrapper(event: Event) -> None: + pass + + wrapper.__mb_event_handler__ = True if isinstance(var, EventType): - func.__mb_event_type__ = var + wrapper.__mb_event_type__ = var else: - func.__mb_event_type__ = EventType.ALL + wrapper.__mb_event_type__ = EventType.ALL - return func + return wrapper - if isinstance(var, EventType): - return decorator - else: - decorator(var) + return decorator if isinstance(var, EventType) else decorator(var) class Field: diff --git a/maubot/matrix.py b/maubot/matrix.py index 2969c9e..1cb2d6e 100644 --- a/maubot/matrix.py +++ b/maubot/matrix.py @@ -21,7 +21,7 @@ import attr from mautrix import Client as MatrixClient from mautrix.util.formatter import parse_html from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent, - MessageType, TextMessageEventContent, Format, RelatesTo) + MessageType, TextMessageEventContent, Format, RelatesTo, StateEvent) class EscapeHTML(Extension): @@ -39,12 +39,12 @@ def parse_markdown(markdown: str, allow_html: bool = False) -> Tuple[str, str]: class MaubotMessageEvent(MessageEvent): - _client: MatrixClient + client: MatrixClient def __init__(self, base: MessageEvent, client: MatrixClient): super().__init__(**{a.name.lstrip("_"): getattr(base, a.name) for a in attr.fields(MessageEvent)}) - self._client = client + self.client = client def respond(self, content: Union[str, MessageEventContent], event_type: EventType = EventType.ROOM_MESSAGE, @@ -56,7 +56,7 @@ class MaubotMessageEvent(MessageEvent): content.body, content.formatted_body = parse_markdown(content.body) if reply: content.set_reply(self) - return self._client.send_message_event(self.room_id, event_type, content) + return self.client.send_message_event(self.room_id, event_type, content) def reply(self, content: Union[str, MessageEventContent], event_type: EventType = EventType.ROOM_MESSAGE, @@ -64,7 +64,7 @@ class MaubotMessageEvent(MessageEvent): return self.respond(content, event_type, markdown, reply=True) def mark_read(self) -> Awaitable[None]: - return self._client.send_receipt(self.room_id, self.event_id, "m.read") + return self.client.send_receipt(self.room_id, self.event_id, "m.read") class MaubotMatrixClient(MatrixClient): @@ -79,10 +79,14 @@ class MaubotMatrixClient(MatrixClient): async def call_handlers(self, event: Event) -> None: if isinstance(event, MessageEvent): event = MaubotMessageEvent(event, self) + else: + event.client = self return await super().call_handlers(event) async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: event = await super().get_event(room_id, event_id) if isinstance(event, MessageEvent): return MaubotMessageEvent(event, self) + else: + event.client = self return event diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index d41ae9c..2835b66 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -52,18 +52,8 @@ class Plugin(ABC): for key in dir(self): val = getattr(self, key) if hasattr(val, "__mb_event_handler__"): - handle_own_events = hasattr(val, "__mb_handle_own_events__") - - @functools.wraps(val) - async def handler(event: Event) -> None: - if not handle_own_events and getattr(event, "sender", "") == self.client.mxid: - return - for filter in val.__mb_event_filters__: - if not filter(event): - return - await val(event) - self._handlers_at_startup.append((handler, val.__mb_event_type__)) - self.client.add_event_handler(val.__mb_event_type__, handler) + self._handlers_at_startup.append((val, val.__mb_event_type__)) + self.client.add_event_handler(val.__mb_event_type__, val) async def stop(self) -> None: for func, event_type in self._handlers_at_startup: