diff --git a/maubot/__init__.py b/maubot/__init__.py index e829d9c..c3148ae 100644 --- a/maubot/__init__.py +++ b/maubot/__init__.py @@ -1,3 +1,3 @@ from .plugin_base import Plugin from .command_spec import CommandSpec, Command, PassiveCommand, Argument -from .client import MaubotMatrixClient as Client +from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent diff --git a/maubot/matrix.py b/maubot/matrix.py index 11b6eba..aa18afc 100644 --- a/maubot/matrix.py +++ b/maubot/matrix.py @@ -13,15 +13,51 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Union, Callable +from typing import Dict, List, Union, Callable, Awaitable +import attr +import commonmark from mautrix import Client as MatrixClient from mautrix.client import EventHandler -from mautrix.types import EventType, MessageEvent +from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent, + MessageType, TextMessageEventContent, Format) from .command_spec import ParsedCommand, CommandSpec +class MaubotMessageEvent(MessageEvent): + _client: MatrixClient + + def __init__(self, base: MessageEvent, client: MatrixClient): + super().__init__(**{a.name.lstrip("_"): getattr(base, a.name) + for a in attr.fields(MessageEvent)}) + self._client = client + + def respond(self, content: Union[str, MessageEventContent], + event_type: EventType = EventType.ROOM_MESSAGE, + markdown: bool = True) -> Awaitable[EventID]: + if isinstance(content, str): + content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content) + if markdown: + content.format = Format.HTML + content.formatted_body = commonmark.commonmark(content.body) + return self._client.send_message_event(self.room_id, event_type, content) + + def reply(self, content: Union[str, MessageEventContent], + event_type: EventType = EventType.ROOM_MESSAGE, + markdown: bool = True) -> Awaitable[EventID]: + if isinstance(content, str): + content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content) + if markdown: + content.format = Format.HTML + content.formatted_body = commonmark.commonmark(content.body) + content.set_reply(self) + return self._client.send_message_event(self.room_id, event_type, content) + + def mark_read(self) -> Awaitable[None]: + return self._client.send_receipt(self.room_id, self.event_id, "m.read") + + class MaubotMatrixClient(MatrixClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -75,3 +111,16 @@ class MaubotMatrixClient(MatrixClient): self.command_handlers[command].remove(handler) except (KeyError, ValueError): pass + + def call_handlers(self, event: Event) -> Awaitable[None]: + if isinstance(event, MessageEvent): + if event.sender == self.mxid: + return + event = MaubotMessageEvent(event, self) + return super().call_handlers(event) + + async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: + event = await super().get_event(room_id, event_id) + if isinstance(event, MessageEvent): + return MaubotMessageEvent(event, self) + return event diff --git a/maubot/plugin.py b/maubot/plugin.py index 850a45e..691bf1e 100644 --- a/maubot/plugin.py +++ b/maubot/plugin.py @@ -45,22 +45,27 @@ class PluginInstance: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") - self.db_instance.enabled = False + self.enabled = False return self.client = Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") - self.db_instance.enabled = False + self.enabled = False + self.log.debug("Plugin instance dependencies loaded") async def start(self) -> None: - self.log.debug(f"Starting...") + if not self.enabled: + self.log.warn(f"Plugin disabled, not starting.") + return cls = self.loader.load() self.plugin = cls(self.client.client, self.id, self.log) self.loader.references |= {self} await self.plugin.start() + self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} " + f"with user {self.client.id}") async def stop(self) -> None: - self.log.debug("Stopping...") + self.log.debug("Stopping plugin instance...") self.loader.references -= {self} await self.plugin.stop() self.plugin = None