diff --git a/docker/example-config.yaml b/docker/example-config.yaml index 64176c8..77a86b9 100644 --- a/docker/example-config.yaml +++ b/docker/example-config.yaml @@ -22,10 +22,14 @@ server: # The IP and port to listen to. hostname: 0.0.0.0 port: 29316 + # Public base URL where the server is visible. + public_url: https://example.com # The base management API path. base_path: /_matrix/maubot/v1 # The base path for the UI. ui_base_path: /_matrix/maubot + # The base path for plugin endpoints. The instance ID will be appended directly. + plugin_base_path: /_matrix/maubot/plugin/ # Override path from where to load UI resources. # Set to false to using pkg_resources to find the path. override_resource_path: /opt/maubot/frontend diff --git a/example-config.yaml b/example-config.yaml index d4701b8..6579918 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -22,10 +22,14 @@ server: # The IP and port to listen to. hostname: 0.0.0.0 port: 29316 + # Public base URL where the server is visible. + public_url: https://example.com # The base management API path. base_path: /_matrix/maubot/v1 # The base path for the UI. ui_base_path: /_matrix/maubot + # The base path for plugin endpoints. The instance ID will be appended directly. + plugin_base_path: /_matrix/maubot/plugin/ # Override path from where to load UI resources. # Set to false to using pkg_resources to find the path. override_resource_path: false diff --git a/maubot/__main__.py b/maubot/__main__.py index 2a6a3f2..586091e 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -58,10 +58,9 @@ loop = asyncio.get_event_loop() init_zip_loader(config) db_session = init_db(config) clients = init_client_class(db_session, loop) -plugins = init_plugin_instance_class(db_session, config, loop) management_api = init_mgmt_api(config, loop) -server = MaubotServer(config, loop) -server.app.add_subapp(config["server.base_path"], management_api) +server = MaubotServer(management_api, config, loop) +plugins = init_plugin_instance_class(db_session, config, server, loop) for plugin in plugins: plugin.load() diff --git a/maubot/config.py b/maubot/config.py index f3b56e3..57e6552 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -38,9 +38,11 @@ class Config(BaseFileConfig): copy("plugin_directories.db") copy("server.hostname") copy("server.port") + copy("server.public_url") copy("server.listen") copy("server.base_path") copy("server.ui_base_path") + copy("server.plugin_base_path") copy("server.override_resource_path") copy("server.appservice_base_path") shared_secret = self["server.unshared_secret"] diff --git a/maubot/instance.py b/maubot/instance.py index fd6fa57..114858d 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -13,8 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING from asyncio import AbstractEventLoop +from aiohttp import web import os.path import logging import io @@ -33,6 +34,9 @@ from .client import Client from .loader import PluginLoader, ZippedPluginLoader from .plugin_base import Plugin +if TYPE_CHECKING: + from .server import MaubotServer, PluginWebApp + log = logging.getLogger("maubot.instance") yaml = YAML() @@ -41,6 +45,7 @@ yaml.indent(4) class PluginInstance: db: Session = None + webserver: 'MaubotServer' = None mb_config: Config = None loop: AbstractEventLoop = None cache: Dict[str, 'PluginInstance'] = {} @@ -54,6 +59,8 @@ class PluginInstance: base_cfg: RecursiveDict[CommentedMap] inst_db: sql.engine.Engine inst_db_tables: Dict[str, sql.Table] + inst_webapp: 'PluginWebApp' + inst_webapp_url: str started: bool def __init__(self, db_instance: DBPlugin): @@ -66,6 +73,8 @@ class PluginInstance: self.plugin = None self.inst_db = None self.inst_db_tables = None + self.inst_webapp = None + self.inst_webapp_url = None self.base_cfg = None self.cache[self.id] = self @@ -105,6 +114,8 @@ class PluginInstance: if self.loader.meta.database: db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") + if self.loader.meta.webapp: + self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id) self.log.debug("Plugin instance dependencies loaded") self.loader.references.add(self) self.client.references.add(self) @@ -126,6 +137,8 @@ class PluginInstance: ZippedPluginLoader.trash( os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), reason="deleted") + if self.inst_webapp: + self.webserver.remove_instance_webapp(self.id) def load_config(self) -> CommentedMap: return yaml.load(self.db_instance.config) @@ -157,7 +170,8 @@ class PluginInstance: self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config) self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client, instance_id=self.id, log=self.log, config=self.config, - database=self.inst_db) + database=self.inst_db, webapp=self.inst_webapp, + webapp_url=self.inst_webapp_url) try: await self.plugin.start() except Exception: @@ -274,8 +288,10 @@ class PluginInstance: # endregion -def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]: +def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[ + PluginInstance]: PluginInstance.db = db PluginInstance.mb_config = config PluginInstance.loop = loop + PluginInstance.webserver = webserver return PluginInstance.all() diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index f4b62a7..4f0c5ed 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -57,6 +57,7 @@ class PluginMeta(SerializableAttrs['PluginMeta']): maubot: Version = Version(__version__) database: bool = False + webapp: bool = False license: str = "" extra_files: List[str] = [] dependencies: List[str] = [] diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index fbc7891..47ae7b4 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -17,7 +17,6 @@ from typing import Type, Optional, TYPE_CHECKING from abc import ABC from logging import Logger from asyncio import AbstractEventLoop -import functools from sqlalchemy.engine.base import Engine from aiohttp import ClientSession @@ -25,6 +24,7 @@ from aiohttp import ClientSession if TYPE_CHECKING: from mautrix.util.config import BaseProxyConfig from .client import MaubotMatrixClient + from .plugin_server import PluginWebApp class Plugin(ABC): @@ -34,10 +34,13 @@ class Plugin(ABC): loop: AbstractEventLoop config: Optional['BaseProxyConfig'] database: Optional[Engine] + webapp: Optional['PluginWebApp'] + webapp_url: Optional[str] def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, instance_id: str, log: Logger, config: Optional['BaseProxyConfig'], - database: Optional[Engine]) -> None: + database: Optional[Engine], webapp: Optional['PluginWebApp'], + webapp_url: Optional[str]) -> None: self.client = client self.loop = loop self.http = http @@ -45,6 +48,8 @@ class Plugin(ABC): self.log = log self.config = config self.database = database + self.webapp = webapp + self.webapp_url = webapp_url self._handlers_at_startup = [] async def start(self) -> None: diff --git a/maubot/plugin_server.py b/maubot/plugin_server.py new file mode 100644 index 0000000..a5dd49a --- /dev/null +++ b/maubot/plugin_server.py @@ -0,0 +1,87 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import List, Callable, Awaitable +from functools import partial + +from aiohttp import web, hdrs +from yarl import URL + +Handler = Callable[[web.Request], Awaitable[web.Response]] +Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]] + + +class PluginWebApp(web.UrlDispatcher): + def __init__(self): + super().__init__() + self._middleware: List[Middleware] = [] + + def add_middleware(self, middleware: Middleware) -> None: + self._middleware.append(middleware) + + def remove_middleware(self, middleware: Middleware) -> None: + self._middleware.remove(middleware) + + def clear(self) -> None: + self._resources = [] + self._named_resources = {} + self._middleware = [] + + async def handle(self, request: web.Request) -> web.Response: + match_info = await self.resolve(request) + match_info.freeze() + resp = None + request._match_info = match_info + expect = request.headers.get(hdrs.EXPECT) + if expect: + resp = await match_info.expect_handler(request) + await request.writer.drain() + if resp is None: + handler = match_info.handler + for middleware in self._middleware: + handler = partial(middleware, handler=handler) + resp = await handler(request) + return resp + + +class PrefixResource(web.Resource): + def __init__(self, prefix, *, name=None): + assert not prefix or prefix.startswith('/'), prefix + assert prefix in ('', '/') or not prefix.endswith('/'), prefix + super().__init__(name=name) + self._prefix = URL.build(path=prefix).raw_path + + @property + def canonical(self): + return self._prefix + + def get_info(self): + return {'path': self._prefix} + + def url_for(self): + return URL.build(path=self._prefix, encoded=True) + + def add_prefix(self, prefix): + assert prefix.startswith('/') + assert not prefix.endswith('/') + assert len(prefix) > 1 + self._prefix = prefix + self._prefix + + def _match(self, path: str) -> dict: + return {} if self.raw_match(path) else None + + def raw_match(self, path: str) -> bool: + return path and path.startswith(self._prefix) + diff --git a/maubot/server.py b/maubot/server.py index 35682ee..9a10cd6 100644 --- a/maubot/server.py +++ b/maubot/server.py @@ -13,16 +13,18 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Tuple, Dict import logging import asyncio -from aiohttp import web +from aiohttp import web, hdrs from aiohttp.abc import AbstractAccessLogger import pkg_resources from mautrix.api import PathBuilder, Method from .config import Config +from .plugin_server import PrefixResource, PluginWebApp from .__meta__ import __version__ @@ -35,19 +37,57 @@ class AccessLogger(AbstractAccessLogger): class MaubotServer: log: logging.Logger = logging.getLogger("maubot.server") + plugin_routes: Dict[str, PluginWebApp] - def __init__(self, config: Config, loop: asyncio.AbstractEventLoop) -> None: + def __init__(self, management_api: web.Application, config: Config, + loop: asyncio.AbstractEventLoop) -> None: self.loop = loop or asyncio.get_event_loop() - self.app = web.Application(loop=self.loop, client_max_size=100*1024*1024) + self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024) self.config = config - as_path = PathBuilder(config["server.appservice_base_path"]) - self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) - + self.setup_appservice() + self.app.add_subapp(config["server.base_path"], management_api) + self.setup_instance_subapps() self.setup_management_ui() self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) + async def handle_plugin_path(self, request: web.Request) -> web.Response: + for path, app in self.plugin_routes.items(): + if request.path.startswith(path): + request = request.clone(rel_url=request.rel_url + .with_path(request.rel_url.path[len(path):]) + .with_query(request.query_string)) + return await app.handle(request) + return web.Response(status=404) + + def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]: + subpath = self.config["server.plugin_base_path"] + instance_id + url = self.config["server.public_url"] + subpath + try: + return self.plugin_routes[subpath], url + except KeyError: + app = PluginWebApp() + self.plugin_routes[subpath] = app + return app, url + + def remove_instance_webapp(self, instance_id: str) -> None: + try: + subpath = self.config["server.plugin_base_path"] + instance_id + self.plugin_routes.pop(subpath).clear() + except KeyError: + return + + def setup_instance_subapps(self) -> None: + self.plugin_routes = {} + resource = PrefixResource(self.config["server.plugin_base_path"].rstrip("/")) + resource.add_route(hdrs.METH_ANY, self.handle_plugin_path) + self.app.router.register_resource(resource) + + def setup_appservice(self) -> None: + as_path = PathBuilder(self.config["server.appservice_base_path"]) + self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) + def setup_management_ui(self) -> None: ui_base = self.config["server.ui_base_path"] if ui_base == "/": diff --git a/setup.py b/setup.py index a25d1a8..b58bfca 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setuptools.setup( packages=setuptools.find_packages(), install_requires=[ - "mautrix>=0.4.dev20,<0.5", + "mautrix>=0.4.dev24,<0.5", "aiohttp>=3.0.1,<4", "SQLAlchemy>=1.2.3,<2", "alembic>=1.0.0,<2",