From 4d8e1475e68d2b253e692458ea520930e1b449a6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 26 Mar 2022 13:59:49 +0200 Subject: [PATCH] Add support for asyncpg plugin databases --- docker/run.sh | 4 +- maubot/__main__.py | 46 ++++++++++++--- maubot/config.py | 7 ++- maubot/example-config.yaml | 24 +++++++- maubot/instance.py | 87 ++++++++++++++++++++++------ maubot/lib/plugin_db.py | 94 +++++++++++++++++++++++++++++++ maubot/loader/__init__.py | 3 +- maubot/loader/meta.py | 14 ++++- maubot/management/api/instance.py | 1 + maubot/plugin_base.py | 13 +++-- optional-requirements.txt | 3 - requirements.txt | 1 + 12 files changed, 258 insertions(+), 39 deletions(-) create mode 100644 maubot/lib/plugin_db.py diff --git a/docker/run.sh b/docker/run.sh index 9ca3a3f..2cd9b45 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -17,8 +17,8 @@ function fixconfig { fixdefault '.plugin_directories.upload' './plugins' '/data/plugins' fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins' fixdefault '.plugin_directories.trash' './trash' '/data/trash' - fixdefault '.plugin_directories.db' './plugins' '/data/dbs' - fixdefault '.plugin_directories.db' './dbs' '/data/dbs' + fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs' + fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs' fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log' # This doesn't need to be configurable yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml diff --git a/maubot/__main__.py b/maubot/__main__.py index f7865ea..4d2e24e 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -18,7 +18,7 @@ from __future__ import annotations import asyncio import sys -from mautrix.util.async_db import Database, DatabaseException +from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme from mautrix.util.program import Program from .__meta__ import __version__ @@ -43,6 +43,7 @@ class Maubot(Program): server: MaubotServer db: Database crypto_db: Database | None + plugin_postgres_db: PostgresDatabase | None state_store: PgStateStore config_class = Config @@ -71,13 +72,7 @@ class Maubot(Program): help="Run even if the database contains tables from other programs (like Synapse)", ) - def prepare(self) -> None: - super().prepare() - - if self.config["api_features.log"]: - self.prepare_log_websocket() - - init_zip_loader(self.config) + def prepare_db(self) -> None: self.db = Database.create( self.config["database"], upgrade_table=upgrade_table, @@ -86,6 +81,7 @@ class Maubot(Program): ignore_foreign_tables=self.args.ignore_foreign_tables, ) init_db(self.db) + if self.config["crypto_database"] == "default": self.crypto_db = self.db else: @@ -94,6 +90,40 @@ class Maubot(Program): upgrade_table=PgCryptoStore.upgrade_table, ignore_foreign_tables=self.args.ignore_foreign_tables, ) + + if self.config["plugin_databases.postgres"] == "default": + if self.db.scheme != Scheme.POSTGRES: + self.log.critical( + 'Using "default" as the postgres plugin database URL is only allowed if ' + "the default database is postgres." + ) + sys.exit(24) + assert isinstance(self.db, PostgresDatabase) + self.plugin_postgres_db = self.db + elif self.config["plugin_databases.postgres"]: + plugin_db = Database.create( + self.config["plugin_databases.postgres"], + db_args={ + **self.config["database_opts"], + **self.config["plugin_databases.postgres_opts"], + }, + ) + if plugin_db.scheme != Scheme.POSTGRES: + self.log.critical("The plugin postgres database URL must be a postgres database") + sys.exit(24) + assert isinstance(plugin_db, PostgresDatabase) + self.plugin_postgres_db = plugin_db + else: + self.plugin_postgres_db = None + + def prepare(self) -> None: + super().prepare() + + if self.config["api_features.log"]: + self.prepare_log_websocket() + + init_zip_loader(self.config) + self.prepare_db() Client.init_cls(self) PluginInstance.init_cls(self) management_api = init_mgmt_api(self.config, self.loop) diff --git a/maubot/config.py b/maubot/config.py index 2ee635e..e11fe1c 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -42,7 +42,12 @@ class Config(BaseFileConfig): copy("plugin_directories.upload") copy("plugin_directories.load") copy("plugin_directories.trash") - copy("plugin_directories.db") + if "plugin_directories.db" in self: + base["plugin_databases.sqlite"] = self["plugin_directories.db"] + else: + copy("plugin_databases.sqlite") + copy("plugin_databases.postgres") + copy("plugin_databases.postgres_opts") copy("server.hostname") copy("server.port") copy("server.public_url") diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml index a1d6ce2..d157269 100644 --- a/maubot/example-config.yaml +++ b/maubot/example-config.yaml @@ -16,6 +16,7 @@ database_opts: min_size: 1 max_size: 10 +# Configuration for storing plugin .mbp files plugin_directories: # The directory where uploaded new plugins should be stored. upload: ./plugins @@ -26,8 +27,27 @@ plugin_directories: # The directory where old plugin versions and conflicting plugins should be moved. # Set to "delete" to delete files immediately. trash: ./trash - # The directory where plugin databases should be stored. - db: ./plugins + +# Configuration for storing plugin databases +plugin_databases: + # The directory where SQLite plugin databases should be stored. + sqlite: ./plugins + # The connection URL for plugin databases. If null, all plugins will get SQLite databases. + # If set, plugins using the new asyncpg interface will get a Postgres connection instead. + # Plugins using the legacy SQLAlchemy interface will always get a SQLite connection. + # + # To use the same connection pool as the default database, set to "default" + # (the default database above must be postgres to do this). + # + # When enabled, maubot will create separate Postgres schemas in the database for each plugin. + # To view schemas in psql, use `\dn`. To view enter and interact with a specific schema, + # use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal + # SQL queries/psql commands. + postgres: null + # Maximum number of connections per plugin instance. + postgres_max_conns_per_plugin: 3 + # Overrides for the default database_opts when using a non-"default" postgres connection string. + postgres_opts: {} server: # The IP and port to listen to. diff --git a/maubot/instance.py b/maubot/instance.py index d615a72..8d9bd6e 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from collections import defaultdict import asyncio import inspect @@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap import sqlalchemy as sql from mautrix.types import UserID +from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.config import BaseProxyConfig, RecursiveDict +from mautrix.util.logging import TraceLogger from .client import Client from .db import Instance as DBInstance -from .loader import PluginLoader, ZippedPluginLoader +from .lib.plugin_db import ProxyPostgresDatabase +from .loader import DatabaseType, PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: from .__main__ import Maubot from .server import PluginWebApp -log = logging.getLogger("maubot.instance") +log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance")) +db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db")) yaml = YAML() yaml.indent(4) @@ -60,7 +64,7 @@ class PluginInstance(DBInstance): config: BaseProxyConfig | None base_cfg: RecursiveDict[CommentedMap] | None base_cfg_str: str | None - inst_db: sql.engine.Engine | None + inst_db: sql.engine.Engine | Database | None inst_db_tables: dict[str, sql.Table] | None inst_webapp: PluginWebApp | None inst_webapp_url: str | None @@ -130,8 +134,6 @@ class PluginInstance(DBInstance): self.log.error(f"Failed to get client for user {self.primary_user}") await self.update_enabled(False) return False - if self.loader.meta.database: - self.enable_database() if self.loader.meta.webapp: self.enable_webapp() self.log.debug("Plugin instance dependencies loaded") @@ -147,9 +149,9 @@ class PluginInstance(DBInstance): self.inst_webapp = None self.inst_webapp_url = None - def enable_database(self) -> None: - db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id) - self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") + @property + def _sqlite_db_path(self) -> str: + return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db") async def delete(self) -> None: if self.loader is not None: @@ -162,11 +164,8 @@ class PluginInstance(DBInstance): pass await super().delete() if self.inst_db: - self.inst_db.dispose() - ZippedPluginLoader.trash( - os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"), - reason="deleted", - ) + await self.stop_database() + await self.delete_database() if self.inst_webapp: self.disable_webapp() @@ -178,6 +177,56 @@ class PluginInstance(DBInstance): yaml.dump(data, buf) self.config_str = buf.getvalue() + async def start_database( + self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True + ) -> None: + if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: + self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}") + elif self.loader.meta.database_type == DatabaseType.ASYNCPG: + instance_db_log = db_log.getChild(self.id) + # TODO should there be a way to choose between SQLite and Postgres + # for individual instances? Maybe checking the existence of the SQLite file. + if self.maubot.plugin_postgres_db: + self.inst_db = ProxyPostgresDatabase( + pool=self.maubot.plugin_postgres_db, + instance_id=self.id, + max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"], + upgrade_table=upgrade_table, + log=instance_db_log, + ) + else: + self.inst_db = Database.create( + f"sqlite:///{self._sqlite_db_path}", + upgrade_table=upgrade_table, + log=instance_db_log, + ) + if actually_start: + await self.inst_db.start() + else: + raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") + + async def stop_database(self) -> None: + if isinstance(self.inst_db, Database): + await self.inst_db.stop() + elif isinstance(self.inst_db, sql.engine.Engine): + self.inst_db.dispose() + else: + raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}") + + async def delete_database(self) -> None: + if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: + ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") + elif self.loader.meta.database_type == DatabaseType.ASYNCPG: + if self.inst_db is None: + await self.start_database(None, actually_start=False) + if isinstance(self.inst_db, ProxyPostgresDatabase): + await self.inst_db.delete() + else: + ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") + else: + raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") + self.inst_db = None + async def start(self) -> None: if self.started: self.log.warning("Ignoring start() call to already started plugin") @@ -196,9 +245,8 @@ class PluginInstance(DBInstance): elif not self.loader.meta.webapp and self.inst_webapp is not None: self.log.debug("Disabling webapp after plugin meta reload") self.disable_webapp() - if self.loader.meta.database and self.inst_db is None: - self.log.debug("Enabling database after plugin meta reload") - self.enable_database() + if self.loader.meta.database: + await self.start_database(cls.get_db_upgrade_table()) config_class = cls.get_config_class() if config_class: try: @@ -254,6 +302,11 @@ class PluginInstance(DBInstance): except Exception: self.log.exception("Failed to stop instance") self.plugin = None + if self.inst_db: + try: + await self.stop_database() + except Exception: + self.log.exception("Failed to stop instance database") self.inst_db_tables = None async def update_id(self, new_id: str | None) -> None: diff --git a/maubot/lib/plugin_db.py b/maubot/lib/plugin_db.py new file mode 100644 index 0000000..a99d461 --- /dev/null +++ b/maubot/lib/plugin_db.py @@ -0,0 +1,94 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from contextlib import asynccontextmanager +import asyncio + +from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable +from mautrix.util.async_db.connection import LoggingConnection +from mautrix.util.logging import TraceLogger + +remove_double_quotes = str.maketrans({'"': "_"}) + + +class ProxyPostgresDatabase(Database): + scheme = Scheme.POSTGRES + _underlying_pool: PostgresDatabase + _schema: str + _default_search_path: str + _conn_sema: asyncio.Semaphore + + def __init__( + self, + pool: PostgresDatabase, + instance_id: str, + max_conns: int, + upgrade_table: UpgradeTable | None, + log: TraceLogger | None = None, + ) -> None: + super().__init__(pool.url, upgrade_table=upgrade_table, log=log) + self._underlying_pool = pool + # Simple accidental SQL injection prevention. + # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway. + self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"' + self._default_search_path = '"$user", public' + self._conn_sema = asyncio.BoundedSemaphore(max_conns) + + async def start(self) -> None: + async with self._underlying_pool.acquire() as conn: + self._default_search_path = await conn.fetchval("SHOW search_path") + self.log.debug(f"Found default search path: {self._default_search_path}") + await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}") + await super().start() + + async def stop(self) -> None: + while not self._conn_sema.locked(): + try: + await asyncio.wait_for(self._conn_sema.acquire(), timeout=3) + except asyncio.TimeoutError: + self.log.warning( + "Failed to drain plugin database connection pool, " + "the plugin may be leaking database connections" + ) + break + + async def delete(self) -> None: + self.log.debug(f"Deleting schema {self._schema} and all data in it") + try: + await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE") + except Exception: + self.log.warning("Failed to delete schema", exc_info=True) + + @asynccontextmanager + async def acquire(self) -> LoggingConnection: + conn: LoggingConnection + async with self._conn_sema, self._underlying_pool.acquire() as conn: + await conn.execute(f"SET search_path = {self._default_search_path}") + try: + yield conn + finally: + if not conn.wrapped.is_closed(): + try: + await conn.execute(f"SET search_path = {self._default_search_path}") + except Exception: + self.log.exception("Error resetting search_path after use") + await conn.wrapped.close() + else: + self.log.debug("Connection was closed after use, not resetting search_path") + + +__all__ = ["ProxyPostgresDatabase"] diff --git a/maubot/loader/__init__.py b/maubot/loader/__init__.py index c4291ed..d61be5c 100644 --- a/maubot/loader/__init__.py +++ b/maubot/loader/__init__.py @@ -1,2 +1,3 @@ -from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta +from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader +from .meta import DatabaseType, PluginMeta from .zip import MaubotZipImportError, ZippedPluginLoader diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py index 7d44483..f16937b 100644 --- a/maubot/loader/meta.py +++ b/maubot/loader/meta.py @@ -18,7 +18,13 @@ from typing import List from attr import dataclass from packaging.version import InvalidVersion, Version -from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer +from mautrix.types import ( + ExtensibleEnum, + SerializableAttrs, + SerializerError, + deserializer, + serializer, +) from ..__meta__ import __version__ @@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version: raise SerializerError("Invalid version") from e +class DatabaseType(ExtensibleEnum): + SQLALCHEMY = "sqlalchemy" + ASYNCPG = "asyncpg" + + @dataclass class PluginMeta(SerializableAttrs): id: str @@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs): maubot: Version = Version(__version__) database: bool = False + database_type: DatabaseType = DatabaseType.SQLALCHEMY config: bool = False webapp: bool = False license: str = "" diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index edc34bd..4043221 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response: instance.enabled = data.get("enabled", True) instance.config_str = data.get("config") or "" await instance.update() + await instance.load() await instance.start() return resp.created(instance.to_dict()) diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 3fae788..396a7b6 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -23,10 +23,11 @@ from aiohttp import ClientSession from sqlalchemy.engine.base import Engine from yarl import URL -if TYPE_CHECKING: - from mautrix.util.config import BaseProxyConfig - from mautrix.util.logging import TraceLogger +from mautrix.util.async_db import Database, UpgradeTable +from mautrix.util.config import BaseProxyConfig +from mautrix.util.logging import TraceLogger +if TYPE_CHECKING: from .client import MaubotMatrixClient from .loader import BasePluginLoader from .plugin_server import PluginWebApp @@ -40,7 +41,7 @@ class Plugin(ABC): loop: AbstractEventLoop loader: BasePluginLoader config: BaseProxyConfig | None - database: Engine | None + database: Engine | Database | None webapp: PluginWebApp | None webapp_url: URL | None @@ -124,6 +125,10 @@ class Plugin(ABC): def get_config_class(cls) -> type[BaseProxyConfig] | None: return None + @classmethod + def get_db_upgrade_table(cls) -> UpgradeTable | None: + return None + def on_external_config_update(self) -> Awaitable[None] | None: if self.config: self.config.load_and_update() diff --git a/optional-requirements.txt b/optional-requirements.txt index f42cab6..6d87db3 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -1,9 +1,6 @@ # Format: #/name defines a new extras_require group called name # Uncommented lines after the group definition insert things into that group. -#/sqlite -aiosqlite>=0.16,<0.18 - #/encryption python-olm>=3,<4 pycryptodome>=3,<4 diff --git a/requirements.txt b/requirements.txt index dd541e4..b51244e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ aiohttp>=3,<4 yarl>=1,<2 SQLAlchemy>=1,<1.4 asyncpg>=0.20,<0.26 +aiosqlite>=0.16,<0.18 alembic>=1,<2 commonmark>=0.9,<1 ruamel.yaml>=0.15.35,<0.18