diff --git a/maubot/__init__.py b/maubot/__init__.py index 9c3cbb8..5dd2035 100644 --- a/maubot/__init__.py +++ b/maubot/__init__.py @@ -1,3 +1,4 @@ from .plugin_base import Plugin from .command_spec import CommandSpec, Command, PassiveCommand, Argument from .event import FakeEvent as Event +from .client import MaubotMatrixClient as Client diff --git a/maubot/client.py b/maubot/client.py index 539ca54..30cccff 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -13,33 +13,80 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Optional +from typing import Dict, List, Optional, Union, Callable from aiohttp import ClientSession +import asyncio import logging from mautrix import Client as MatrixClient +from mautrix.client import EventHandler from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, - EventType) + EventType, MessageEvent) +from .command_spec import ParsedCommand from .db import DBClient log = logging.getLogger("maubot.client") +class MaubotMatrixClient(MatrixClient): + def __init__(self, maubot_client: 'Client', *args, **kwargs): + super().__init__(*args, **kwargs) + self._maubot_client = maubot_client + self.command_handlers: Dict[str, List[EventHandler]] = {} + self.commands: List[ParsedCommand] = [] + + self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE) + + async def _command_event_handler(self, evt: MessageEvent) -> None: + for command in self.commands: + if command.match(evt): + await self._trigger_command(command, evt) + return + + async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None: + for handler in self.command_handlers.get(command.name, []): + await handler(evt) + + def on(self, var: Union[EventHandler, EventType, str] + ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]: + if isinstance(var, str): + def decorator(func: EventHandler) -> EventHandler: + self.add_command_handler(var, func) + return func + + return decorator + return super().on(var) + + def add_command_handler(self, command: str, handler: EventHandler) -> None: + self.command_handlers.setdefault(command, []).append(handler) + + def remove_command_handler(self, command: str, handler: EventHandler) -> None: + try: + self.command_handlers[command].remove(handler) + except (KeyError, ValueError): + pass + + class Client: cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None + db_instance: DBClient + client: MaubotMatrixClient + def __init__(self, db_instance: DBClient) -> None: - self.db_instance: DBClient = db_instance + self.db_instance = db_instance self.cache[self.id] = self - self.client: MatrixClient = MatrixClient(mxid=self.id, - base_url=self.homeserver, - token=self.access_token, - client_session=self.http_client, - log=log.getChild(self.id)) + self.client = MaubotMatrixClient(maubot_client=self, + store=self.db_instance, + mxid=self.id, + base_url=self.homeserver, + token=self.access_token, + client_session=self.http_client, + log=log.getChild(self.id)) if self.autojoin: - self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) + self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) @classmethod def get(cls, user_id: UserID) -> Optional['Client']: @@ -103,9 +150,9 @@ class Client: if value == self.db_instance.autojoin: return if value: - self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) + self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) else: - self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER) + self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER) self.db_instance.autojoin = value @property @@ -126,6 +173,6 @@ class Client: # endregion - async def handle_invite(self, evt: StateEvent) -> None: + async def _handle_invite(self, evt: StateEvent) -> None: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room_by_id(evt.room_id) diff --git a/maubot/command_spec.py b/maubot/command_spec.py index 8ac6e23..8a8dc6a 100644 --- a/maubot/command_spec.py +++ b/maubot/command_spec.py @@ -13,10 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Dict +from typing import List, Dict, Pattern, Union, Tuple, Optional, Any from attr import dataclass +import re -from mautrix.types import Event +from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand from mautrix.client.api.types.util import SerializableAttrs @@ -39,7 +40,103 @@ class PassiveCommand(SerializableAttrs['PassiveCommand']): name: str matches: str match_against: str - match_event: Event = None + match_event: MessageEvent = None + + +class ParsedCommand: + name: str + is_passive: bool + arguments: List[str] + starts_with: str + matches: Pattern + match_against: str + match_event: MessageEvent + + def __init__(self, command: Union[PassiveCommand, Command]) -> None: + if isinstance(command, PassiveCommand): + self._init_passive(command) + elif isinstance(command, Command): + self._init_active(command) + else: + raise ValueError("Command parameter must be a Command or a PassiveCommand.") + + def _init_passive(self, command: PassiveCommand) -> None: + self.name = command.name + self.is_passive = True + self.match_against = command.match_against + self.matches = re.compile(command.matches) + self.match_event = command.match_event + + def _init_active(self, command: Command) -> None: + self.name = command.syntax + self.is_passive = False + + regex_builder = [] + sw_builder = [] + argument_encountered = False + + for word in command.syntax.split(" "): + arg = command.arguments.get(word, None) + if arg is not None and len(word) > 0: + argument_encountered = True + regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?" + self.arguments.append(word) + regex_builder.append(regex) + else: + if not argument_encountered: + sw_builder.append(word) + regex_builder.append(re.escape(word)) + self.starts_with = "!" + " ".join(sw_builder) + self.matches = re.compile("^!" + " ".join(regex_builder) + "$") + self.match_against = "body" + + def match(self, evt: MessageEvent) -> bool: + return self._match_passive(evt) if self.is_passive else self._match_active(evt) + + @staticmethod + def _parse_key(key: str) -> Tuple[str, Optional[str]]: + if '.' not in key: + return key, None + key, next_key = key.split('.', 1) + if len(key) > 0 and key[0] == "[": + end_index = next_key.index("]") + key = key[1:] + "." + next_key[:end_index] + next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None + return key, next_key + + @classmethod + def _recursive_get(cls, data: Any, key: str) -> Any: + if not data: + return None + key, next_key = cls._parse_key(key) + if next_key is not None: + return cls._recursive_get(data[key], next_key) + return data[key] + + def _match_passive(self, evt: MessageEvent) -> bool: + try: + match_against = self._recursive_get(evt.content, self.match_against) + except KeyError: + match_against = None + match_against = match_against or evt.content.body + matches = [[match.string[match.start():match.end()]] + list(match.groups()) + for match in self.matches.finditer(match_against)] + if not matches: + return False + if evt.unsigned.passive_command is None: + evt.unsigned.passive_command = {} + evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches) + return True + + def _match_active(self, evt: MessageEvent) -> bool: + if not evt.content.body.startswith(self.starts_with): + return False + match = self.matches.match(evt.content.body) + if not match: + return False + evt.content.command = MatchedCommand(matched=self.name, + arguments=dict(zip(self.arguments, match.groups()))) + return True @dataclass @@ -50,3 +147,6 @@ class CommandSpec(SerializableAttrs['CommandSpec']): def __add__(self, other: 'CommandSpec') -> 'CommandSpec': return CommandSpec(commands=self.commands + other.commands, passive_commands=self.passive_commands + other.passive_commands) + + def parse(self) -> List[ParsedCommand]: + return [ParsedCommand(command) for command in self.commands + self.passive_commands] diff --git a/maubot/db.py b/maubot/db.py index 841dbe8..9c4ccc1 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -13,31 +13,39 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Type from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator) from sqlalchemy.orm import Query from sqlalchemy.ext.declarative import declarative_base import json -from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI +from mautrix.types import UserID, FilterID, SyncToken, ContentURI +from mautrix.client.api.types.util import Serializable +from mautrix import ClientStore + +from .command_spec import CommandSpec Base: declarative_base = declarative_base() -class JSONEncodedDict(TypeDecorator): - impl = Text +def make_serializable_alchemy(serializable_type: Type[Serializable]): + class SerializableAlchemy(TypeDecorator): + impl = Text - @property - def python_type(self): - return dict + @property + def python_type(self): + return serializable_type - def process_literal_param(self, value, _): - return json.dumps(value) if value is not None else None + def process_literal_param(self, value: Serializable, _) -> str: + return json.dumps(value.serialize()) if value is not None else None - def process_bind_param(self, value, _): - return json.dumps(value) if value is not None else None + def process_bind_param(self, value: Serializable, _) -> str: + return json.dumps(value.serialize()) if value is not None else None - def process_result_value(self, value, _): - return json.loads(value) if value is not None else None + def process_result_value(self, value: str, _) -> serializable_type: + return serializable_type.deserialize(json.loads(value)) if value is not None else None + + return SerializableAlchemy class DBPlugin(Base): @@ -52,7 +60,7 @@ class DBPlugin(Base): nullable=False) -class DBClient(Base): +class DBClient(ClientStore, Base): query: Query __tablename__ = "client" @@ -74,10 +82,10 @@ class DBCommandSpec(Base): query: Query __tablename__ = "command_spec" - owner: str = Column(String(255), - ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"), - primary_key=True) + plugin: str = Column(String(255), + ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"), + primary_key=True) client: UserID = Column(String(255), ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True) - spec: JSON = Column(JSONEncodedDict, nullable=False) + spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False) diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 489352a..0f22c00 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -17,13 +17,17 @@ from typing import TYPE_CHECKING from abc import ABC if TYPE_CHECKING: - from mautrix import Client as MatrixClient + from .client import MaubotMatrixClient + from .command_spec import CommandSpec class Plugin(ABC): - def __init__(self, client: 'MatrixClient') -> None: + def __init__(self, client: 'MaubotMatrixClient') -> None: self.client = client + def set_command_spec(self, spec: 'CommandSpec') -> None: + pass + async def start(self) -> None: pass diff --git a/requirements.txt b/requirements.txt index 318c4e2..f8671cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ SQLAlchemy alembic commonmark ruamel.yaml +attrs diff --git a/setup.py b/setup.py index ca44f69..aaae281 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ setuptools.setup( "alembic>=1.0.0,<2", "commonmark>=0.8.1,<1", "ruamel.yaml>=0.15.35,<0.16", + "attrs>=18.2.0,<19", ], classifiers=[