diff --git a/maubot/instance.py b/maubot/instance.py index 2905d12..8427e3c 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -25,7 +25,6 @@ import os.path from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap -import sqlalchemy as sql from mautrix.types import UserID from mautrix.util import background_task @@ -36,6 +35,7 @@ from mautrix.util.logging import TraceLogger from .client import Client from .db import DatabaseEngine, Instance as DBInstance +from .lib.optionalalchemy import Engine, MetaData, create_engine from .lib.plugin_db import ProxyPostgresDatabase from .loader import DatabaseType, PluginLoader, ZippedPluginLoader from .plugin_base import Plugin @@ -128,7 +128,7 @@ class PluginInstance(DBInstance): } def _introspect_sqlalchemy(self) -> dict: - metadata = sql.MetaData() + metadata = MetaData() metadata.reflect(self.inst_db) return { table.name: { @@ -214,7 +214,7 @@ class PluginInstance(DBInstance): async def get_db_tables(self) -> dict: if self.inst_db_tables is None: - if isinstance(self.inst_db, sql.engine.Engine): + if isinstance(self.inst_db, Engine): self.inst_db_tables = self._introspect_sqlalchemy() elif self.inst_db.scheme == Scheme.SQLITE: self.inst_db_tables = await self._introspect_sqlite() @@ -294,7 +294,7 @@ class PluginInstance(DBInstance): "Instance database engine is marked as Postgres, but plugin uses legacy " "database interface, which doesn't support postgres." ) - self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}") + self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}") elif self.loader.meta.database_type == DatabaseType.ASYNCPG: if self.database_engine is None: if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db: @@ -329,7 +329,7 @@ class PluginInstance(DBInstance): async def stop_database(self) -> None: if isinstance(self.inst_db, Database): await self.inst_db.stop() - elif isinstance(self.inst_db, sql.engine.Engine): + elif isinstance(self.inst_db, Engine): self.inst_db.dispose() else: raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}") diff --git a/maubot/lib/optionalalchemy.py b/maubot/lib/optionalalchemy.py new file mode 100644 index 0000000..ba94271 --- /dev/null +++ b/maubot/lib/optionalalchemy.py @@ -0,0 +1,19 @@ +try: + from sqlalchemy import MetaData, asc, create_engine, desc + from sqlalchemy.engine import Engine + from sqlalchemy.exc import IntegrityError, OperationalError +except ImportError: + + class FakeError(Exception): + pass + + class FakeType: + def __init__(self, *args, **kwargs): + raise Exception("SQLAlchemy is not installed") + + def create_engine(*args, **kwargs): + raise Exception("SQLAlchemy is not installed") + + MetaData = Engine = FakeType + IntegrityError = OperationalError = FakeError + asc = desc = lambda a: a diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 183d3e6..8642183 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -31,7 +31,7 @@ from ..config import Config from ..lib.zipimport import ZipImportError, zipimporter from ..plugin_base import Plugin from .abc import IDConflictError, PluginClass, PluginLoader -from .meta import PluginMeta +from .meta import DatabaseType, PluginMeta current_version = Version(__version__) yaml = YAML() @@ -155,9 +155,9 @@ class ZippedPluginLoader(PluginLoader): return file, meta @classmethod - def verify_meta(cls, source) -> tuple[str, Version]: + def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]: _, meta = cls._read_meta(source) - return meta.id, meta.version + return meta.id, meta.version, meta.database_type if meta.database else None def _load_meta(self) -> None: file, meta = self._read_meta(self.path) diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py index a40434f..2f8c37a 100644 --- a/maubot/management/api/instance_database.py +++ b/maubot/management/api/instance_database.py @@ -19,12 +19,12 @@ from datetime import datetime from aiohttp import web from asyncpg import PostgresError -from sqlalchemy import asc, desc, engine, exc import aiosqlite from mautrix.util.async_db import Database from ...instance import PluginInstance +from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc from .base import routes from .responses import resp @@ -66,7 +66,7 @@ async def get_table(request: web.Request) -> web.Response: except KeyError: order = [] limit = int(request.query.get("limit", "100")) - if isinstance(instance.inst_db, engine.Engine): + if isinstance(instance.inst_db, Engine): return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit)) @@ -84,7 +84,7 @@ async def query(request: web.Request) -> web.Response: except KeyError: return resp.query_missing rows_as_dict = data.get("rows_as_dict", False) - if isinstance(instance.inst_db, engine.Engine): + if isinstance(instance.inst_db, Engine): return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict) elif isinstance(instance.inst_db, Database): try: @@ -133,12 +133,12 @@ async def _execute_query_asyncpg( def _execute_query_sqlalchemy( instance: PluginInstance, sql_query: str, rows_as_dict: bool = False ) -> web.Response: - assert isinstance(instance.inst_db, engine.Engine) + assert isinstance(instance.inst_db, Engine) try: res = instance.inst_db.execute(sql_query) - except exc.IntegrityError as e: + except IntegrityError as e: return resp.sql_integrity_error(e, sql_query) - except exc.OperationalError as e: + except OperationalError as e: return resp.sql_operational_error(e, sql_query) data = { "ok": True, diff --git a/maubot/management/api/plugin_upload.py b/maubot/management/api/plugin_upload.py index ea4fd1f..4cd2c47 100644 --- a/maubot/management/api/plugin_upload.py +++ b/maubot/management/api/plugin_upload.py @@ -23,10 +23,17 @@ import traceback from aiohttp import web from packaging.version import Version -from ...loader import MaubotZipImportError, PluginLoader, ZippedPluginLoader +from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader from .base import get_config, routes from .responses import resp +try: + import sqlalchemy + + has_alchemy = True +except ImportError: + has_alchemy = False + log = logging.getLogger("maubot.server.upload") @@ -36,9 +43,11 @@ async def put_plugin(request: web.Request) -> web.Response: content = await request.read() file = BytesIO(content) try: - pid, version = ZippedPluginLoader.verify_meta(file) + pid, version, db_type = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) + if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: + return resp.sqlalchemy_not_installed if pid != plugin_id: return resp.pid_mismatch plugin = PluginLoader.id_cache.get(plugin_id, None) @@ -55,9 +64,11 @@ async def upload_plugin(request: web.Request) -> web.Response: content = await request.read() file = BytesIO(content) try: - pid, version = ZippedPluginLoader.verify_meta(file) + pid, version, db_type = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) + if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: + return resp.sqlalchemy_not_installed plugin = PluginLoader.id_cache.get(pid, None) if not plugin: return await upload_new_plugin(content, pid, version) diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 15f6a96..0f22caa 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -15,13 +15,16 @@ # along with this program. If not, see . from __future__ import annotations +from typing import TYPE_CHECKING from http import HTTPStatus from aiohttp import web from asyncpg import PostgresError -from sqlalchemy.exc import IntegrityError, OperationalError import aiosqlite +if TYPE_CHECKING: + from sqlalchemy.exc import IntegrityError, OperationalError + class _Response: @property @@ -324,6 +327,16 @@ class _Response: } ) + @property + def sqlalchemy_not_installed(self) -> web.Response: + return web.json_response( + { + "error": "This plugin requires a legacy database, but SQLAlchemy is not installed", + "errcode": "unsupported_plugin_database", + }, + status=HTTPStatus.NOT_IMPLEMENTED, + ) + @property def table_not_found(self) -> web.Response: return web.json_response( diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 5e967dc..1be15e0 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -20,7 +20,6 @@ from abc import ABC from asyncio import AbstractEventLoop from aiohttp import ClientSession -from sqlalchemy.engine.base import Engine from yarl import URL from mautrix.util.async_db import Database, UpgradeTable @@ -30,6 +29,8 @@ from mautrix.util.logging import TraceLogger from .scheduler import BasicScheduler if TYPE_CHECKING: + from sqlalchemy.engine.base import Engine + from .client import MaubotMatrixClient from .loader import BasePluginLoader from .plugin_server import PluginWebApp @@ -56,7 +57,7 @@ class Plugin(ABC): instance_id: str, log: TraceLogger, config: BaseProxyConfig | None, - database: Engine | None, + database: Engine | Database | None, webapp: PluginWebApp | None, webapp_url: str | None, loader: BasePluginLoader, diff --git a/optional-requirements.txt b/optional-requirements.txt index 0e45b97..f5b378a 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -9,3 +9,6 @@ unpaddedbase64>=1,<3 #/testing pytest pytest-asyncio + +#/legacydb +SQLAlchemy>1,<1.4 diff --git a/requirements.txt b/requirements.txt index 30cf06e..7de02dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ mautrix>=0.20.6,<0.21 aiohttp>=3,<4 yarl>=1,<2 -SQLAlchemy>=1,<1.4 asyncpg>=0.20,<0.30 aiosqlite>=0.16,<0.21 commonmark>=0.9,<1