Let plugins add their endpoints to the main webserver

This commit is contained in:
Tulir Asokan 2019-03-06 22:22:34 +02:00
parent 79c61d6889
commit f303bd66ab
7 changed files with 43 additions and 7 deletions

View File

@ -26,6 +26,8 @@ server:
base_path: /_matrix/maubot/v1 base_path: /_matrix/maubot/v1
# The base path for the UI. # The base path for the UI.
ui_base_path: /_matrix/maubot ui_base_path: /_matrix/maubot
# The base path for plugin endpoints. {id} is replaced with the ID of the instance.
plugin_base_path: /_matrix/maubot/plugin/{id}
# Override path from where to load UI resources. # Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path. # Set to false to using pkg_resources to find the path.
override_resource_path: false override_resource_path: false

View File

@ -58,10 +58,10 @@ loop = asyncio.get_event_loop()
init_zip_loader(config) init_zip_loader(config)
db_session = init_db(config) db_session = init_db(config)
clients = init_client_class(db_session, loop) clients = init_client_class(db_session, loop)
plugins = init_plugin_instance_class(db_session, config, loop)
management_api = init_mgmt_api(config, loop) management_api = init_mgmt_api(config, loop)
server = MaubotServer(config, loop) server = MaubotServer(config, loop)
server.app.add_subapp(config["server.base_path"], management_api) server.app.add_subapp(config["server.base_path"], management_api)
plugins = init_plugin_instance_class(db_session, config, server.app, loop)
for plugin in plugins: for plugin in plugins:
plugin.load() plugin.load()

View File

@ -41,6 +41,7 @@ class Config(BaseFileConfig):
copy("server.listen") copy("server.listen")
copy("server.base_path") copy("server.base_path")
copy("server.ui_base_path") copy("server.ui_base_path")
copy("server.plugin_base_path")
copy("server.override_resource_path") copy("server.override_resource_path")
copy("server.appservice_base_path") copy("server.appservice_base_path")
shared_secret = self["server.unshared_secret"] shared_secret = self["server.unshared_secret"]

View File

@ -13,8 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional from typing import Dict, List, Optional, TYPE_CHECKING
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from aiohttp import web
import os.path import os.path
import logging import logging
import io import io
@ -33,6 +34,9 @@ from .client import Client
from .loader import PluginLoader, ZippedPluginLoader from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
if TYPE_CHECKING:
from .server import MaubotServer
log = logging.getLogger("maubot.instance") log = logging.getLogger("maubot.instance")
yaml = YAML() yaml = YAML()
@ -41,6 +45,7 @@ yaml.indent(4)
class PluginInstance: class PluginInstance:
db: Session = None db: Session = None
webserver: 'MaubotServer' = None
mb_config: Config = None mb_config: Config = None
loop: AbstractEventLoop = None loop: AbstractEventLoop = None
cache: Dict[str, 'PluginInstance'] = {} cache: Dict[str, 'PluginInstance'] = {}
@ -54,6 +59,7 @@ class PluginInstance:
base_cfg: RecursiveDict[CommentedMap] base_cfg: RecursiveDict[CommentedMap]
inst_db: sql.engine.Engine inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table] inst_db_tables: Dict[str, sql.Table]
inst_webapp: web.Application
started: bool started: bool
def __init__(self, db_instance: DBPlugin): def __init__(self, db_instance: DBPlugin):
@ -66,6 +72,7 @@ class PluginInstance:
self.plugin = None self.plugin = None
self.inst_db = None self.inst_db = None
self.inst_db_tables = None self.inst_db_tables = None
self.inst_webapp = None
self.base_cfg = None self.base_cfg = None
self.cache[self.id] = self self.cache[self.id] = self
@ -105,6 +112,8 @@ class PluginInstance:
if self.loader.meta.database: if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
if self.loader.meta.webapp:
self.inst_webapp = self.webserver.get_instance_subapp(self.id)
self.log.debug("Plugin instance dependencies loaded") self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self) self.loader.references.add(self)
self.client.references.add(self) self.client.references.add(self)
@ -126,6 +135,8 @@ class PluginInstance:
ZippedPluginLoader.trash( ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted") reason="deleted")
if self.inst_webapp:
self.webserver.remove_instance_webapp(self.id)
def load_config(self) -> CommentedMap: def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config) return yaml.load(self.db_instance.config)
@ -157,7 +168,7 @@ class PluginInstance:
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config) self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client, self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config, instance_id=self.id, log=self.log, config=self.config,
database=self.inst_db) database=self.inst_db, webapp=self.inst_webapp)
try: try:
await self.plugin.start() await self.plugin.start()
except Exception: except Exception:
@ -274,8 +285,10 @@ class PluginInstance:
# endregion # endregion
def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]: def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[
PluginInstance]:
PluginInstance.db = db PluginInstance.db = db
PluginInstance.mb_config = config PluginInstance.mb_config = config
PluginInstance.loop = loop PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all() return PluginInstance.all()

View File

@ -57,6 +57,7 @@ class PluginMeta(SerializableAttrs['PluginMeta']):
maubot: Version = Version(__version__) maubot: Version = Version(__version__)
database: bool = False database: bool = False
webapp: bool = False
license: str = "" license: str = ""
extra_files: List[str] = [] extra_files: List[str] = []
dependencies: List[str] = [] dependencies: List[str] = []

View File

@ -17,7 +17,7 @@ from typing import Type, Optional, TYPE_CHECKING
from abc import ABC from abc import ABC
from logging import Logger from logging import Logger
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
import functools from aiohttp.web import Application
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession from aiohttp import ClientSession
@ -37,7 +37,7 @@ class Plugin(ABC):
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
instance_id: str, log: Logger, config: Optional['BaseProxyConfig'], instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
database: Optional[Engine]) -> None: database: Optional[Engine], webapp: Optional[Application]) -> None:
self.client = client self.client = client
self.loop = loop self.loop = loop
self.http = http self.http = http
@ -45,6 +45,7 @@ class Plugin(ABC):
self.log = log self.log = log
self.config = config self.config = config
self.database = database self.database = database
self.webapp = webapp
self._handlers_at_startup = [] self._handlers_at_startup = []
async def start(self) -> None: async def start(self) -> None:

View File

@ -43,11 +43,29 @@ class MaubotServer:
as_path = PathBuilder(config["server.appservice_base_path"]) as_path = PathBuilder(config["server.appservice_base_path"])
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
self.subapps = {}
self.setup_management_ui() self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
def get_instance_subapp(self, instance_id: str) -> web.Application:
try:
return self.subapps[instance_id]
except KeyError:
app = web.Application(loop=self.loop)
self.app.add_subapp(self.config["server.plugin_base_path"].format(id=instance_id), app)
self.subapps[instance_id] = app
return app
def remove_instance_webapp(self, instance_id: str) -> None:
try:
subapp: web.Application = self.subapps.pop(instance_id)
except KeyError:
return
subapp.shutdown()
subapp.cleanup()
def setup_management_ui(self) -> None: def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"] ui_base = self.config["server.ui_base_path"]
if ui_base == "/": if ui_base == "/":