diff --git a/maubot/__main__.py b/maubot/__main__.py
index 9f8cafe..3e18a68 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -13,8 +13,6 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from sqlalchemy import orm
-import sqlalchemy as sql
import logging.config
import argparse
import asyncio
@@ -23,11 +21,11 @@ import copy
import sys
from .config import Config
-from .db import Base, init as init_db
+from .db import init as init_db
from .server import MaubotServer
-from .client import Client, init as init_client
-from .loader import ZippedPluginLoader
-from .instance import PluginInstance, init as init_plugin_instance_class
+from .client import Client, init as init_client_class
+from .loader.zip import init as init_zip_loader
+from .instance import init as init_plugin_instance_class
from .management.api import init as init_management
from .__meta__ import __version__
@@ -46,57 +44,48 @@ config.update()
logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot.init")
-log.debug(f"Initializing maubot {__version__}")
-
-db_engine: sql.engine.Engine = sql.create_engine(config["database"])
-db_factory = orm.sessionmaker(bind=db_engine)
-db_session = orm.scoping.scoped_session(db_factory)
-Base.metadata.bind = db_engine
-Base.metadata.create_all()
+log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
-init_db(db_session)
-clients = init_client(loop)
-init_plugin_instance_class(db_session, config, loop)
+init_zip_loader(config)
+db_session = init_db(config)
+clients = init_client_class(db_session, loop)
+plugins = init_plugin_instance_class(db_session, config, loop)
management_api = init_management(config, loop)
server = MaubotServer(config, management_api, loop)
-ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
-ZippedPluginLoader.directories = config["plugin_directories.load"]
-ZippedPluginLoader.load_all()
-
-plugins = PluginInstance.all()
-
for plugin in plugins:
plugin.load()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
-stop = False
-
async def periodic_commit():
- while not stop:
+ while True:
await asyncio.sleep(60)
db_session.commit()
+periodic_commit_task: asyncio.Future = None
+
try:
- log.debug("Starting server")
+ log.info("Starting server")
loop.run_until_complete(server.start())
- log.debug("Starting clients and plugins")
+ log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
- log.debug("Startup actions complete, running forever")
- loop.run_until_complete(periodic_commit())
+ log.info("Startup actions complete, running forever")
+ periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
- stop = True
+ if periodic_commit_task is not None:
+ periodic_commit_task.cancel()
for client in Client.cache.values():
client.stop()
db_session.commit()
loop.run_until_complete(server.stop())
+ loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)
diff --git a/maubot/client.py b/maubot/client.py
index dee1d8c..684eaea 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -14,10 +14,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from typing import Dict, List, Optional, Set, TYPE_CHECKING
-from aiohttp import ClientSession
import asyncio
import logging
+from sqlalchemy.orm import Session
+from aiohttp import ClientSession
+
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
@@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client")
class Client:
+ db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@@ -73,12 +76,12 @@ class Client:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
- self.enabled = False
+ self.db_instance.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
- self.enabled = False
+ self.db_instance.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
@@ -86,7 +89,7 @@ class Client:
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
- self.enabled = False
+ self.db_instance.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
@@ -100,8 +103,7 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
- if self.sync:
- self.client.start(self.filter_id)
+ self.start_sync()
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
@@ -110,12 +112,19 @@ class Client:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
- await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
+ await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
loop=self.loop)
+ def start_sync(self) -> None:
+ if self.sync:
+ self.client.start(self.filter_id)
+
+ def stop_sync(self) -> None:
+ self.client.stop()
+
def stop(self) -> None:
self.started = False
- self.client.stop()
+ self.stop_sync()
def to_dict(self) -> dict:
return {
@@ -233,7 +242,8 @@ class Client:
# endregion
-def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
+def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
+ Client.db = db
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()
diff --git a/maubot/db.py b/maubot/db.py
index d53658c..2373fb2 100644
--- a/maubot/db.py
+++ b/maubot/db.py
@@ -13,12 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import cast
+
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
-from sqlalchemy.orm import Query, scoped_session
+from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
+import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
+from .config import Config
+
Base: declarative_base = declarative_base()
@@ -54,6 +59,14 @@ class DBClient(Base):
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
-def init(session: scoped_session) -> None:
- DBPlugin.query = session.query_property()
- DBClient.query = session.query_property()
+def init(config: Config) -> Session:
+ db_engine: sql.engine.Engine = sql.create_engine(config["database"])
+ db_factory = sessionmaker(bind=db_engine)
+ db_session = scoped_session(db_factory)
+ Base.metadata.bind = db_engine
+ Base.metadata.create_all()
+
+ DBPlugin.query = db_session.query_property()
+ DBClient.query = db_session.query_property()
+
+ return cast(Session, db_session)
diff --git a/maubot/instance.py b/maubot/instance.py
index 51c9dcd..6e3590f 100644
--- a/maubot/instance.py
+++ b/maubot/instance.py
@@ -48,13 +48,14 @@ class PluginInstance:
client: Client
plugin: Plugin
config: BaseProxyConfig
- running: bool
+ base_cfg: RecursiveDict[CommentedMap]
+ started: bool
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.config = None
- self.running = False
+ self.started = False
self.cache[self.id] = self
def to_dict(self) -> dict:
@@ -62,7 +63,7 @@ class PluginInstance:
"id": self.id,
"type": self.type,
"enabled": self.enabled,
- "running": self.running,
+ "started": self.started,
"primary_user": self.primary_user,
}
@@ -71,19 +72,26 @@ class PluginInstance:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
- self.enabled = False
+ self.db_instance.enabled = False
return
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
- self.enabled = False
+ self.db_instance.enabled = False
return
self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self)
self.client.references.add(self)
def delete(self) -> None:
- self.loader.references.remove(self)
+ if self.loader is not None:
+ self.loader.references.remove(self)
+ if self.client is not None:
+ self.client.references.remove(self)
+ try:
+ del self.cache[self.id]
+ except KeyError:
+ pass
self.db.delete(self.db_instance)
# TODO delete plugin db
@@ -96,7 +104,7 @@ class PluginInstance:
self.db_instance.config = buf.getvalue()
async def start(self) -> None:
- if self.running:
+ if self.started:
self.log.warning("Ignoring start() call to already started plugin")
return
elif not self.enabled:
@@ -107,28 +115,28 @@ class PluginInstance:
if config_class:
try:
base = await self.loader.read_file("base-config.yaml")
- base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
+ self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
- base_file = None
- self.config = config_class(self.load_config, lambda: base_file, self.save_config)
+ self.base_cfg = None
+ self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
self.log, self.config, self.mb_config["plugin_directories.db"])
try:
await self.plugin.start()
except Exception:
self.log.exception("Failed to start instance")
- self.enabled = False
+ self.db_instance.enabled = False
return
- self.running = True
+ self.started = True
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
f"with user {self.client.id}")
async def stop(self) -> None:
- if not self.running:
+ if not self.started:
self.log.warning("Ignoring stop() call to non-running plugin")
return
self.log.debug("Stopping plugin instance...")
- self.running = False
+ self.started = False
try:
await self.plugin.stop()
except Exception:
@@ -150,6 +158,37 @@ class PluginInstance:
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
+ def update_id(self, new_id: str) -> None:
+ if new_id is not None and new_id != self.id:
+ self.db_instance.id = new_id
+
+ def update_config(self, config: str) -> None:
+ if not config or self.db_instance.config == config:
+ return
+ self.db_instance.config = config
+ if self.started and self.plugin is not None:
+ self.plugin.on_external_config_update()
+
+ async def update_primary_user(self, primary_user: UserID) -> bool:
+ client = Client.get(primary_user)
+ if not client:
+ return False
+ await self.stop()
+ self.db_instance.primary_user = client.id
+ self.client.references.remove(self)
+ self.client = client
+ await self.start()
+ self.log.debug(f"Primary user switched to {self.client.id}")
+ return True
+
+ async def update_started(self, started: bool) -> None:
+ if started is not None and started != self.started:
+ await (self.start() if started else self.stop())
+
+ def update_enabled(self, enabled: bool) -> None:
+ if enabled is not None and enabled != self.enabled:
+ self.db_instance.enabled = enabled
+
# region Properties
@property
@@ -168,22 +207,15 @@ class PluginInstance:
def enabled(self) -> bool:
return self.db_instance.enabled
- @enabled.setter
- def enabled(self, value: bool) -> None:
- self.db_instance.enabled = value
-
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
- @primary_user.setter
- def primary_user(self, value: UserID) -> None:
- self.db_instance.primary_user = value
-
# endregion
-def init(db: Session, config: Config, loop: AbstractEventLoop):
+def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]:
PluginInstance.db = db
PluginInstance.mb_config = config
PluginInstance.loop = loop
+ return PluginInstance.all()
diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py
index 4dbb299..111e469 100644
--- a/maubot/loader/abc.py
+++ b/maubot/loader/abc.py
@@ -61,7 +61,7 @@ class PluginLoader(ABC):
async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance
- in self.references if instance.running])
+ in self.references if instance.started])
async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance
diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py
index a4f1d6e..b8c54ea 100644
--- a/maubot/loader/zip.py
+++ b/maubot/loader/zip.py
@@ -23,6 +23,7 @@ import os
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
+from ..config import Config
from .abc import PluginLoader, PluginClass, IDConflictError
@@ -264,3 +265,9 @@ class ZippedPluginLoader(PluginLoader):
except IDConflictError:
cls.log.error(f"Duplicate plugin ID at {path}, trashing...")
cls.trash(path)
+
+
+def init(config: Config) -> None:
+ ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
+ ZippedPluginLoader.directories = config["plugin_directories.load"]
+ ZippedPluginLoader.load_all()
diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py
index e66b527..d8d1917 100644
--- a/maubot/management/api/__init__.py
+++ b/maubot/management/api/__init__.py
@@ -21,6 +21,8 @@ from .base import routes, set_config
from .middleware import auth, error
from .auth import web as _
from .plugin import web as _
+from .instance import web as _
+from .client import web as _
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py
index a7ec14e..8f5eed0 100644
--- a/maubot/management/api/client.py
+++ b/maubot/management/api/client.py
@@ -13,27 +13,57 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from json import JSONDecodeError
+
from aiohttp import web
+from mautrix.types import UserID
+
+from ...client import Client
from .base import routes
-from .responses import ErrNotImplemented
+from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
@routes.get("/clients")
-def get_clients(request: web.Request) -> web.Response:
- return ErrNotImplemented
+async def get_clients(request: web.Request) -> web.Response:
+ return web.json_response([client.to_dict() for client in Client.cache.values()])
@routes.get("/client/{id}")
-def get_client(request: web.Request) -> web.Response:
+async def get_client(request: web.Request) -> web.Response:
+ user_id = request.match_info.get("id", None)
+ client = Client.get(user_id, None)
+ if not client:
+ return ErrClientNotFound
+ return web.json_response(client.to_dict())
+
+
+async def create_client(user_id: UserID, data: dict) -> web.Response:
+ return ErrNotImplemented
+
+
+async def update_client(client: Client, data: dict) -> web.Response:
return ErrNotImplemented
@routes.put("/client/{id}")
-def update_client(request: web.Request) -> web.Response:
- return ErrNotImplemented
+async def update_client(request: web.Request) -> web.Response:
+ user_id = request.match_info.get("id", None)
+ client = Client.get(user_id, None)
+ try:
+ data = await request.json()
+ except JSONDecodeError:
+ return ErrBodyNotJSON
+ if not client:
+ return await create_client(user_id, data)
+ else:
+ return await update_client(client, data)
@routes.delete("/client/{id}")
-def delete_client(request: web.Request) -> web.Response:
+async def delete_client(request: web.Request) -> web.Response:
+ user_id = request.match_info.get("id", None)
+ client = Client.get(user_id, None)
+ if not client:
+ return ErrClientNotFound
return ErrNotImplemented
diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py
index 064c3f8..166108a 100644
--- a/maubot/management/api/instance.py
+++ b/maubot/management/api/instance.py
@@ -13,57 +13,88 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from aiohttp import web
from json import JSONDecodeError
-from mautrix.types import UserID
+from aiohttp import web
-from ...db import DBClient
+from ...db import DBPlugin
+from ...instance import PluginInstance
+from ...loader import PluginLoader
from ...client import Client
from .base import routes
-from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
+from .responses import (ErrInstanceNotFound, ErrBodyNotJSON, RespDeleted, ErrPrimaryUserNotFound,
+ ErrPluginTypeRequired, ErrPrimaryUserRequired, ErrPluginTypeNotFound)
@routes.get("/instances")
async def get_instances(_: web.Request) -> web.Response:
- return web.json_response([client.to_dict() for client in Client.cache.values()])
+ return web.json_response([instance.to_dict() for instance in PluginInstance.cache.values()])
@routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
- if not client:
- return ErrClientNotFound
- return web.json_response(client.to_dict())
+ instance_id = request.match_info.get("id", "").lower()
+ instance = PluginInstance.get(instance_id, None)
+ if not instance:
+ return ErrInstanceNotFound
+ return web.json_response(instance.to_dict())
-async def create_instance(user_id: UserID, data: dict) -> web.Response:
- return ErrNotImplemented
+async def create_instance(instance_id: str, data: dict) -> web.Response:
+ plugin_type = data.get("type", None)
+ primary_user = data.get("primary_user", None)
+ if not plugin_type:
+ return ErrPluginTypeRequired
+ elif not primary_user:
+ return ErrPrimaryUserRequired
+ elif not Client.get(primary_user):
+ return ErrPrimaryUserNotFound
+ try:
+ PluginLoader.find(plugin_type)
+ except KeyError:
+ return ErrPluginTypeNotFound
+ db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
+ primary_user=primary_user, config=data.get("config", ""))
+ instance = PluginInstance(db_instance)
+ instance.load()
+ PluginInstance.db.add(db_instance)
+ PluginInstance.db.commit()
+ await instance.start()
+ return web.json_response(instance.to_dict())
-async def update_instance(client: Client, data: dict) -> web.Response:
- return ErrNotImplemented
+async def update_instance(instance: PluginInstance, data: dict) -> web.Response:
+ if not await instance.update_primary_user(data.get("primary_user")):
+ return ErrPrimaryUserNotFound
+ instance.update_id(data.get("id", None))
+ instance.update_enabled(data.get("enabled", None))
+ instance.update_config(data.get("config", None))
+ await instance.update_started(data.get("started", None))
+ instance.db.commit()
+ return web.json_response(instance.to_dict())
@routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
+ instance_id = request.match_info.get("id", "").lower()
+ instance = PluginInstance.get(instance_id, None)
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
- if not client:
- return await create_instance(user_id, data)
+ if not instance:
+ return await create_instance(instance_id, data)
else:
- return await update_instance(client, data)
+ return await update_instance(instance, data)
@routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
- if not client:
- return ErrClientNotFound
- return ErrNotImplemented
+ instance_id = request.match_info.get("id", "").lower()
+ instance = PluginInstance.get(instance_id, None)
+ if not instance:
+ return ErrInstanceNotFound
+ if instance.started:
+ await instance.stop()
+ instance.delete()
+ return RespDeleted
diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py
index ad3033b..16efecb 100644
--- a/maubot/management/api/responses.py
+++ b/maubot/management/api/responses.py
@@ -41,6 +41,31 @@ ErrClientNotFound = web.json_response({
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
+ErrPrimaryUserNotFound = web.json_response({
+ "error": "Client for given primary user not found",
+ "errcode": "primary_user_not_found",
+}, status=HTTPStatus.NOT_FOUND)
+
+ErrInstanceNotFound = web.json_response({
+ "error": "Plugin instance not found",
+ "errcode": "instance_not_found",
+}, status=HTTPStatus.NOT_FOUND)
+
+ErrPluginTypeNotFound = web.json_response({
+ "error": "Given plugin type not found",
+ "errcode": "plugin_type_not_found",
+}, status=HTTPStatus.NOT_FOUND)
+
+ErrPluginTypeRequired = web.json_response({
+ "error": "Plugin type is required when creating plugin instances",
+ "errcode": "plugin_type_required",
+}, status=HTTPStatus.BAD_REQUEST)
+
+ErrPrimaryUserRequired = web.json_response({
+ "error": "Primary user is required when creating plugin instances",
+ "errcode": "primary_user_required",
+}, status=HTTPStatus.BAD_REQUEST)
+
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml
index 8cfde21..7b7ebd9 100644
--- a/maubot/management/api/spec.yaml
+++ b/maubot/management/api/spec.yaml
@@ -346,6 +346,9 @@ components:
primary_user:
type: string
example: '@putkiteippi:maunium.net'
+ config:
+ type: string
+ example: "YAML"
MatrixClient:
type: object
properties:
diff --git a/maubot/management/frontend/src/index.js b/maubot/management/frontend/src/index.js
index 294a680..8f50972 100644
--- a/maubot/management/frontend/src/index.js
+++ b/maubot/management/frontend/src/index.js
@@ -15,7 +15,7 @@
// along with this program. If not, see .
import React from "react"
import ReactDOM from "react-dom"
-import "./style/base"
+import "./style/index.sass"
import MaubotManager from "./MaubotManager"
ReactDOM.render(, document.getElementById("root"))
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index 9b394f1..3925513 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
from .command_spec import CommandSpec
from mautrix.util.config import BaseProxyConfig
-
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
@@ -69,3 +68,7 @@ class Plugin(ABC):
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
return None
+
+ def on_external_config_update(self) -> None:
+ if self.config:
+ self.config.load_and_update()
diff --git a/setup.py b/setup.py
index 0d8dad0..2995b45 100644
--- a/setup.py
+++ b/setup.py
@@ -48,4 +48,8 @@ setuptools.setup(
data_files=[
(".", ["example-config.yaml"]),
],
+ package_data={
+ "maubot": ["management/frontend/build/*", "management/frontend/build/static/css/*",
+ "management/frontend/build/static/js/*"],
+ },
)