diff --git a/maubot/__main__.py b/maubot/__main__.py
index a188ff5..586091e 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -59,8 +59,7 @@ init_zip_loader(config)
db_session = init_db(config)
clients = init_client_class(db_session, loop)
management_api = init_mgmt_api(config, loop)
-server = MaubotServer(config, loop)
-server.app.add_subapp(config["server.base_path"], management_api)
+server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(db_session, config, server, loop)
for plugin in plugins:
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index e07cae1..47ae7b4 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -24,7 +24,7 @@ from aiohttp import ClientSession
if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient
- from .server import PluginWebApp
+ from .plugin_server import PluginWebApp
class Plugin(ABC):
diff --git a/maubot/plugin_server.py b/maubot/plugin_server.py
new file mode 100644
index 0000000..a5dd49a
--- /dev/null
+++ b/maubot/plugin_server.py
@@ -0,0 +1,87 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2018 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from typing import List, Callable, Awaitable
+from functools import partial
+
+from aiohttp import web, hdrs
+from yarl import URL
+
+Handler = Callable[[web.Request], Awaitable[web.Response]]
+Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
+
+
+class PluginWebApp(web.UrlDispatcher):
+ def __init__(self):
+ super().__init__()
+ self._middleware: List[Middleware] = []
+
+ def add_middleware(self, middleware: Middleware) -> None:
+ self._middleware.append(middleware)
+
+ def remove_middleware(self, middleware: Middleware) -> None:
+ self._middleware.remove(middleware)
+
+ def clear(self) -> None:
+ self._resources = []
+ self._named_resources = {}
+ self._middleware = []
+
+ async def handle(self, request: web.Request) -> web.Response:
+ match_info = await self.resolve(request)
+ match_info.freeze()
+ resp = None
+ request._match_info = match_info
+ expect = request.headers.get(hdrs.EXPECT)
+ if expect:
+ resp = await match_info.expect_handler(request)
+ await request.writer.drain()
+ if resp is None:
+ handler = match_info.handler
+ for middleware in self._middleware:
+ handler = partial(middleware, handler=handler)
+ resp = await handler(request)
+ return resp
+
+
+class PrefixResource(web.Resource):
+ def __init__(self, prefix, *, name=None):
+ assert not prefix or prefix.startswith('/'), prefix
+ assert prefix in ('', '/') or not prefix.endswith('/'), prefix
+ super().__init__(name=name)
+ self._prefix = URL.build(path=prefix).raw_path
+
+ @property
+ def canonical(self):
+ return self._prefix
+
+ def get_info(self):
+ return {'path': self._prefix}
+
+ def url_for(self):
+ return URL.build(path=self._prefix, encoded=True)
+
+ def add_prefix(self, prefix):
+ assert prefix.startswith('/')
+ assert not prefix.endswith('/')
+ assert len(prefix) > 1
+ self._prefix = prefix + self._prefix
+
+ def _match(self, path: str) -> dict:
+ return {} if self.raw_match(path) else None
+
+ def raw_match(self, path: str) -> bool:
+ return path and path.startswith(self._prefix)
+
diff --git a/maubot/server.py b/maubot/server.py
index 91e113f..73de847 100644
--- a/maubot/server.py
+++ b/maubot/server.py
@@ -13,18 +13,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Tuple, List, Dict, Callable, Awaitable
-from functools import partial
+from typing import Tuple, Dict
import logging
import asyncio
-from aiohttp import web, hdrs, URL
+from aiohttp import web, hdrs
from aiohttp.abc import AbstractAccessLogger
import pkg_resources
from mautrix.api import PathBuilder, Method
from .config import Config
+from .plugin_server import PrefixResource, PluginWebApp
from .__meta__ import __version__
@@ -35,78 +35,19 @@ class AccessLogger(AbstractAccessLogger):
f'in {round(time, 4)}s"')
-Handler = Callable[[web.Request], Awaitable[web.Response]]
-Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
-
-
-class PluginWebApp(web.UrlDispatcher):
- def __init__(self):
- super().__init__()
- self._middleware: List[Middleware] = []
-
- def add_middleware(self, middleware: Middleware) -> None:
- self._middleware.append(middleware)
-
- def remove_middleware(self, middleware: Middleware) -> None:
- self._middleware.remove(middleware)
-
- async def handle(self, request: web.Request) -> web.Response:
- match_info = await self.resolve(request)
- match_info.freeze()
- resp = None
- request._match_info = match_info
- expect = request.headers.get(hdrs.EXPECT)
- if expect:
- resp = await match_info.expect_handler(request)
- await request.writer.drain()
- if resp is None:
- handler = match_info.handler
- for middleware in self._middleware:
- handler = partial(middleware, handler=handler)
- resp = await handler(request)
- return resp
-
-
-class PrefixResource(web.Resource):
- def __init__(self, prefix, *, name=None):
- assert not prefix or prefix.startswith('/'), prefix
- assert prefix in ('', '/') or not prefix.endswith('/'), prefix
- super().__init__(name=name)
- self._prefix = URL.build(path=prefix).raw_path
-
- @property
- def canonical(self):
- return self._prefix
-
- def add_prefix(self, prefix):
- assert prefix.startswith('/')
- assert not prefix.endswith('/')
- assert len(prefix) > 1
- self._prefix = prefix + self._prefix
-
- def _match(self, path: str) -> dict:
- return {} if self.raw_match(path) else None
-
- def raw_match(self, path: str) -> bool:
- return path and path.startswith(self._prefix)
-
-
class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
+ plugin_routes: Dict[str, PluginWebApp]
- def __init__(self, config: Config, loop: asyncio.AbstractEventLoop) -> None:
+ def __init__(self, management_api: web.Application, config: Config,
+ loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024)
self.config = config
- as_path = PathBuilder(config["server.appservice_base_path"])
- self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
-
- self.plugin_routes: Dict[str, PluginWebApp] = {}
- resource = PrefixResource(config["server.plugin_base_path"])
- resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
- self.app.router.register_resource(resource)
-
+ self.setup_appservice()
+ self.app.add_subapp(config["server.base_path"], management_api)
+ self.setup_instance_subapps()
self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
@@ -114,7 +55,8 @@ class MaubotServer:
async def handle_plugin_path(self, request: web.Request) -> web.Response:
for path, app in self.plugin_routes.items():
if request.path.startswith(path):
- request = request.clone(rel_url=request.path[len(path):])
+ request = request.clone(
+ rel_url=request.rel_url.with_path(request.rel_url.path[len(path):]))
return await app.handle(request)
return web.Response(status=404)
@@ -131,10 +73,20 @@ class MaubotServer:
def remove_instance_webapp(self, instance_id: str) -> None:
try:
subpath = self.config["server.plugin_base_path"] + instance_id
- self.plugin_routes.pop(subpath)
+ self.plugin_routes.pop(subpath).clear()
except KeyError:
return
+ def setup_instance_subapps(self) -> None:
+ self.plugin_routes = {}
+ resource = PrefixResource(self.config["server.plugin_base_path"].rstrip("/"))
+ resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
+ self.app.router.register_resource(resource)
+
+ def setup_appservice(self) -> None:
+ as_path = PathBuilder(self.config["server.appservice_base_path"])
+ self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
+
def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"]
if ui_base == "/":