diff --git a/example-config.yaml b/example-config.yaml
index a951ae9..3920e3e 100644
--- a/example-config.yaml
+++ b/example-config.yaml
@@ -10,8 +10,9 @@ plugin_directories:
- ./plugins
server:
- # The IP:port to listen to.
- listen: 0.0.0.0:29316
+ # The IP and port to listen to.
+ hostname: 0.0.0.0
+ port: 29316
# The base management API path.
base_path: /_matrix/maubot
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
diff --git a/maubot/__main__.py b/maubot/__main__.py
index 2b97589..76dc72a 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -17,9 +17,14 @@ from sqlalchemy import orm
import sqlalchemy as sql
import logging.config
import argparse
+import asyncio
import copy
+import sys
from .config import Config
+from .db import Base, init as init_db
+from .server import MaubotServer
+from .client import Client, init as init_client
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@@ -36,7 +41,23 @@ logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot")
log.debug(f"Initializing maubot {__version__}")
-db_engine = sql.create_engine(config["database"])
+db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind=db_engine
+
+loop = asyncio.get_event_loop()
+
+init_db(db_session)
+init_client(loop)
+server = MaubotServer(config, loop)
+
+try:
+ loop.run_until_complete(server.start())
+ loop.run_forever()
+except KeyboardInterrupt:
+ log.debug("Keyboard interrupt received, stopping...")
+ for client in Client.cache.values():
+ client.stop()
+ loop.run_until_complete(server.stop())
+ sys.exit(0)
diff --git a/maubot/client.py b/maubot/client.py
index 30cccff..f4ed9c3 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -13,62 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, List, Optional, Union, Callable
+from typing import Dict, List, Optional
from aiohttp import ClientSession
import asyncio
import logging
-from mautrix import Client as MatrixClient
-from mautrix.client import EventHandler
-from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
- EventType, MessageEvent)
+from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType
-from .command_spec import ParsedCommand
from .db import DBClient
+from .matrix import MaubotMatrixClient
log = logging.getLogger("maubot.client")
-class MaubotMatrixClient(MatrixClient):
- def __init__(self, maubot_client: 'Client', *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._maubot_client = maubot_client
- self.command_handlers: Dict[str, List[EventHandler]] = {}
- self.commands: List[ParsedCommand] = []
-
- self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
-
- async def _command_event_handler(self, evt: MessageEvent) -> None:
- for command in self.commands:
- if command.match(evt):
- await self._trigger_command(command, evt)
- return
-
- async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
- for handler in self.command_handlers.get(command.name, []):
- await handler(evt)
-
- def on(self, var: Union[EventHandler, EventType, str]
- ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
- if isinstance(var, str):
- def decorator(func: EventHandler) -> EventHandler:
- self.add_command_handler(var, func)
- return func
-
- return decorator
- return super().on(var)
-
- def add_command_handler(self, command: str, handler: EventHandler) -> None:
- self.command_handlers.setdefault(command, []).append(handler)
-
- def remove_command_handler(self, command: str, handler: EventHandler) -> None:
- try:
- self.command_handlers[command].remove(handler)
- except (KeyError, ValueError):
- pass
-
-
class Client:
+ loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
@@ -78,26 +37,33 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
- self.client = MaubotMatrixClient(maubot_client=self,
- store=self.db_instance,
- mxid=self.id,
- base_url=self.homeserver,
- token=self.access_token,
- client_session=self.http_client,
+ self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
+ mxid=self.id, base_url=self.homeserver,
+ token=self.access_token, client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
+ def start(self) -> None:
+ asyncio.ensure_future(self.client.start(), loop=self.loop)
+
+ def stop(self) -> None:
+ self.client.stop()
+
@classmethod
- def get(cls, user_id: UserID) -> Optional['Client']:
+ def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
return cls.cache[user_id]
except KeyError:
- db_instance = DBClient.query.get(user_id)
+ db_instance = db_instance or DBClient.query.get(user_id)
if not db_instance:
return None
return Client(db_instance)
+ @classmethod
+ def all(cls) -> List['Client']:
+ return [cls.get(user.id, user) for user in DBClient.query.all()]
+
# region Properties
@property
@@ -176,3 +142,9 @@ class Client:
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
+
+
+def init(loop: asyncio.AbstractEventLoop) -> None:
+ Client.loop = loop
+ for client in Client.all():
+ client.start()
diff --git a/maubot/db.py b/maubot/db.py
index 9c4ccc1..fd4e4cc 100644
--- a/maubot/db.py
+++ b/maubot/db.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
-from sqlalchemy.orm import Query
+from sqlalchemy.orm import Query, scoped_session
from sqlalchemy.ext.declarative import declarative_base
import json
@@ -89,3 +89,9 @@ class DBCommandSpec(Base):
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
+
+
+def init(session: scoped_session) -> None:
+ DBPlugin.query = session.query_property()
+ DBClient.query = session.query_property()
+ DBCommandSpec.query = session.query_property()
diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py
index e7b323d..71c6ce3 100644
--- a/maubot/loader/abc.py
+++ b/maubot/loader/abc.py
@@ -13,14 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import TypeVar, Type
+from typing import TypeVar, Type, Dict
from abc import ABC, abstractmethod
+
from ..plugin_base import Plugin
PluginClass = TypeVar("PluginClass", bound=Plugin)
class PluginLoader(ABC):
+ id_cache: Dict[str, 'PluginLoader'] = {}
+
id: str
version: str
diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py
index bb44979..a69ccec 100644
--- a/maubot/loader/zip.py
+++ b/maubot/loader/zip.py
@@ -29,7 +29,6 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
- id_cache: Dict[str, 'ZippedPluginLoader'] = {}
path: str
id: str
diff --git a/maubot/matrix.py b/maubot/matrix.py
new file mode 100644
index 0000000..11b6eba
--- /dev/null
+++ b/maubot/matrix.py
@@ -0,0 +1,77 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2018 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from typing import Dict, List, Union, Callable
+
+from mautrix import Client as MatrixClient
+from mautrix.client import EventHandler
+from mautrix.types import EventType, MessageEvent
+
+from .command_spec import ParsedCommand, CommandSpec
+
+
+class MaubotMatrixClient(MatrixClient):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.command_handlers: Dict[str, List[EventHandler]] = {}
+ self.commands: List[ParsedCommand] = []
+ self.command_specs: Dict[str, CommandSpec] = {}
+
+ self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
+
+ def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
+ self.command_specs[plugin_id] = spec
+ self._reparse_command_specs()
+
+ def _reparse_command_specs(self) -> None:
+ self.commands = [parsed_command
+ for spec in self.command_specs.values()
+ for parsed_command in spec.parse()]
+
+ def remove_command_spec(self, plugin_id: str) -> None:
+ try:
+ del self.command_specs[plugin_id]
+ self._reparse_command_specs()
+ except KeyError:
+ pass
+
+ async def _command_event_handler(self, evt: MessageEvent) -> None:
+ for command in self.commands:
+ if command.match(evt):
+ await self._trigger_command(command, evt)
+ return
+
+ async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
+ for handler in self.command_handlers.get(command.name, []):
+ await handler(evt)
+
+ def on(self, var: Union[EventHandler, EventType, str]
+ ) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
+ if isinstance(var, str):
+ def decorator(func: EventHandler) -> EventHandler:
+ self.add_command_handler(var, func)
+ return func
+
+ return decorator
+ return super().on(var)
+
+ def add_command_handler(self, command: str, handler: EventHandler) -> None:
+ self.command_handlers.setdefault(command, []).append(handler)
+
+ def remove_command_handler(self, command: str, handler: EventHandler) -> None:
+ try:
+ self.command_handlers[command].remove(handler)
+ except (KeyError, ValueError):
+ pass
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index 0f22c00..69dedf5 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -22,11 +22,12 @@ if TYPE_CHECKING:
class Plugin(ABC):
- def __init__(self, client: 'MaubotMatrixClient') -> None:
+ def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
self.client = client
+ self.id = plugin_instance_id
def set_command_spec(self, spec: 'CommandSpec') -> None:
- pass
+ self.client.set_command_spec(self.id, spec)
async def start(self) -> None:
pass
diff --git a/maubot/server.py b/maubot/server.py
new file mode 100644
index 0000000..55a4b4a
--- /dev/null
+++ b/maubot/server.py
@@ -0,0 +1,54 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2018 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from aiohttp import web
+import asyncio
+
+from mautrix.api import PathBuilder
+
+from .config import Config
+from .__meta__ import __version__
+
+
+class MaubotServer:
+ def __init__(self, config: Config, loop: asyncio.AbstractEventLoop):
+ self.loop = loop or asyncio.get_event_loop()
+ self.app = web.Application(loop=self.loop)
+ self.config = config
+
+ path = PathBuilder(config["server.base_path"])
+ self.app.router.add_get(path.version, self.version)
+
+ as_path = PathBuilder(config["server.appservice_base_path"])
+ self.app.router.add_put(as_path.transactions, self.handle_transaction)
+
+ self.runner = web.AppRunner(self.app)
+
+ async def start(self) -> None:
+ await self.runner.setup()
+ site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
+ await site.start()
+
+ async def stop(self) -> None:
+ await self.runner.cleanup()
+
+ @staticmethod
+ async def version(_: web.Request) -> web.Response:
+ return web.json_response({
+ "version": __version__
+ })
+
+ async def handle_transaction(self, request: web.Request) -> web.Response:
+ return web.Response(status=501)