diff --git a/maubot/__main__.py b/maubot/__main__.py index a188ff5..586091e 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -59,8 +59,7 @@ init_zip_loader(config) db_session = init_db(config) clients = init_client_class(db_session, loop) management_api = init_mgmt_api(config, loop) -server = MaubotServer(config, loop) -server.app.add_subapp(config["server.base_path"], management_api) +server = MaubotServer(management_api, config, loop) plugins = init_plugin_instance_class(db_session, config, server, loop) for plugin in plugins: diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index e07cae1..47ae7b4 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -24,7 +24,7 @@ from aiohttp import ClientSession if TYPE_CHECKING: from mautrix.util.config import BaseProxyConfig from .client import MaubotMatrixClient - from .server import PluginWebApp + from .plugin_server import PluginWebApp class Plugin(ABC): diff --git a/maubot/plugin_server.py b/maubot/plugin_server.py new file mode 100644 index 0000000..a5dd49a --- /dev/null +++ b/maubot/plugin_server.py @@ -0,0 +1,87 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import List, Callable, Awaitable +from functools import partial + +from aiohttp import web, hdrs +from yarl import URL + +Handler = Callable[[web.Request], Awaitable[web.Response]] +Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]] + + +class PluginWebApp(web.UrlDispatcher): + def __init__(self): + super().__init__() + self._middleware: List[Middleware] = [] + + def add_middleware(self, middleware: Middleware) -> None: + self._middleware.append(middleware) + + def remove_middleware(self, middleware: Middleware) -> None: + self._middleware.remove(middleware) + + def clear(self) -> None: + self._resources = [] + self._named_resources = {} + self._middleware = [] + + async def handle(self, request: web.Request) -> web.Response: + match_info = await self.resolve(request) + match_info.freeze() + resp = None + request._match_info = match_info + expect = request.headers.get(hdrs.EXPECT) + if expect: + resp = await match_info.expect_handler(request) + await request.writer.drain() + if resp is None: + handler = match_info.handler + for middleware in self._middleware: + handler = partial(middleware, handler=handler) + resp = await handler(request) + return resp + + +class PrefixResource(web.Resource): + def __init__(self, prefix, *, name=None): + assert not prefix or prefix.startswith('/'), prefix + assert prefix in ('', '/') or not prefix.endswith('/'), prefix + super().__init__(name=name) + self._prefix = URL.build(path=prefix).raw_path + + @property + def canonical(self): + return self._prefix + + def get_info(self): + return {'path': self._prefix} + + def url_for(self): + return URL.build(path=self._prefix, encoded=True) + + def add_prefix(self, prefix): + assert prefix.startswith('/') + assert not prefix.endswith('/') + assert len(prefix) > 1 + self._prefix = prefix + self._prefix + + def _match(self, path: str) -> dict: + return {} if self.raw_match(path) else None + + def raw_match(self, path: str) -> bool: + return path and path.startswith(self._prefix) + diff --git a/maubot/server.py b/maubot/server.py index 91e113f..73de847 100644 --- a/maubot/server.py +++ b/maubot/server.py @@ -13,18 +13,18 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, List, Dict, Callable, Awaitable -from functools import partial +from typing import Tuple, Dict import logging import asyncio -from aiohttp import web, hdrs, URL +from aiohttp import web, hdrs from aiohttp.abc import AbstractAccessLogger import pkg_resources from mautrix.api import PathBuilder, Method from .config import Config +from .plugin_server import PrefixResource, PluginWebApp from .__meta__ import __version__ @@ -35,78 +35,19 @@ class AccessLogger(AbstractAccessLogger): f'in {round(time, 4)}s"') -Handler = Callable[[web.Request], Awaitable[web.Response]] -Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]] - - -class PluginWebApp(web.UrlDispatcher): - def __init__(self): - super().__init__() - self._middleware: List[Middleware] = [] - - def add_middleware(self, middleware: Middleware) -> None: - self._middleware.append(middleware) - - def remove_middleware(self, middleware: Middleware) -> None: - self._middleware.remove(middleware) - - async def handle(self, request: web.Request) -> web.Response: - match_info = await self.resolve(request) - match_info.freeze() - resp = None - request._match_info = match_info - expect = request.headers.get(hdrs.EXPECT) - if expect: - resp = await match_info.expect_handler(request) - await request.writer.drain() - if resp is None: - handler = match_info.handler - for middleware in self._middleware: - handler = partial(middleware, handler=handler) - resp = await handler(request) - return resp - - -class PrefixResource(web.Resource): - def __init__(self, prefix, *, name=None): - assert not prefix or prefix.startswith('/'), prefix - assert prefix in ('', '/') or not prefix.endswith('/'), prefix - super().__init__(name=name) - self._prefix = URL.build(path=prefix).raw_path - - @property - def canonical(self): - return self._prefix - - def add_prefix(self, prefix): - assert prefix.startswith('/') - assert not prefix.endswith('/') - assert len(prefix) > 1 - self._prefix = prefix + self._prefix - - def _match(self, path: str) -> dict: - return {} if self.raw_match(path) else None - - def raw_match(self, path: str) -> bool: - return path and path.startswith(self._prefix) - - class MaubotServer: log: logging.Logger = logging.getLogger("maubot.server") + plugin_routes: Dict[str, PluginWebApp] - def __init__(self, config: Config, loop: asyncio.AbstractEventLoop) -> None: + def __init__(self, management_api: web.Application, config: Config, + loop: asyncio.AbstractEventLoop) -> None: self.loop = loop or asyncio.get_event_loop() self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024) self.config = config - as_path = PathBuilder(config["server.appservice_base_path"]) - self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) - - self.plugin_routes: Dict[str, PluginWebApp] = {} - resource = PrefixResource(config["server.plugin_base_path"]) - resource.add_route(hdrs.METH_ANY, self.handle_plugin_path) - self.app.router.register_resource(resource) - + self.setup_appservice() + self.app.add_subapp(config["server.base_path"], management_api) + self.setup_instance_subapps() self.setup_management_ui() self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) @@ -114,7 +55,8 @@ class MaubotServer: async def handle_plugin_path(self, request: web.Request) -> web.Response: for path, app in self.plugin_routes.items(): if request.path.startswith(path): - request = request.clone(rel_url=request.path[len(path):]) + request = request.clone( + rel_url=request.rel_url.with_path(request.rel_url.path[len(path):])) return await app.handle(request) return web.Response(status=404) @@ -131,10 +73,20 @@ class MaubotServer: def remove_instance_webapp(self, instance_id: str) -> None: try: subpath = self.config["server.plugin_base_path"] + instance_id - self.plugin_routes.pop(subpath) + self.plugin_routes.pop(subpath).clear() except KeyError: return + def setup_instance_subapps(self) -> None: + self.plugin_routes = {} + resource = PrefixResource(self.config["server.plugin_base_path"].rstrip("/")) + resource.add_route(hdrs.METH_ANY, self.handle_plugin_path) + self.app.router.register_resource(resource) + + def setup_appservice(self) -> None: + as_path = PathBuilder(self.config["server.appservice_base_path"]) + self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) + def setup_management_ui(self) -> None: ui_base = self.config["server.ui_base_path"] if ui_base == "/":