Add support for asyncpg plugin databases
This commit is contained in:
parent
4b234e4d34
commit
4d8e1475e6
@ -17,8 +17,8 @@ function fixconfig {
|
||||
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
|
||||
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
|
||||
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
|
||||
fixdefault '.plugin_directories.db' './plugins' '/data/dbs'
|
||||
fixdefault '.plugin_directories.db' './dbs' '/data/dbs'
|
||||
fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
|
||||
fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
|
||||
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
|
||||
# This doesn't need to be configurable
|
||||
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from mautrix.util.async_db import Database, DatabaseException
|
||||
from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
|
||||
from mautrix.util.program import Program
|
||||
|
||||
from .__meta__ import __version__
|
||||
@ -43,6 +43,7 @@ class Maubot(Program):
|
||||
server: MaubotServer
|
||||
db: Database
|
||||
crypto_db: Database | None
|
||||
plugin_postgres_db: PostgresDatabase | None
|
||||
state_store: PgStateStore
|
||||
|
||||
config_class = Config
|
||||
@ -71,13 +72,7 @@ class Maubot(Program):
|
||||
help="Run even if the database contains tables from other programs (like Synapse)",
|
||||
)
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
|
||||
if self.config["api_features.log"]:
|
||||
self.prepare_log_websocket()
|
||||
|
||||
init_zip_loader(self.config)
|
||||
def prepare_db(self) -> None:
|
||||
self.db = Database.create(
|
||||
self.config["database"],
|
||||
upgrade_table=upgrade_table,
|
||||
@ -86,6 +81,7 @@ class Maubot(Program):
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
init_db(self.db)
|
||||
|
||||
if self.config["crypto_database"] == "default":
|
||||
self.crypto_db = self.db
|
||||
else:
|
||||
@ -94,6 +90,40 @@ class Maubot(Program):
|
||||
upgrade_table=PgCryptoStore.upgrade_table,
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
|
||||
if self.config["plugin_databases.postgres"] == "default":
|
||||
if self.db.scheme != Scheme.POSTGRES:
|
||||
self.log.critical(
|
||||
'Using "default" as the postgres plugin database URL is only allowed if '
|
||||
"the default database is postgres."
|
||||
)
|
||||
sys.exit(24)
|
||||
assert isinstance(self.db, PostgresDatabase)
|
||||
self.plugin_postgres_db = self.db
|
||||
elif self.config["plugin_databases.postgres"]:
|
||||
plugin_db = Database.create(
|
||||
self.config["plugin_databases.postgres"],
|
||||
db_args={
|
||||
**self.config["database_opts"],
|
||||
**self.config["plugin_databases.postgres_opts"],
|
||||
},
|
||||
)
|
||||
if plugin_db.scheme != Scheme.POSTGRES:
|
||||
self.log.critical("The plugin postgres database URL must be a postgres database")
|
||||
sys.exit(24)
|
||||
assert isinstance(plugin_db, PostgresDatabase)
|
||||
self.plugin_postgres_db = plugin_db
|
||||
else:
|
||||
self.plugin_postgres_db = None
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
|
||||
if self.config["api_features.log"]:
|
||||
self.prepare_log_websocket()
|
||||
|
||||
init_zip_loader(self.config)
|
||||
self.prepare_db()
|
||||
Client.init_cls(self)
|
||||
PluginInstance.init_cls(self)
|
||||
management_api = init_mgmt_api(self.config, self.loop)
|
||||
|
@ -42,7 +42,12 @@ class Config(BaseFileConfig):
|
||||
copy("plugin_directories.upload")
|
||||
copy("plugin_directories.load")
|
||||
copy("plugin_directories.trash")
|
||||
copy("plugin_directories.db")
|
||||
if "plugin_directories.db" in self:
|
||||
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
|
||||
else:
|
||||
copy("plugin_databases.sqlite")
|
||||
copy("plugin_databases.postgres")
|
||||
copy("plugin_databases.postgres_opts")
|
||||
copy("server.hostname")
|
||||
copy("server.port")
|
||||
copy("server.public_url")
|
||||
|
@ -16,6 +16,7 @@ database_opts:
|
||||
min_size: 1
|
||||
max_size: 10
|
||||
|
||||
# Configuration for storing plugin .mbp files
|
||||
plugin_directories:
|
||||
# The directory where uploaded new plugins should be stored.
|
||||
upload: ./plugins
|
||||
@ -26,8 +27,27 @@ plugin_directories:
|
||||
# The directory where old plugin versions and conflicting plugins should be moved.
|
||||
# Set to "delete" to delete files immediately.
|
||||
trash: ./trash
|
||||
# The directory where plugin databases should be stored.
|
||||
db: ./plugins
|
||||
|
||||
# Configuration for storing plugin databases
|
||||
plugin_databases:
|
||||
# The directory where SQLite plugin databases should be stored.
|
||||
sqlite: ./plugins
|
||||
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
|
||||
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
|
||||
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
|
||||
#
|
||||
# To use the same connection pool as the default database, set to "default"
|
||||
# (the default database above must be postgres to do this).
|
||||
#
|
||||
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
|
||||
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
|
||||
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
|
||||
# SQL queries/psql commands.
|
||||
postgres: null
|
||||
# Maximum number of connections per plugin instance.
|
||||
postgres_max_conns_per_plugin: 3
|
||||
# Overrides for the default database_opts when using a non-"default" postgres connection string.
|
||||
postgres_opts: {}
|
||||
|
||||
server:
|
||||
# The IP and port to listen to.
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import inspect
|
||||
@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from .client import Client
|
||||
from .db import Instance as DBInstance
|
||||
from .loader import PluginLoader, ZippedPluginLoader
|
||||
from .lib.plugin_db import ProxyPostgresDatabase
|
||||
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .__main__ import Maubot
|
||||
from .server import PluginWebApp
|
||||
|
||||
log = logging.getLogger("maubot.instance")
|
||||
log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
|
||||
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
|
||||
|
||||
yaml = YAML()
|
||||
yaml.indent(4)
|
||||
@ -60,7 +64,7 @@ class PluginInstance(DBInstance):
|
||||
config: BaseProxyConfig | None
|
||||
base_cfg: RecursiveDict[CommentedMap] | None
|
||||
base_cfg_str: str | None
|
||||
inst_db: sql.engine.Engine | None
|
||||
inst_db: sql.engine.Engine | Database | None
|
||||
inst_db_tables: dict[str, sql.Table] | None
|
||||
inst_webapp: PluginWebApp | None
|
||||
inst_webapp_url: str | None
|
||||
@ -130,8 +134,6 @@ class PluginInstance(DBInstance):
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if self.loader.meta.database:
|
||||
self.enable_database()
|
||||
if self.loader.meta.webapp:
|
||||
self.enable_webapp()
|
||||
self.log.debug("Plugin instance dependencies loaded")
|
||||
@ -147,9 +149,9 @@ class PluginInstance(DBInstance):
|
||||
self.inst_webapp = None
|
||||
self.inst_webapp_url = None
|
||||
|
||||
def enable_database(self) -> None:
|
||||
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
||||
@property
|
||||
def _sqlite_db_path(self) -> str:
|
||||
return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
|
||||
|
||||
async def delete(self) -> None:
|
||||
if self.loader is not None:
|
||||
@ -162,11 +164,8 @@ class PluginInstance(DBInstance):
|
||||
pass
|
||||
await super().delete()
|
||||
if self.inst_db:
|
||||
self.inst_db.dispose()
|
||||
ZippedPluginLoader.trash(
|
||||
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
||||
reason="deleted",
|
||||
)
|
||||
await self.stop_database()
|
||||
await self.delete_database()
|
||||
if self.inst_webapp:
|
||||
self.disable_webapp()
|
||||
|
||||
@ -178,6 +177,56 @@ class PluginInstance(DBInstance):
|
||||
yaml.dump(data, buf)
|
||||
self.config_str = buf.getvalue()
|
||||
|
||||
async def start_database(
|
||||
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
|
||||
) -> None:
|
||||
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
|
||||
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
|
||||
instance_db_log = db_log.getChild(self.id)
|
||||
# TODO should there be a way to choose between SQLite and Postgres
|
||||
# for individual instances? Maybe checking the existence of the SQLite file.
|
||||
if self.maubot.plugin_postgres_db:
|
||||
self.inst_db = ProxyPostgresDatabase(
|
||||
pool=self.maubot.plugin_postgres_db,
|
||||
instance_id=self.id,
|
||||
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
|
||||
upgrade_table=upgrade_table,
|
||||
log=instance_db_log,
|
||||
)
|
||||
else:
|
||||
self.inst_db = Database.create(
|
||||
f"sqlite:///{self._sqlite_db_path}",
|
||||
upgrade_table=upgrade_table,
|
||||
log=instance_db_log,
|
||||
)
|
||||
if actually_start:
|
||||
await self.inst_db.start()
|
||||
else:
|
||||
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
|
||||
|
||||
async def stop_database(self) -> None:
|
||||
if isinstance(self.inst_db, Database):
|
||||
await self.inst_db.stop()
|
||||
elif isinstance(self.inst_db, sql.engine.Engine):
|
||||
self.inst_db.dispose()
|
||||
else:
|
||||
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
|
||||
|
||||
async def delete_database(self) -> None:
|
||||
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
|
||||
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
|
||||
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
|
||||
if self.inst_db is None:
|
||||
await self.start_database(None, actually_start=False)
|
||||
if isinstance(self.inst_db, ProxyPostgresDatabase):
|
||||
await self.inst_db.delete()
|
||||
else:
|
||||
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
|
||||
else:
|
||||
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
|
||||
self.inst_db = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.started:
|
||||
self.log.warning("Ignoring start() call to already started plugin")
|
||||
@ -196,9 +245,8 @@ class PluginInstance(DBInstance):
|
||||
elif not self.loader.meta.webapp and self.inst_webapp is not None:
|
||||
self.log.debug("Disabling webapp after plugin meta reload")
|
||||
self.disable_webapp()
|
||||
if self.loader.meta.database and self.inst_db is None:
|
||||
self.log.debug("Enabling database after plugin meta reload")
|
||||
self.enable_database()
|
||||
if self.loader.meta.database:
|
||||
await self.start_database(cls.get_db_upgrade_table())
|
||||
config_class = cls.get_config_class()
|
||||
if config_class:
|
||||
try:
|
||||
@ -254,6 +302,11 @@ class PluginInstance(DBInstance):
|
||||
except Exception:
|
||||
self.log.exception("Failed to stop instance")
|
||||
self.plugin = None
|
||||
if self.inst_db:
|
||||
try:
|
||||
await self.stop_database()
|
||||
except Exception:
|
||||
self.log.exception("Failed to stop instance database")
|
||||
self.inst_db_tables = None
|
||||
|
||||
async def update_id(self, new_id: str | None) -> None:
|
||||
|
94
maubot/lib/plugin_db.py
Normal file
94
maubot/lib/plugin_db.py
Normal file
@ -0,0 +1,94 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
|
||||
from mautrix.util.async_db.connection import LoggingConnection
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
remove_double_quotes = str.maketrans({'"': "_"})
|
||||
|
||||
|
||||
class ProxyPostgresDatabase(Database):
|
||||
scheme = Scheme.POSTGRES
|
||||
_underlying_pool: PostgresDatabase
|
||||
_schema: str
|
||||
_default_search_path: str
|
||||
_conn_sema: asyncio.Semaphore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: PostgresDatabase,
|
||||
instance_id: str,
|
||||
max_conns: int,
|
||||
upgrade_table: UpgradeTable | None,
|
||||
log: TraceLogger | None = None,
|
||||
) -> None:
|
||||
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
|
||||
self._underlying_pool = pool
|
||||
# Simple accidental SQL injection prevention.
|
||||
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
|
||||
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
|
||||
self._default_search_path = '"$user", public'
|
||||
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
|
||||
|
||||
async def start(self) -> None:
|
||||
async with self._underlying_pool.acquire() as conn:
|
||||
self._default_search_path = await conn.fetchval("SHOW search_path")
|
||||
self.log.debug(f"Found default search path: {self._default_search_path}")
|
||||
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
|
||||
await super().start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
while not self._conn_sema.locked():
|
||||
try:
|
||||
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
|
||||
except asyncio.TimeoutError:
|
||||
self.log.warning(
|
||||
"Failed to drain plugin database connection pool, "
|
||||
"the plugin may be leaking database connections"
|
||||
)
|
||||
break
|
||||
|
||||
async def delete(self) -> None:
|
||||
self.log.debug(f"Deleting schema {self._schema} and all data in it")
|
||||
try:
|
||||
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
|
||||
except Exception:
|
||||
self.log.warning("Failed to delete schema", exc_info=True)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self) -> LoggingConnection:
|
||||
conn: LoggingConnection
|
||||
async with self._conn_sema, self._underlying_pool.acquire() as conn:
|
||||
await conn.execute(f"SET search_path = {self._default_search_path}")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
if not conn.wrapped.is_closed():
|
||||
try:
|
||||
await conn.execute(f"SET search_path = {self._default_search_path}")
|
||||
except Exception:
|
||||
self.log.exception("Error resetting search_path after use")
|
||||
await conn.wrapped.close()
|
||||
else:
|
||||
self.log.debug("Connection was closed after use, not resetting search_path")
|
||||
|
||||
|
||||
__all__ = ["ProxyPostgresDatabase"]
|
@ -1,2 +1,3 @@
|
||||
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
|
||||
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
|
||||
from .meta import DatabaseType, PluginMeta
|
||||
from .zip import MaubotZipImportError, ZippedPluginLoader
|
||||
|
@ -18,7 +18,13 @@ from typing import List
|
||||
from attr import dataclass
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||
from mautrix.types import (
|
||||
ExtensibleEnum,
|
||||
SerializableAttrs,
|
||||
SerializerError,
|
||||
deserializer,
|
||||
serializer,
|
||||
)
|
||||
|
||||
from ..__meta__ import __version__
|
||||
|
||||
@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version:
|
||||
raise SerializerError("Invalid version") from e
|
||||
|
||||
|
||||
class DatabaseType(ExtensibleEnum):
|
||||
SQLALCHEMY = "sqlalchemy"
|
||||
ASYNCPG = "asyncpg"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMeta(SerializableAttrs):
|
||||
id: str
|
||||
@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs):
|
||||
|
||||
maubot: Version = Version(__version__)
|
||||
database: bool = False
|
||||
database_type: DatabaseType = DatabaseType.SQLALCHEMY
|
||||
config: bool = False
|
||||
webapp: bool = False
|
||||
license: str = ""
|
||||
|
@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
||||
instance.enabled = data.get("enabled", True)
|
||||
instance.config_str = data.get("config") or ""
|
||||
await instance.update()
|
||||
await instance.load()
|
||||
await instance.start()
|
||||
return resp.created(instance.to_dict())
|
||||
|
||||
|
@ -23,10 +23,11 @@ from aiohttp import ClientSession
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from yarl import URL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
from mautrix.util.logging import TraceLogger
|
||||
from mautrix.util.async_db import Database, UpgradeTable
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import MaubotMatrixClient
|
||||
from .loader import BasePluginLoader
|
||||
from .plugin_server import PluginWebApp
|
||||
@ -40,7 +41,7 @@ class Plugin(ABC):
|
||||
loop: AbstractEventLoop
|
||||
loader: BasePluginLoader
|
||||
config: BaseProxyConfig | None
|
||||
database: Engine | None
|
||||
database: Engine | Database | None
|
||||
webapp: PluginWebApp | None
|
||||
webapp_url: URL | None
|
||||
|
||||
@ -124,6 +125,10 @@ class Plugin(ABC):
|
||||
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_db_upgrade_table(cls) -> UpgradeTable | None:
|
||||
return None
|
||||
|
||||
def on_external_config_update(self) -> Awaitable[None] | None:
|
||||
if self.config:
|
||||
self.config.load_and_update()
|
||||
|
@ -1,9 +1,6 @@
|
||||
# Format: #/name defines a new extras_require group called name
|
||||
# Uncommented lines after the group definition insert things into that group.
|
||||
|
||||
#/sqlite
|
||||
aiosqlite>=0.16,<0.18
|
||||
|
||||
#/encryption
|
||||
python-olm>=3,<4
|
||||
pycryptodome>=3,<4
|
||||
|
@ -3,6 +3,7 @@ aiohttp>=3,<4
|
||||
yarl>=1,<2
|
||||
SQLAlchemy>=1,<1.4
|
||||
asyncpg>=0.20,<0.26
|
||||
aiosqlite>=0.16,<0.18
|
||||
alembic>=1,<2
|
||||
commonmark>=0.9,<1
|
||||
ruamel.yaml>=0.15.35,<0.18
|
||||
|
Loading…
Reference in New Issue
Block a user