diff --git a/maubot/__main__.py b/maubot/__main__.py index 9f8cafe..3e18a68 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -13,8 +13,6 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy import orm -import sqlalchemy as sql import logging.config import argparse import asyncio @@ -23,11 +21,11 @@ import copy import sys from .config import Config -from .db import Base, init as init_db +from .db import init as init_db from .server import MaubotServer -from .client import Client, init as init_client -from .loader import ZippedPluginLoader -from .instance import PluginInstance, init as init_plugin_instance_class +from .client import Client, init as init_client_class +from .loader.zip import init as init_zip_loader +from .instance import init as init_plugin_instance_class from .management.api import init as init_management from .__meta__ import __version__ @@ -46,57 +44,48 @@ config.update() logging.config.dictConfig(copy.deepcopy(config["logging"])) log = logging.getLogger("maubot.init") -log.debug(f"Initializing maubot {__version__}") - -db_engine: sql.engine.Engine = sql.create_engine(config["database"]) -db_factory = orm.sessionmaker(bind=db_engine) -db_session = orm.scoping.scoped_session(db_factory) -Base.metadata.bind = db_engine -Base.metadata.create_all() +log.info(f"Initializing maubot {__version__}") loop = asyncio.get_event_loop() -init_db(db_session) -clients = init_client(loop) -init_plugin_instance_class(db_session, config, loop) +init_zip_loader(config) +db_session = init_db(config) +clients = init_client_class(db_session, loop) +plugins = init_plugin_instance_class(db_session, config, loop) management_api = init_management(config, loop) server = MaubotServer(config, management_api, loop) -ZippedPluginLoader.trash_path = config["plugin_directories.trash"] -ZippedPluginLoader.directories = config["plugin_directories.load"] -ZippedPluginLoader.load_all() - -plugins = PluginInstance.all() - for plugin in plugins: plugin.load() signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler) -stop = False - async def periodic_commit(): - while not stop: + while True: await asyncio.sleep(60) db_session.commit() +periodic_commit_task: asyncio.Future = None + try: - log.debug("Starting server") + log.info("Starting server") loop.run_until_complete(server.start()) - log.debug("Starting clients and plugins") + log.info("Starting clients and plugins") loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) - log.debug("Startup actions complete, running forever") - loop.run_until_complete(periodic_commit()) + log.info("Startup actions complete, running forever") + periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop) loop.run_forever() except KeyboardInterrupt: log.debug("Interrupt received, stopping HTTP clients/servers and saving database") - stop = True + if periodic_commit_task is not None: + periodic_commit_task.cancel() for client in Client.cache.values(): client.stop() db_session.commit() loop.run_until_complete(server.stop()) + loop.close() log.debug("Everything stopped, shutting down") sys.exit(0) diff --git a/maubot/client.py b/maubot/client.py index dee1d8c..684eaea 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -14,10 +14,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Dict, List, Optional, Set, TYPE_CHECKING -from aiohttp import ClientSession import asyncio import logging +from sqlalchemy.orm import Session +from aiohttp import ClientSession + from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, EventType, Filter, RoomFilter, RoomEventFilter) @@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client") class Client: + db: Session = None log: logging.Logger = None loop: asyncio.AbstractEventLoop = None cache: Dict[UserID, 'Client'] = {} @@ -73,12 +76,12 @@ class Client: user_id = await self.client.whoami() except MatrixInvalidToken as e: self.log.error(f"Invalid token: {e}. Disabling client") - self.enabled = False + self.db_instance.enabled = False return except MatrixRequestError: if try_n >= 5: self.log.exception("Failed to get /account/whoami, disabling client") - self.enabled = False + self.db_instance.enabled = False else: self.log.exception(f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s") @@ -86,7 +89,7 @@ class Client: return if user_id != self.id: self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") - self.enabled = False + self.db_instance.enabled = False return if not self.filter_id: self.filter_id = await self.client.create_filter(Filter( @@ -100,8 +103,7 @@ class Client: await self.client.set_displayname(self.displayname) if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) - if self.sync: - self.client.start(self.filter_id) + self.start_sync() self.started = True self.log.info("Client started, starting plugin instances...") await self.start_plugins() @@ -110,12 +112,19 @@ class Client: await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop) async def stop_plugins(self) -> None: - await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running], + await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started], loop=self.loop) + def start_sync(self) -> None: + if self.sync: + self.client.start(self.filter_id) + + def stop_sync(self) -> None: + self.client.stop() + def stop(self) -> None: self.started = False - self.client.stop() + self.stop_sync() def to_dict(self) -> dict: return { @@ -233,7 +242,8 @@ class Client: # endregion -def init(loop: asyncio.AbstractEventLoop) -> List[Client]: +def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]: + Client.db = db Client.http_client = ClientSession(loop=loop) Client.loop = loop return Client.all() diff --git a/maubot/db.py b/maubot/db.py index d53658c..2373fb2 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -13,12 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import cast + from sqlalchemy import Column, String, Boolean, ForeignKey, Text -from sqlalchemy.orm import Query, scoped_session +from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy as sql from mautrix.types import UserID, FilterID, SyncToken, ContentURI +from .config import Config + Base: declarative_base = declarative_base() @@ -54,6 +59,14 @@ class DBClient(Base): avatar_url: ContentURI = Column(String(255), nullable=False, default="") -def init(session: scoped_session) -> None: - DBPlugin.query = session.query_property() - DBClient.query = session.query_property() +def init(config: Config) -> Session: + db_engine: sql.engine.Engine = sql.create_engine(config["database"]) + db_factory = sessionmaker(bind=db_engine) + db_session = scoped_session(db_factory) + Base.metadata.bind = db_engine + Base.metadata.create_all() + + DBPlugin.query = db_session.query_property() + DBClient.query = db_session.query_property() + + return cast(Session, db_session) diff --git a/maubot/instance.py b/maubot/instance.py index 51c9dcd..6e3590f 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -48,13 +48,14 @@ class PluginInstance: client: Client plugin: Plugin config: BaseProxyConfig - running: bool + base_cfg: RecursiveDict[CommentedMap] + started: bool def __init__(self, db_instance: DBPlugin): self.db_instance = db_instance self.log = logging.getLogger(f"maubot.plugin.{self.id}") self.config = None - self.running = False + self.started = False self.cache[self.id] = self def to_dict(self) -> dict: @@ -62,7 +63,7 @@ class PluginInstance: "id": self.id, "type": self.type, "enabled": self.enabled, - "running": self.running, + "started": self.started, "primary_user": self.primary_user, } @@ -71,19 +72,26 @@ class PluginInstance: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") - self.enabled = False + self.db_instance.enabled = False return self.client = Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") - self.enabled = False + self.db_instance.enabled = False return self.log.debug("Plugin instance dependencies loaded") self.loader.references.add(self) self.client.references.add(self) def delete(self) -> None: - self.loader.references.remove(self) + if self.loader is not None: + self.loader.references.remove(self) + if self.client is not None: + self.client.references.remove(self) + try: + del self.cache[self.id] + except KeyError: + pass self.db.delete(self.db_instance) # TODO delete plugin db @@ -96,7 +104,7 @@ class PluginInstance: self.db_instance.config = buf.getvalue() async def start(self) -> None: - if self.running: + if self.started: self.log.warning("Ignoring start() call to already started plugin") return elif not self.enabled: @@ -107,28 +115,28 @@ class PluginInstance: if config_class: try: base = await self.loader.read_file("base-config.yaml") - base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap) + self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap) except (FileNotFoundError, KeyError): - base_file = None - self.config = config_class(self.load_config, lambda: base_file, self.save_config) + self.base_cfg = None + self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config) self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id, self.log, self.config, self.mb_config["plugin_directories.db"]) try: await self.plugin.start() except Exception: self.log.exception("Failed to start instance") - self.enabled = False + self.db_instance.enabled = False return - self.running = True + self.started = True self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} " f"with user {self.client.id}") async def stop(self) -> None: - if not self.running: + if not self.started: self.log.warning("Ignoring stop() call to non-running plugin") return self.log.debug("Stopping plugin instance...") - self.running = False + self.started = False try: await self.plugin.stop() except Exception: @@ -150,6 +158,37 @@ class PluginInstance: def all(cls) -> List['PluginInstance']: return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()] + def update_id(self, new_id: str) -> None: + if new_id is not None and new_id != self.id: + self.db_instance.id = new_id + + def update_config(self, config: str) -> None: + if not config or self.db_instance.config == config: + return + self.db_instance.config = config + if self.started and self.plugin is not None: + self.plugin.on_external_config_update() + + async def update_primary_user(self, primary_user: UserID) -> bool: + client = Client.get(primary_user) + if not client: + return False + await self.stop() + self.db_instance.primary_user = client.id + self.client.references.remove(self) + self.client = client + await self.start() + self.log.debug(f"Primary user switched to {self.client.id}") + return True + + async def update_started(self, started: bool) -> None: + if started is not None and started != self.started: + await (self.start() if started else self.stop()) + + def update_enabled(self, enabled: bool) -> None: + if enabled is not None and enabled != self.enabled: + self.db_instance.enabled = enabled + # region Properties @property @@ -168,22 +207,15 @@ class PluginInstance: def enabled(self) -> bool: return self.db_instance.enabled - @enabled.setter - def enabled(self, value: bool) -> None: - self.db_instance.enabled = value - @property def primary_user(self) -> UserID: return self.db_instance.primary_user - @primary_user.setter - def primary_user(self, value: UserID) -> None: - self.db_instance.primary_user = value - # endregion -def init(db: Session, config: Config, loop: AbstractEventLoop): +def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]: PluginInstance.db = db PluginInstance.mb_config = config PluginInstance.loop = loop + return PluginInstance.all() diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index 4dbb299..111e469 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -61,7 +61,7 @@ class PluginLoader(ABC): async def stop_instances(self) -> None: await asyncio.gather(*[instance.stop() for instance - in self.references if instance.running]) + in self.references if instance.started]) async def start_instances(self) -> None: await asyncio.gather(*[instance.start() for instance diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index a4f1d6e..b8c54ea 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -23,6 +23,7 @@ import os from ..lib.zipimport import zipimporter, ZipImportError from ..plugin_base import Plugin +from ..config import Config from .abc import PluginLoader, PluginClass, IDConflictError @@ -264,3 +265,9 @@ class ZippedPluginLoader(PluginLoader): except IDConflictError: cls.log.error(f"Duplicate plugin ID at {path}, trashing...") cls.trash(path) + + +def init(config: Config) -> None: + ZippedPluginLoader.trash_path = config["plugin_directories.trash"] + ZippedPluginLoader.directories = config["plugin_directories.load"] + ZippedPluginLoader.load_all() diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py index e66b527..d8d1917 100644 --- a/maubot/management/api/__init__.py +++ b/maubot/management/api/__init__.py @@ -21,6 +21,8 @@ from .base import routes, set_config from .middleware import auth, error from .auth import web as _ from .plugin import web as _ +from .instance import web as _ +from .client import web as _ def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index a7ec14e..8f5eed0 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -13,27 +13,57 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from json import JSONDecodeError + from aiohttp import web +from mautrix.types import UserID + +from ...client import Client from .base import routes -from .responses import ErrNotImplemented +from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON @routes.get("/clients") -def get_clients(request: web.Request) -> web.Response: - return ErrNotImplemented +async def get_clients(request: web.Request) -> web.Response: + return web.json_response([client.to_dict() for client in Client.cache.values()]) @routes.get("/client/{id}") -def get_client(request: web.Request) -> web.Response: +async def get_client(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + if not client: + return ErrClientNotFound + return web.json_response(client.to_dict()) + + +async def create_client(user_id: UserID, data: dict) -> web.Response: + return ErrNotImplemented + + +async def update_client(client: Client, data: dict) -> web.Response: return ErrNotImplemented @routes.put("/client/{id}") -def update_client(request: web.Request) -> web.Response: - return ErrNotImplemented +async def update_client(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + try: + data = await request.json() + except JSONDecodeError: + return ErrBodyNotJSON + if not client: + return await create_client(user_id, data) + else: + return await update_client(client, data) @routes.delete("/client/{id}") -def delete_client(request: web.Request) -> web.Response: +async def delete_client(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + if not client: + return ErrClientNotFound return ErrNotImplemented diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index 064c3f8..166108a 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -13,57 +13,88 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web from json import JSONDecodeError -from mautrix.types import UserID +from aiohttp import web -from ...db import DBClient +from ...db import DBPlugin +from ...instance import PluginInstance +from ...loader import PluginLoader from ...client import Client from .base import routes -from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON +from .responses import (ErrInstanceNotFound, ErrBodyNotJSON, RespDeleted, ErrPrimaryUserNotFound, + ErrPluginTypeRequired, ErrPrimaryUserRequired, ErrPluginTypeNotFound) @routes.get("/instances") async def get_instances(_: web.Request) -> web.Response: - return web.json_response([client.to_dict() for client in Client.cache.values()]) + return web.json_response([instance.to_dict() for instance in PluginInstance.cache.values()]) @routes.get("/instance/{id}") async def get_instance(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) - if not client: - return ErrClientNotFound - return web.json_response(client.to_dict()) + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id, None) + if not instance: + return ErrInstanceNotFound + return web.json_response(instance.to_dict()) -async def create_instance(user_id: UserID, data: dict) -> web.Response: - return ErrNotImplemented +async def create_instance(instance_id: str, data: dict) -> web.Response: + plugin_type = data.get("type", None) + primary_user = data.get("primary_user", None) + if not plugin_type: + return ErrPluginTypeRequired + elif not primary_user: + return ErrPrimaryUserRequired + elif not Client.get(primary_user): + return ErrPrimaryUserNotFound + try: + PluginLoader.find(plugin_type) + except KeyError: + return ErrPluginTypeNotFound + db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True), + primary_user=primary_user, config=data.get("config", "")) + instance = PluginInstance(db_instance) + instance.load() + PluginInstance.db.add(db_instance) + PluginInstance.db.commit() + await instance.start() + return web.json_response(instance.to_dict()) -async def update_instance(client: Client, data: dict) -> web.Response: - return ErrNotImplemented +async def update_instance(instance: PluginInstance, data: dict) -> web.Response: + if not await instance.update_primary_user(data.get("primary_user")): + return ErrPrimaryUserNotFound + instance.update_id(data.get("id", None)) + instance.update_enabled(data.get("enabled", None)) + instance.update_config(data.get("config", None)) + await instance.update_started(data.get("started", None)) + instance.db.commit() + return web.json_response(instance.to_dict()) @routes.put("/instance/{id}") async def update_instance(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id, None) try: data = await request.json() except JSONDecodeError: return ErrBodyNotJSON - if not client: - return await create_instance(user_id, data) + if not instance: + return await create_instance(instance_id, data) else: - return await update_instance(client, data) + return await update_instance(instance, data) @routes.delete("/instance/{id}") async def delete_instance(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) - if not client: - return ErrClientNotFound - return ErrNotImplemented + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id, None) + if not instance: + return ErrInstanceNotFound + if instance.started: + await instance.stop() + instance.delete() + return RespDeleted diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index ad3033b..16efecb 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -41,6 +41,31 @@ ErrClientNotFound = web.json_response({ "errcode": "client_not_found", }, status=HTTPStatus.NOT_FOUND) +ErrPrimaryUserNotFound = web.json_response({ + "error": "Client for given primary user not found", + "errcode": "primary_user_not_found", +}, status=HTTPStatus.NOT_FOUND) + +ErrInstanceNotFound = web.json_response({ + "error": "Plugin instance not found", + "errcode": "instance_not_found", +}, status=HTTPStatus.NOT_FOUND) + +ErrPluginTypeNotFound = web.json_response({ + "error": "Given plugin type not found", + "errcode": "plugin_type_not_found", +}, status=HTTPStatus.NOT_FOUND) + +ErrPluginTypeRequired = web.json_response({ + "error": "Plugin type is required when creating plugin instances", + "errcode": "plugin_type_required", +}, status=HTTPStatus.BAD_REQUEST) + +ErrPrimaryUserRequired = web.json_response({ + "error": "Primary user is required when creating plugin instances", + "errcode": "primary_user_required", +}, status=HTTPStatus.BAD_REQUEST) + ErrPathNotFound = web.json_response({ "error": "Resource not found", "errcode": "resource_not_found", diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml index 8cfde21..7b7ebd9 100644 --- a/maubot/management/api/spec.yaml +++ b/maubot/management/api/spec.yaml @@ -346,6 +346,9 @@ components: primary_user: type: string example: '@putkiteippi:maunium.net' + config: + type: string + example: "YAML" MatrixClient: type: object properties: diff --git a/maubot/management/frontend/src/index.js b/maubot/management/frontend/src/index.js index 294a680..8f50972 100644 --- a/maubot/management/frontend/src/index.js +++ b/maubot/management/frontend/src/index.js @@ -15,7 +15,7 @@ // along with this program. If not, see . import React from "react" import ReactDOM from "react-dom" -import "./style/base" +import "./style/index.sass" import MaubotManager from "./MaubotManager" ReactDOM.render(, document.getElementById("root")) diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 9b394f1..3925513 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -28,7 +28,6 @@ if TYPE_CHECKING: from .command_spec import CommandSpec from mautrix.util.config import BaseProxyConfig - DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.") @@ -69,3 +68,7 @@ class Plugin(ABC): @classmethod def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]: return None + + def on_external_config_update(self) -> None: + if self.config: + self.config.load_and_update() diff --git a/setup.py b/setup.py index 0d8dad0..2995b45 100644 --- a/setup.py +++ b/setup.py @@ -48,4 +48,8 @@ setuptools.setup( data_files=[ (".", ["example-config.yaml"]), ], + package_data={ + "maubot": ["management/frontend/build/*", "management/frontend/build/static/css/*", + "management/frontend/build/static/js/*"], + }, )