diff --git a/maubot/__init__.py b/maubot/__init__.py
index 9c3cbb8..5dd2035 100644
--- a/maubot/__init__.py
+++ b/maubot/__init__.py
@@ -1,3 +1,4 @@
from .plugin_base import Plugin
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
from .event import FakeEvent as Event
+from .client import MaubotMatrixClient as Client
diff --git a/maubot/client.py b/maubot/client.py
index 539ca54..30cccff 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -13,33 +13,80 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, Optional
+from typing import Dict, List, Optional, Union, Callable
from aiohttp import ClientSession
+import asyncio
import logging
from mautrix import Client as MatrixClient
+from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
- EventType)
+ EventType, MessageEvent)
+from .command_spec import ParsedCommand
from .db import DBClient
log = logging.getLogger("maubot.client")
+class MaubotMatrixClient(MatrixClient):
+ def __init__(self, maubot_client: 'Client', *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._maubot_client = maubot_client
+ self.command_handlers: Dict[str, List[EventHandler]] = {}
+ self.commands: List[ParsedCommand] = []
+
+ self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
+
+ async def _command_event_handler(self, evt: MessageEvent) -> None:
+ for command in self.commands:
+ if command.match(evt):
+ await self._trigger_command(command, evt)
+ return
+
+ async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
+ for handler in self.command_handlers.get(command.name, []):
+ await handler(evt)
+
+ def on(self, var: Union[EventHandler, EventType, str]
+ ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
+ if isinstance(var, str):
+ def decorator(func: EventHandler) -> EventHandler:
+ self.add_command_handler(var, func)
+ return func
+
+ return decorator
+ return super().on(var)
+
+ def add_command_handler(self, command: str, handler: EventHandler) -> None:
+ self.command_handlers.setdefault(command, []).append(handler)
+
+ def remove_command_handler(self, command: str, handler: EventHandler) -> None:
+ try:
+ self.command_handlers[command].remove(handler)
+ except (KeyError, ValueError):
+ pass
+
+
class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
+ db_instance: DBClient
+ client: MaubotMatrixClient
+
def __init__(self, db_instance: DBClient) -> None:
- self.db_instance: DBClient = db_instance
+ self.db_instance = db_instance
self.cache[self.id] = self
- self.client: MatrixClient = MatrixClient(mxid=self.id,
- base_url=self.homeserver,
- token=self.access_token,
- client_session=self.http_client,
- log=log.getChild(self.id))
+ self.client = MaubotMatrixClient(maubot_client=self,
+ store=self.db_instance,
+ mxid=self.id,
+ base_url=self.homeserver,
+ token=self.access_token,
+ client_session=self.http_client,
+ log=log.getChild(self.id))
if self.autojoin:
- self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
+ self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
@classmethod
def get(cls, user_id: UserID) -> Optional['Client']:
@@ -103,9 +150,9 @@ class Client:
if value == self.db_instance.autojoin:
return
if value:
- self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
+ self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
else:
- self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
+ self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.db_instance.autojoin = value
@property
@@ -126,6 +173,6 @@ class Client:
# endregion
- async def handle_invite(self, evt: StateEvent) -> None:
+ async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
diff --git a/maubot/command_spec.py b/maubot/command_spec.py
index 8ac6e23..8a8dc6a 100644
--- a/maubot/command_spec.py
+++ b/maubot/command_spec.py
@@ -13,10 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import List, Dict
+from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
from attr import dataclass
+import re
-from mautrix.types import Event
+from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
from mautrix.client.api.types.util import SerializableAttrs
@@ -39,7 +40,103 @@ class PassiveCommand(SerializableAttrs['PassiveCommand']):
name: str
matches: str
match_against: str
- match_event: Event = None
+ match_event: MessageEvent = None
+
+
+class ParsedCommand:
+ name: str
+ is_passive: bool
+ arguments: List[str]
+ starts_with: str
+ matches: Pattern
+ match_against: str
+ match_event: MessageEvent
+
+ def __init__(self, command: Union[PassiveCommand, Command]) -> None:
+ if isinstance(command, PassiveCommand):
+ self._init_passive(command)
+ elif isinstance(command, Command):
+ self._init_active(command)
+ else:
+ raise ValueError("Command parameter must be a Command or a PassiveCommand.")
+
+ def _init_passive(self, command: PassiveCommand) -> None:
+ self.name = command.name
+ self.is_passive = True
+ self.match_against = command.match_against
+ self.matches = re.compile(command.matches)
+ self.match_event = command.match_event
+
+ def _init_active(self, command: Command) -> None:
+ self.name = command.syntax
+ self.is_passive = False
+
+ regex_builder = []
+ sw_builder = []
+ argument_encountered = False
+
+ for word in command.syntax.split(" "):
+ arg = command.arguments.get(word, None)
+ if arg is not None and len(word) > 0:
+ argument_encountered = True
+ regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?"
+ self.arguments.append(word)
+ regex_builder.append(regex)
+ else:
+ if not argument_encountered:
+ sw_builder.append(word)
+ regex_builder.append(re.escape(word))
+ self.starts_with = "!" + " ".join(sw_builder)
+ self.matches = re.compile("^!" + " ".join(regex_builder) + "$")
+ self.match_against = "body"
+
+ def match(self, evt: MessageEvent) -> bool:
+ return self._match_passive(evt) if self.is_passive else self._match_active(evt)
+
+ @staticmethod
+ def _parse_key(key: str) -> Tuple[str, Optional[str]]:
+ if '.' not in key:
+ return key, None
+ key, next_key = key.split('.', 1)
+ if len(key) > 0 and key[0] == "[":
+ end_index = next_key.index("]")
+ key = key[1:] + "." + next_key[:end_index]
+ next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
+ return key, next_key
+
+ @classmethod
+ def _recursive_get(cls, data: Any, key: str) -> Any:
+ if not data:
+ return None
+ key, next_key = cls._parse_key(key)
+ if next_key is not None:
+ return cls._recursive_get(data[key], next_key)
+ return data[key]
+
+ def _match_passive(self, evt: MessageEvent) -> bool:
+ try:
+ match_against = self._recursive_get(evt.content, self.match_against)
+ except KeyError:
+ match_against = None
+ match_against = match_against or evt.content.body
+ matches = [[match.string[match.start():match.end()]] + list(match.groups())
+ for match in self.matches.finditer(match_against)]
+ if not matches:
+ return False
+ if evt.unsigned.passive_command is None:
+ evt.unsigned.passive_command = {}
+ evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
+ return True
+
+ def _match_active(self, evt: MessageEvent) -> bool:
+ if not evt.content.body.startswith(self.starts_with):
+ return False
+ match = self.matches.match(evt.content.body)
+ if not match:
+ return False
+ evt.content.command = MatchedCommand(matched=self.name,
+ arguments=dict(zip(self.arguments, match.groups())))
+ return True
@dataclass
@@ -50,3 +147,6 @@ class CommandSpec(SerializableAttrs['CommandSpec']):
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
return CommandSpec(commands=self.commands + other.commands,
passive_commands=self.passive_commands + other.passive_commands)
+
+ def parse(self) -> List[ParsedCommand]:
+ return [ParsedCommand(command) for command in self.commands + self.passive_commands]
diff --git a/maubot/db.py b/maubot/db.py
index 841dbe8..9c4ccc1 100644
--- a/maubot/db.py
+++ b/maubot/db.py
@@ -13,31 +13,39 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
from sqlalchemy.orm import Query
from sqlalchemy.ext.declarative import declarative_base
import json
-from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI
+from mautrix.types import UserID, FilterID, SyncToken, ContentURI
+from mautrix.client.api.types.util import Serializable
+from mautrix import ClientStore
+
+from .command_spec import CommandSpec
Base: declarative_base = declarative_base()
-class JSONEncodedDict(TypeDecorator):
- impl = Text
+def make_serializable_alchemy(serializable_type: Type[Serializable]):
+ class SerializableAlchemy(TypeDecorator):
+ impl = Text
- @property
- def python_type(self):
- return dict
+ @property
+ def python_type(self):
+ return serializable_type
- def process_literal_param(self, value, _):
- return json.dumps(value) if value is not None else None
+ def process_literal_param(self, value: Serializable, _) -> str:
+ return json.dumps(value.serialize()) if value is not None else None
- def process_bind_param(self, value, _):
- return json.dumps(value) if value is not None else None
+ def process_bind_param(self, value: Serializable, _) -> str:
+ return json.dumps(value.serialize()) if value is not None else None
- def process_result_value(self, value, _):
- return json.loads(value) if value is not None else None
+ def process_result_value(self, value: str, _) -> serializable_type:
+ return serializable_type.deserialize(json.loads(value)) if value is not None else None
+
+ return SerializableAlchemy
class DBPlugin(Base):
@@ -52,7 +60,7 @@ class DBPlugin(Base):
nullable=False)
-class DBClient(Base):
+class DBClient(ClientStore, Base):
query: Query
__tablename__ = "client"
@@ -74,10 +82,10 @@ class DBCommandSpec(Base):
query: Query
__tablename__ = "command_spec"
- owner: str = Column(String(255),
- ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
- primary_key=True)
+ plugin: str = Column(String(255),
+ ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
+ primary_key=True)
client: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
- spec: JSON = Column(JSONEncodedDict, nullable=False)
+ spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index 489352a..0f22c00 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -17,13 +17,17 @@ from typing import TYPE_CHECKING
from abc import ABC
if TYPE_CHECKING:
- from mautrix import Client as MatrixClient
+ from .client import MaubotMatrixClient
+ from .command_spec import CommandSpec
class Plugin(ABC):
- def __init__(self, client: 'MatrixClient') -> None:
+ def __init__(self, client: 'MaubotMatrixClient') -> None:
self.client = client
+ def set_command_spec(self, spec: 'CommandSpec') -> None:
+ pass
+
async def start(self) -> None:
pass
diff --git a/requirements.txt b/requirements.txt
index 318c4e2..f8671cd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,3 +4,4 @@ SQLAlchemy
alembic
commonmark
ruamel.yaml
+attrs
diff --git a/setup.py b/setup.py
index ca44f69..aaae281 100644
--- a/setup.py
+++ b/setup.py
@@ -27,6 +27,7 @@ setuptools.setup(
"alembic>=1.0.0,<2",
"commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16",
+ "attrs>=18.2.0,<19",
],
classifiers=[