diff --git a/docker/run.sh b/docker/run.sh
index 9ca3a3f..2cd9b45 100755
--- a/docker/run.sh
+++ b/docker/run.sh
@@ -17,8 +17,8 @@ function fixconfig {
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
- fixdefault '.plugin_directories.db' './plugins' '/data/dbs'
- fixdefault '.plugin_directories.db' './dbs' '/data/dbs'
+ fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
+ fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
# This doesn't need to be configurable
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
diff --git a/maubot/__main__.py b/maubot/__main__.py
index f7865ea..4d2e24e 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import asyncio
import sys
-from mautrix.util.async_db import Database, DatabaseException
+from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
from mautrix.util.program import Program
from .__meta__ import __version__
@@ -43,6 +43,7 @@ class Maubot(Program):
server: MaubotServer
db: Database
crypto_db: Database | None
+ plugin_postgres_db: PostgresDatabase | None
state_store: PgStateStore
config_class = Config
@@ -71,13 +72,7 @@ class Maubot(Program):
help="Run even if the database contains tables from other programs (like Synapse)",
)
- def prepare(self) -> None:
- super().prepare()
-
- if self.config["api_features.log"]:
- self.prepare_log_websocket()
-
- init_zip_loader(self.config)
+ def prepare_db(self) -> None:
self.db = Database.create(
self.config["database"],
upgrade_table=upgrade_table,
@@ -86,6 +81,7 @@ class Maubot(Program):
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
+
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
@@ -94,6 +90,40 @@ class Maubot(Program):
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
+
+ if self.config["plugin_databases.postgres"] == "default":
+ if self.db.scheme != Scheme.POSTGRES:
+ self.log.critical(
+ 'Using "default" as the postgres plugin database URL is only allowed if '
+ "the default database is postgres."
+ )
+ sys.exit(24)
+ assert isinstance(self.db, PostgresDatabase)
+ self.plugin_postgres_db = self.db
+ elif self.config["plugin_databases.postgres"]:
+ plugin_db = Database.create(
+ self.config["plugin_databases.postgres"],
+ db_args={
+ **self.config["database_opts"],
+ **self.config["plugin_databases.postgres_opts"],
+ },
+ )
+ if plugin_db.scheme != Scheme.POSTGRES:
+ self.log.critical("The plugin postgres database URL must be a postgres database")
+ sys.exit(24)
+ assert isinstance(plugin_db, PostgresDatabase)
+ self.plugin_postgres_db = plugin_db
+ else:
+ self.plugin_postgres_db = None
+
+ def prepare(self) -> None:
+ super().prepare()
+
+ if self.config["api_features.log"]:
+ self.prepare_log_websocket()
+
+ init_zip_loader(self.config)
+ self.prepare_db()
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)
diff --git a/maubot/config.py b/maubot/config.py
index 2ee635e..e11fe1c 100644
--- a/maubot/config.py
+++ b/maubot/config.py
@@ -42,7 +42,12 @@ class Config(BaseFileConfig):
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
- copy("plugin_directories.db")
+ if "plugin_directories.db" in self:
+ base["plugin_databases.sqlite"] = self["plugin_directories.db"]
+ else:
+ copy("plugin_databases.sqlite")
+ copy("plugin_databases.postgres")
+ copy("plugin_databases.postgres_opts")
copy("server.hostname")
copy("server.port")
copy("server.public_url")
diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml
index a1d6ce2..d157269 100644
--- a/maubot/example-config.yaml
+++ b/maubot/example-config.yaml
@@ -16,6 +16,7 @@ database_opts:
min_size: 1
max_size: 10
+# Configuration for storing plugin .mbp files
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
@@ -26,8 +27,27 @@ plugin_directories:
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: ./trash
- # The directory where plugin databases should be stored.
- db: ./plugins
+
+# Configuration for storing plugin databases
+plugin_databases:
+ # The directory where SQLite plugin databases should be stored.
+ sqlite: ./plugins
+ # The connection URL for plugin databases. If null, all plugins will get SQLite databases.
+ # If set, plugins using the new asyncpg interface will get a Postgres connection instead.
+ # Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
+ #
+ # To use the same connection pool as the default database, set to "default"
+ # (the default database above must be postgres to do this).
+ #
+ # When enabled, maubot will create separate Postgres schemas in the database for each plugin.
+ # To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
+ # use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
+ # SQL queries/psql commands.
+ postgres: null
+ # Maximum number of connections per plugin instance.
+ postgres_max_conns_per_plugin: 3
+ # Overrides for the default database_opts when using a non-"default" postgres connection string.
+ postgres_opts: {}
server:
# The IP and port to listen to.
diff --git a/maubot/instance.py b/maubot/instance.py
index d615a72..8d9bd6e 100644
--- a/maubot/instance.py
+++ b/maubot/instance.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
+from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from collections import defaultdict
import asyncio
import inspect
@@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
+from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
+from mautrix.util.logging import TraceLogger
from .client import Client
from .db import Instance as DBInstance
-from .loader import PluginLoader, ZippedPluginLoader
+from .lib.plugin_db import ProxyPostgresDatabase
+from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .__main__ import Maubot
from .server import PluginWebApp
-log = logging.getLogger("maubot.instance")
+log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
+db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
yaml = YAML()
yaml.indent(4)
@@ -60,7 +64,7 @@ class PluginInstance(DBInstance):
config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
- inst_db: sql.engine.Engine | None
+ inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
@@ -130,8 +134,6 @@ class PluginInstance(DBInstance):
self.log.error(f"Failed to get client for user {self.primary_user}")
await self.update_enabled(False)
return False
- if self.loader.meta.database:
- self.enable_database()
if self.loader.meta.webapp:
self.enable_webapp()
self.log.debug("Plugin instance dependencies loaded")
@@ -147,9 +149,9 @@ class PluginInstance(DBInstance):
self.inst_webapp = None
self.inst_webapp_url = None
- def enable_database(self) -> None:
- db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
- self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
+ @property
+ def _sqlite_db_path(self) -> str:
+ return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
async def delete(self) -> None:
if self.loader is not None:
@@ -162,11 +164,8 @@ class PluginInstance(DBInstance):
pass
await super().delete()
if self.inst_db:
- self.inst_db.dispose()
- ZippedPluginLoader.trash(
- os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
- reason="deleted",
- )
+ await self.stop_database()
+ await self.delete_database()
if self.inst_webapp:
self.disable_webapp()
@@ -178,6 +177,56 @@ class PluginInstance(DBInstance):
yaml.dump(data, buf)
self.config_str = buf.getvalue()
+ async def start_database(
+ self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
+ ) -> None:
+ if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
+ self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
+ elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
+ instance_db_log = db_log.getChild(self.id)
+ # TODO should there be a way to choose between SQLite and Postgres
+ # for individual instances? Maybe checking the existence of the SQLite file.
+ if self.maubot.plugin_postgres_db:
+ self.inst_db = ProxyPostgresDatabase(
+ pool=self.maubot.plugin_postgres_db,
+ instance_id=self.id,
+ max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
+ upgrade_table=upgrade_table,
+ log=instance_db_log,
+ )
+ else:
+ self.inst_db = Database.create(
+ f"sqlite:///{self._sqlite_db_path}",
+ upgrade_table=upgrade_table,
+ log=instance_db_log,
+ )
+ if actually_start:
+ await self.inst_db.start()
+ else:
+ raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
+
+ async def stop_database(self) -> None:
+ if isinstance(self.inst_db, Database):
+ await self.inst_db.stop()
+ elif isinstance(self.inst_db, sql.engine.Engine):
+ self.inst_db.dispose()
+ else:
+ raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
+
+ async def delete_database(self) -> None:
+ if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
+ ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
+ elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
+ if self.inst_db is None:
+ await self.start_database(None, actually_start=False)
+ if isinstance(self.inst_db, ProxyPostgresDatabase):
+ await self.inst_db.delete()
+ else:
+ ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
+ else:
+ raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
+ self.inst_db = None
+
async def start(self) -> None:
if self.started:
self.log.warning("Ignoring start() call to already started plugin")
@@ -196,9 +245,8 @@ class PluginInstance(DBInstance):
elif not self.loader.meta.webapp and self.inst_webapp is not None:
self.log.debug("Disabling webapp after plugin meta reload")
self.disable_webapp()
- if self.loader.meta.database and self.inst_db is None:
- self.log.debug("Enabling database after plugin meta reload")
- self.enable_database()
+ if self.loader.meta.database:
+ await self.start_database(cls.get_db_upgrade_table())
config_class = cls.get_config_class()
if config_class:
try:
@@ -254,6 +302,11 @@ class PluginInstance(DBInstance):
except Exception:
self.log.exception("Failed to stop instance")
self.plugin = None
+ if self.inst_db:
+ try:
+ await self.stop_database()
+ except Exception:
+ self.log.exception("Failed to stop instance database")
self.inst_db_tables = None
async def update_id(self, new_id: str | None) -> None:
diff --git a/maubot/lib/plugin_db.py b/maubot/lib/plugin_db.py
new file mode 100644
index 0000000..a99d461
--- /dev/null
+++ b/maubot/lib/plugin_db.py
@@ -0,0 +1,94 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from contextlib import asynccontextmanager
+import asyncio
+
+from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
+from mautrix.util.async_db.connection import LoggingConnection
+from mautrix.util.logging import TraceLogger
+
+remove_double_quotes = str.maketrans({'"': "_"})
+
+
+class ProxyPostgresDatabase(Database):
+ scheme = Scheme.POSTGRES
+ _underlying_pool: PostgresDatabase
+ _schema: str
+ _default_search_path: str
+ _conn_sema: asyncio.Semaphore
+
+ def __init__(
+ self,
+ pool: PostgresDatabase,
+ instance_id: str,
+ max_conns: int,
+ upgrade_table: UpgradeTable | None,
+ log: TraceLogger | None = None,
+ ) -> None:
+ super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
+ self._underlying_pool = pool
+ # Simple accidental SQL injection prevention.
+ # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
+ self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
+ self._default_search_path = '"$user", public'
+ self._conn_sema = asyncio.BoundedSemaphore(max_conns)
+
+ async def start(self) -> None:
+ async with self._underlying_pool.acquire() as conn:
+ self._default_search_path = await conn.fetchval("SHOW search_path")
+ self.log.debug(f"Found default search path: {self._default_search_path}")
+ await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
+ await super().start()
+
+ async def stop(self) -> None:
+ while not self._conn_sema.locked():
+ try:
+ await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
+ except asyncio.TimeoutError:
+ self.log.warning(
+ "Failed to drain plugin database connection pool, "
+ "the plugin may be leaking database connections"
+ )
+ break
+
+ async def delete(self) -> None:
+ self.log.debug(f"Deleting schema {self._schema} and all data in it")
+ try:
+ await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
+ except Exception:
+ self.log.warning("Failed to delete schema", exc_info=True)
+
+ @asynccontextmanager
+ async def acquire(self) -> LoggingConnection:
+ conn: LoggingConnection
+ async with self._conn_sema, self._underlying_pool.acquire() as conn:
+ await conn.execute(f"SET search_path = {self._default_search_path}")
+ try:
+ yield conn
+ finally:
+ if not conn.wrapped.is_closed():
+ try:
+ await conn.execute(f"SET search_path = {self._default_search_path}")
+ except Exception:
+ self.log.exception("Error resetting search_path after use")
+ await conn.wrapped.close()
+ else:
+ self.log.debug("Connection was closed after use, not resetting search_path")
+
+
+__all__ = ["ProxyPostgresDatabase"]
diff --git a/maubot/loader/__init__.py b/maubot/loader/__init__.py
index c4291ed..d61be5c 100644
--- a/maubot/loader/__init__.py
+++ b/maubot/loader/__init__.py
@@ -1,2 +1,3 @@
-from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
+from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
+from .meta import DatabaseType, PluginMeta
from .zip import MaubotZipImportError, ZippedPluginLoader
diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py
index 7d44483..f16937b 100644
--- a/maubot/loader/meta.py
+++ b/maubot/loader/meta.py
@@ -18,7 +18,13 @@ from typing import List
from attr import dataclass
from packaging.version import InvalidVersion, Version
-from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
+from mautrix.types import (
+ ExtensibleEnum,
+ SerializableAttrs,
+ SerializerError,
+ deserializer,
+ serializer,
+)
from ..__meta__ import __version__
@@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version:
raise SerializerError("Invalid version") from e
+class DatabaseType(ExtensibleEnum):
+ SQLALCHEMY = "sqlalchemy"
+ ASYNCPG = "asyncpg"
+
+
@dataclass
class PluginMeta(SerializableAttrs):
id: str
@@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs):
maubot: Version = Version(__version__)
database: bool = False
+ database_type: DatabaseType = DatabaseType.SQLALCHEMY
config: bool = False
webapp: bool = False
license: str = ""
diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py
index edc34bd..4043221 100644
--- a/maubot/management/api/instance.py
+++ b/maubot/management/api/instance.py
@@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
instance.enabled = data.get("enabled", True)
instance.config_str = data.get("config") or ""
await instance.update()
+ await instance.load()
await instance.start()
return resp.created(instance.to_dict())
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index 3fae788..396a7b6 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -23,10 +23,11 @@ from aiohttp import ClientSession
from sqlalchemy.engine.base import Engine
from yarl import URL
-if TYPE_CHECKING:
- from mautrix.util.config import BaseProxyConfig
- from mautrix.util.logging import TraceLogger
+from mautrix.util.async_db import Database, UpgradeTable
+from mautrix.util.config import BaseProxyConfig
+from mautrix.util.logging import TraceLogger
+if TYPE_CHECKING:
from .client import MaubotMatrixClient
from .loader import BasePluginLoader
from .plugin_server import PluginWebApp
@@ -40,7 +41,7 @@ class Plugin(ABC):
loop: AbstractEventLoop
loader: BasePluginLoader
config: BaseProxyConfig | None
- database: Engine | None
+ database: Engine | Database | None
webapp: PluginWebApp | None
webapp_url: URL | None
@@ -124,6 +125,10 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None
+ @classmethod
+ def get_db_upgrade_table(cls) -> UpgradeTable | None:
+ return None
+
def on_external_config_update(self) -> Awaitable[None] | None:
if self.config:
self.config.load_and_update()
diff --git a/optional-requirements.txt b/optional-requirements.txt
index f42cab6..6d87db3 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -1,9 +1,6 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
-#/sqlite
-aiosqlite>=0.16,<0.18
-
#/encryption
python-olm>=3,<4
pycryptodome>=3,<4
diff --git a/requirements.txt b/requirements.txt
index dd541e4..b51244e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4
asyncpg>=0.20,<0.26
+aiosqlite>=0.16,<0.18
alembic>=1,<2
commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18