From 21ed971d2f74fd2ea7d128e4e0b3be3dcc908e06 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 25 Mar 2022 19:45:48 +0200 Subject: [PATCH] Switch to asyncpg/aiosqlite Fixes #142 Fixes #98 Probably fixes #62 --- Dockerfile | 2 - Dockerfile.ci | 2 - alembic.ini | 83 ---- alembic/README | 1 - alembic/env.py | 92 ---- alembic/script.py.mako | 24 -- .../4b93300852aa_add_device_id_to_clients.py | 32 -- .../90aa88820eab_add_matrix_state_store.py | 47 -- .../versions/d295f8dcfa64_initial_revision.py | 50 --- ...ccd1f95544d_add_online_field_to_clients.py | 30 -- docker/run.sh | 2 +- maubot/__main__.py | 91 +++- maubot/__meta__.py | 2 +- maubot/cli/commands/logs.py | 2 +- maubot/client.py | 407 +++++++++--------- maubot/db.py | 108 ----- maubot/db/__init__.py | 13 + maubot/db/client.py | 114 +++++ maubot/db/instance.py | 75 ++++ maubot/db/upgrade/__init__.py | 5 + maubot/db/upgrade/v01_initial_revision.py | 136 ++++++ maubot/example-config.yaml | 4 +- maubot/instance.py | 194 +++++---- maubot/lib/color_log.py | 15 +- maubot/lib/{store_proxy.py => state_store.py} | 17 +- maubot/loader/abc.py | 55 +-- maubot/loader/meta.py | 53 +++ maubot/loader/zip.py | 3 +- maubot/management/api/__init__.py | 3 +- maubot/management/api/auth.py | 2 +- maubot/management/api/base.py | 10 - maubot/management/api/client.py | 68 ++- maubot/management/api/client_auth.py | 18 +- maubot/management/api/client_proxy.py | 2 +- maubot/management/api/instance.py | 50 +-- maubot/management/api/instance_database.py | 14 +- maubot/management/api/log.py | 6 +- maubot/management/api/plugin.py | 12 +- maubot/management/api/plugin_upload.py | 2 +- maubot/plugin_base.py | 5 +- optional-requirements.txt | 7 +- requirements.txt | 3 +- setup.py | 5 +- 43 files changed, 911 insertions(+), 955 deletions(-) delete mode 100644 alembic.ini delete mode 100644 alembic/README delete mode 100644 alembic/env.py delete mode 100644 alembic/script.py.mako delete mode 100644 alembic/versions/4b93300852aa_add_device_id_to_clients.py delete mode 100644 alembic/versions/90aa88820eab_add_matrix_state_store.py delete mode 100644 alembic/versions/d295f8dcfa64_initial_revision.py delete mode 100644 alembic/versions/fccd1f95544d_add_online_field_to_clients.py delete mode 100644 maubot/db.py create mode 100644 maubot/db/__init__.py create mode 100644 maubot/db/client.py create mode 100644 maubot/db/instance.py create mode 100644 maubot/db/upgrade/__init__.py create mode 100644 maubot/db/upgrade/v01_initial_revision.py rename maubot/lib/{store_proxy.py => state_store.py} (64%) create mode 100644 maubot/loader/meta.py diff --git a/Dockerfile b/Dockerfile index bd27ebe..4fd6f2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,6 @@ RUN apk add --no-cache \ py3-attrs \ py3-bcrypt \ py3-cffi \ - py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ @@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ && pip3 install -r requirements.txt -r optional-requirements.txt \ dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \ && apk del .build-deps diff --git a/Dockerfile.ci b/Dockerfile.ci index 7719d33..47655a0 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -10,7 +10,6 @@ RUN apk add --no-cache \ py3-attrs \ py3-bcrypt \ py3-cffi \ - py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ @@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ && pip3 install -r requirements.txt -r optional-requirements.txt \ dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \ && apk del .build-deps diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index 0d78e89..0000000 --- a/alembic.ini +++ /dev/null @@ -1,83 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = alembic - -# template used to generate migration files -# file_template = %%(rev)s_%%(slug)s - -# timezone to use when rendering the date -# within the migration file as well as the filename. -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; this defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path -# version_locations = %(here)s/bar %(here)s/bat alembic/versions - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks=black -# black.type=console_scripts -# black.entrypoint=black -# black.options=-l 79 - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README deleted file mode 100644 index 98e4f9c..0000000 --- a/alembic/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py deleted file mode 100644 index 9946810..0000000 --- a/alembic/env.py +++ /dev/null @@ -1,92 +0,0 @@ -from logging.config import fileConfig - -from sqlalchemy import engine_from_config, pool - -from alembic import context - -import sys -from os.path import abspath, dirname - -sys.path.insert(0, dirname(dirname(abspath(__file__)))) - -from mautrix.util.db import Base -from maubot.config import Config -from maubot import db - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml") -maubot_config = Config(maubot_config_path, None) -maubot_config.load() -config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%")) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = Base.metadata - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline(): - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online(): - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako deleted file mode 100644 index 2c01563..0000000 --- a/alembic/script.py.mako +++ /dev/null @@ -1,24 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision = ${repr(up_revision)} -down_revision = ${repr(down_revision)} -branch_labels = ${repr(branch_labels)} -depends_on = ${repr(depends_on)} - - -def upgrade(): - ${upgrades if upgrades else "pass"} - - -def downgrade(): - ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/4b93300852aa_add_device_id_to_clients.py b/alembic/versions/4b93300852aa_add_device_id_to_clients.py deleted file mode 100644 index efc71cd..0000000 --- a/alembic/versions/4b93300852aa_add_device_id_to_clients.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Add device_id to clients - -Revision ID: 4b93300852aa -Revises: fccd1f95544d -Create Date: 2020-07-11 15:49:38.831459 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '4b93300852aa' -down_revision = 'fccd1f95544d' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('client', schema=None) as batch_op: - batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True)) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('client', schema=None) as batch_op: - batch_op.drop_column('device_id') - - # ### end Alembic commands ### diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py deleted file mode 100644 index 37a68eb..0000000 --- a/alembic/versions/90aa88820eab_add_matrix_state_store.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Add Matrix state store - -Revision ID: 90aa88820eab -Revises: 4b93300852aa -Create Date: 2020-07-12 01:50:06.215623 - -""" -from alembic import op -import sqlalchemy as sa - -from mautrix.client.state_store.sqlalchemy import SerializableType -from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent - - -# revision identifiers, used by Alembic. -revision = '90aa88820eab' -down_revision = '4b93300852aa' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('mx_room_state', - sa.Column('room_id', sa.String(length=255), nullable=False), - sa.Column('is_encrypted', sa.Boolean(), nullable=True), - sa.Column('has_full_member_list', sa.Boolean(), nullable=True), - sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True), - sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True), - sa.PrimaryKeyConstraint('room_id') - ) - op.create_table('mx_user_profile', - sa.Column('room_id', sa.String(length=255), nullable=False), - sa.Column('user_id', sa.String(length=255), nullable=False), - sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False), - sa.Column('displayname', sa.String(), nullable=True), - sa.Column('avatar_url', sa.String(length=255), nullable=True), - sa.PrimaryKeyConstraint('room_id', 'user_id') - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('mx_user_profile') - op.drop_table('mx_room_state') - # ### end Alembic commands ### diff --git a/alembic/versions/d295f8dcfa64_initial_revision.py b/alembic/versions/d295f8dcfa64_initial_revision.py deleted file mode 100644 index ffa502f..0000000 --- a/alembic/versions/d295f8dcfa64_initial_revision.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Initial revision - -Revision ID: d295f8dcfa64 -Revises: -Create Date: 2019-09-27 00:21:02.527915 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'd295f8dcfa64' -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('client', - sa.Column('id', sa.String(length=255), nullable=False), - sa.Column('homeserver', sa.String(length=255), nullable=False), - sa.Column('access_token', sa.Text(), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False), - sa.Column('next_batch', sa.String(length=255), nullable=False), - sa.Column('filter_id', sa.String(length=255), nullable=False), - sa.Column('sync', sa.Boolean(), nullable=False), - sa.Column('autojoin', sa.Boolean(), nullable=False), - sa.Column('displayname', sa.String(length=255), nullable=False), - sa.Column('avatar_url', sa.String(length=255), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('plugin', - sa.Column('id', sa.String(length=255), nullable=False), - sa.Column('type', sa.String(length=255), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False), - sa.Column('primary_user', sa.String(length=255), nullable=False), - sa.Column('config', sa.Text(), nullable=False), - sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('plugin') - op.drop_table('client') - # ### end Alembic commands ### diff --git a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py b/alembic/versions/fccd1f95544d_add_online_field_to_clients.py deleted file mode 100644 index 1f7eabe..0000000 --- a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Add online field to clients - -Revision ID: fccd1f95544d -Revises: d295f8dcfa64 -Create Date: 2020-03-06 15:07:50.136644 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'fccd1f95544d' -down_revision = 'd295f8dcfa64' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("client") as batch_op: - batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true())) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("client") as batch_op: - batch_op.drop_column('online') - # ### end Alembic commands ### diff --git a/docker/run.sh b/docker/run.sh index a9a40e1..9ca3a3f 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -1,7 +1,7 @@ #!/bin/sh function fixperms { - chown -R $UID:$GID /var/log /data /opt/maubot + chown -R $UID:$GID /var/log /data } function fixdefault { diff --git a/maubot/__main__.py b/maubot/__main__.py index a29f347..f7865ea 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -13,24 +13,37 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import asyncio +from __future__ import annotations +import asyncio +import sys + +from mautrix.util.async_db import Database, DatabaseException from mautrix.util.program import Program from .__meta__ import __version__ -from .client import Client, init as init_client_class +from .client import Client from .config import Config -from .db import init as init_db -from .instance import init as init_plugin_instance_class +from .db import init as init_db, upgrade_table +from .instance import PluginInstance from .lib.future_awaitable import FutureAwaitable +from .lib.state_store import PgStateStore from .loader.zip import init as init_zip_loader from .management.api import init as init_mgmt_api from .server import MaubotServer +try: + from mautrix.crypto.store import PgCryptoStore +except ImportError: + PgCryptoStore = None + class Maubot(Program): config: Config server: MaubotServer + db: Database + crypto_db: Database | None + state_store: PgStateStore config_class = Config module = "maubot" @@ -45,6 +58,19 @@ class Maubot(Program): init(self.loop) self.add_shutdown_actions(FutureAwaitable(stop_all)) + def prepare_arg_parser(self) -> None: + super().prepare_arg_parser() + self.parser.add_argument( + "--ignore-unsupported-database", + action="store_true", + help="Run even if the database schema is too new", + ) + self.parser.add_argument( + "--ignore-foreign-tables", + action="store_true", + help="Run even if the database contains tables from other programs (like Synapse)", + ) + def prepare(self) -> None: super().prepare() @@ -52,21 +78,59 @@ class Maubot(Program): self.prepare_log_websocket() init_zip_loader(self.config) - init_db(self.config) - clients = init_client_class(self.config, self.loop) - self.add_startup_actions(*(client.start() for client in clients)) + self.db = Database.create( + self.config["database"], + upgrade_table=upgrade_table, + db_args=self.config["database_opts"], + owner_name=self.name, + ignore_foreign_tables=self.args.ignore_foreign_tables, + ) + init_db(self.db) + if self.config["crypto_database"] == "default": + self.crypto_db = self.db + else: + self.crypto_db = Database.create( + self.config["crypto_database"], + upgrade_table=PgCryptoStore.upgrade_table, + ignore_foreign_tables=self.args.ignore_foreign_tables, + ) + Client.init_cls(self) + PluginInstance.init_cls(self) management_api = init_mgmt_api(self.config, self.loop) self.server = MaubotServer(management_api, self.config, self.loop) + self.state_store = PgStateStore(self.db) - plugins = init_plugin_instance_class(self.config, self.server, self.loop) - for plugin in plugins: - plugin.load() + async def start_db(self) -> None: + self.log.debug("Starting database...") + ignore_unsupported = self.args.ignore_unsupported_database + self.db.upgrade_table.allow_unsupported = ignore_unsupported + self.state_store.upgrade_table.allow_unsupported = ignore_unsupported + PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported + try: + await self.db.start() + await self.state_store.upgrade_table.upgrade(self.db) + if self.crypto_db and self.crypto_db is not self.db: + await self.crypto_db.start() + else: + await PgCryptoStore.upgrade_table.upgrade(self.db) + except DatabaseException as e: + self.log.critical("Failed to initialize database", exc_info=e) + if e.explanation: + self.log.info(e.explanation) + sys.exit(25) + + async def system_exit(self) -> None: + if hasattr(self, "db"): + self.log.trace("Stopping database due to SystemExit") + await self.db.stop() async def start(self) -> None: - if Client.crypto_db: - self.log.debug("Starting client crypto database") - await Client.crypto_db.start() + await self.start_db() + await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()]) + await asyncio.gather(*[client.start() async for client in Client.all()]) await super().start() + async for plugin in PluginInstance.all(): + await plugin.load() await self.server.start() async def stop(self) -> None: @@ -77,6 +141,7 @@ class Maubot(Program): await asyncio.wait_for(self.server.stop(), 5) except asyncio.TimeoutError: self.log.warning("Stopping server timed out") + await self.db.stop() Maubot().run() diff --git a/maubot/__meta__.py b/maubot/__meta__.py index 3ced358..690354d 100644 --- a/maubot/__meta__.py +++ b/maubot/__meta__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.3.0+dev" diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py index 98879ee..9a9c644 100644 --- a/maubot/cli/commands/logs.py +++ b/maubot/cli/commands/logs.py @@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None: global history_count history_count = tail loop = asyncio.get_event_loop() - future = asyncio.ensure_future(view_logs(server, token), loop=loop) + future = asyncio.create_task(view_logs(server, token), loop=loop) try: loop.run_until_complete(future) except KeyboardInterrupt: diff --git a/maubot/client.py b/maubot/client.py index fa6c851..315f217 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -15,14 +15,14 @@ # along with this program. If not, see . from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast +from collections import defaultdict import asyncio import logging from aiohttp import ClientSession from mautrix.client import InternalEventType -from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore from mautrix.errors import MatrixInvalidToken from mautrix.types import ( ContentURI, @@ -41,69 +41,110 @@ from mautrix.types import ( SyncToken, UserID, ) +from mautrix.util.async_getter_lock import async_getter_lock +from mautrix.util.logging import TraceLogger -from .db import DBClient -from .lib.store_proxy import SyncStoreProxy +from .db import Client as DBClient from .matrix import MaubotMatrixClient try: - from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore - from mautrix.util.async_db import Database as AsyncDatabase - - class SQLStateStore(BaseSQLStateStore, CryptoStateStore): - pass + from mautrix.crypto import OlmMachine, PgCryptoStore crypto_import_error = None except ImportError as e: - OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None - SQLStateStore = BaseSQLStateStore + OlmMachine = PgCryptoStore = None crypto_import_error = e if TYPE_CHECKING: - from .config import Config + from .__main__ import Maubot from .instance import PluginInstance -log = logging.getLogger("maubot.client") - -class Client: - log: logging.Logger = None - loop: asyncio.AbstractEventLoop = None +class Client(DBClient): + maubot: "Maubot" = None cache: dict[UserID, Client] = {} + _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) + log: TraceLogger = logging.getLogger("maubot.client") + http_client: ClientSession = None - global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore() - crypto_db: AsyncDatabase | None = None references: set[PluginInstance] - db_instance: DBClient client: MaubotMatrixClient crypto: OlmMachine | None crypto_store: PgCryptoStore | None started: bool + sync_ok: bool remote_displayname: str | None remote_avatar_url: ContentURI | None - def __init__(self, db_instance: DBClient) -> None: - self.db_instance = db_instance + def __init__( + self, + id: UserID, + homeserver: str, + access_token: str, + device_id: DeviceID, + enabled: bool = False, + next_batch: SyncToken = "", + filter_id: FilterID = "", + sync: bool = True, + autojoin: bool = True, + online: bool = True, + displayname: str = "disable", + avatar_url: str = "disable", + ) -> None: + super().__init__( + id=id, + homeserver=homeserver, + access_token=access_token, + device_id=device_id, + enabled=bool(enabled), + next_batch=next_batch, + filter_id=filter_id, + sync=bool(sync), + autojoin=bool(autojoin), + online=bool(online), + displayname=displayname, + avatar_url=avatar_url, + ) + self._postinited = False + + def __hash__(self) -> int: + return hash(self.id) + + @classmethod + def init_cls(cls, maubot: "Maubot") -> None: + cls.maubot = maubot + + def _make_client( + self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None + ) -> MaubotMatrixClient: + return MaubotMatrixClient( + mxid=self.id, + base_url=homeserver or self.homeserver, + token=token or self.access_token, + client_session=self.http_client, + log=self.log, + crypto_log=self.log.getChild("crypto"), + loop=self.maubot.loop, + device_id=device_id or self.device_id, + sync_store=self, + state_store=self.maubot.state_store, + ) + + def postinit(self) -> None: + if self._postinited: + raise RuntimeError("postinit() called twice") + self._postinited = True self.cache[self.id] = self - self.log = log.getChild(self.id) + self.log = self.log.getChild(self.id) + self.http_client = ClientSession(loop=self.maubot.loop) self.references = set() self.started = False self.sync_ok = True self.remote_displayname = None self.remote_avatar_url = None - self.client = MaubotMatrixClient( - mxid=self.id, - base_url=self.homeserver, - token=self.access_token, - client_session=self.http_client, - log=self.log, - loop=self.loop, - device_id=self.device_id, - sync_store=SyncStoreProxy(self.db_instance), - state_store=self.global_state_store, - ) + self.client = self._make_client() if self.enable_crypto: self._prepare_crypto() else: @@ -118,6 +159,12 @@ class Client: self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) + def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]: + async def handler(data: dict[str, Any]) -> None: + self.sync_ok = ok + + return handler + @property def enable_crypto(self) -> bool: if not self.device_id: @@ -131,16 +178,21 @@ class Client: # Clear the stack trace after it's logged once to avoid spamming logs crypto_import_error = None return False - elif not self.crypto_db: + elif not self.maubot.crypto_db: self.log.warning("Client has device ID, but crypto database is not prepared") return False return True def _prepare_crypto(self) -> None: self.crypto_store = PgCryptoStore( - account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db + account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db + ) + self.crypto = OlmMachine( + self.client, + self.crypto_store, + self.maubot.state_store, + log=self.client.crypto_log, ) - self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) self.client.crypto = self.crypto def _remove_crypto_event_handlers(self) -> None: @@ -156,12 +208,6 @@ class Client: for event_type, func in handlers: self.client.remove_event_handler(event_type, func) - def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]: - async def handler(data: dict[str, Any]) -> None: - self.sync_ok = ok - - return handler - async def start(self, try_n: int | None = 0) -> None: try: if try_n > 0: @@ -196,47 +242,50 @@ class Client: whoami = await self.client.whoami() except MatrixInvalidToken as e: self.log.error(f"Invalid token: {e}. Disabling client") - self.db_instance.enabled = False + self.enabled = False + await self.update() return except Exception as e: if try_n >= 8: self.log.exception("Failed to get /account/whoami, disabling client") - self.db_instance.enabled = False + self.enabled = False + await self.update() else: self.log.warning( - f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}" + f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}" ) - _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) + _ = asyncio.create_task(self.start(try_n + 1)) return if whoami.user_id != self.id: self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") - self.db_instance.enabled = False + self.enabled = False + await self.update() return elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: self.log.error( f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}" ) - self.db_instance.enabled = False + self.enabled = False + await self.update() return if not self.filter_id: - self.db_instance.edit( - filter_id=await self.client.create_filter( - Filter( - room=RoomFilter( - timeline=RoomEventFilter( - limit=50, - lazy_load_members=True, - ), - state=StateFilter( - lazy_load_members=True, - ), + self.filter_id = await self.client.create_filter( + Filter( + room=RoomFilter( + timeline=RoomEventFilter( + limit=50, + lazy_load_members=True, ), - presence=EventFilter( - not_types=[EventType.PRESENCE], + state=StateFilter( + lazy_load_members=True, ), - ) + ), + presence=EventFilter( + not_types=[EventType.PRESENCE], + ), ) ) + await self.update() if self.displayname != "disable": await self.client.set_displayname(self.displayname) if self.avatar_url != "disable": @@ -270,18 +319,13 @@ class Client: if self.crypto: await self.crypto_store.close() - def clear_cache(self) -> None: + async def clear_cache(self) -> None: self.stop_sync() - self.db_instance.edit(filter_id="", next_batch="") + self.filter_id = FilterID("") + self.next_batch = SyncToken("") + await self.update() self.start_sync() - def delete(self) -> None: - try: - del self.cache[self.id] - except KeyError: - pass - self.db_instance.delete() - def to_dict(self) -> dict: return { "id": self.id, @@ -304,20 +348,6 @@ class Client: "instances": [instance.to_dict() for instance in self.references], } - @classmethod - def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None: - try: - return cls.cache[user_id] - except KeyError: - db_instance = db_instance or DBClient.get(user_id) - if not db_instance: - return None - return Client(db_instance) - - @classmethod - def all(cls) -> Iterable[Client]: - return (cls.get(user.id, user) for user in DBClient.all()) - async def _handle_tombstone(self, evt: StateEvent) -> None: if not evt.content.replacement_room: self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring") @@ -329,7 +359,7 @@ class Client: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room(evt.room_id) - async def update_started(self, started: bool) -> None: + async def update_started(self, started: bool | None) -> None: if started is None or started == self.started: return if started: @@ -337,23 +367,65 @@ class Client: else: await self.stop() - async def update_displayname(self, displayname: str) -> None: + async def update_enabled(self, enabled: bool | None, save: bool = True) -> None: + if enabled is None or enabled == self.enabled: + return + self.enabled = enabled + if save: + await self.update() + + async def update_displayname(self, displayname: str | None, save: bool = True) -> None: if displayname is None or displayname == self.displayname: return - self.db_instance.displayname = displayname + self.displayname = displayname if self.displayname != "disable": await self.client.set_displayname(self.displayname) else: await self._update_remote_profile() + if save: + await self.update() - async def update_avatar_url(self, avatar_url: ContentURI) -> None: + async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None: if avatar_url is None or avatar_url == self.avatar_url: return - self.db_instance.avatar_url = avatar_url + self.avatar_url = avatar_url if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) else: await self._update_remote_profile() + if save: + await self.update() + + async def update_sync(self, sync: bool | None, save: bool = True) -> None: + if sync is None or self.sync == sync: + return + self.sync = sync + if self.started: + if sync: + self.start_sync() + else: + self.stop_sync() + if save: + await self.update() + + async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None: + if autojoin is None or autojoin == self.autojoin: + return + if autojoin: + self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + else: + self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + self.autojoin = autojoin + if save: + await self.update() + + async def update_online(self, online: bool | None, save: bool = True) -> None: + if online is None or online == self.online: + return + self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE + self.online = online + if save: + await self.update() async def update_access_details( self, @@ -373,22 +445,13 @@ class Client: and device_id == self.device_id ): return - new_client = MaubotMatrixClient( - mxid=self.id, - base_url=homeserver or self.homeserver, - token=access_token or self.access_token, - loop=self.loop, - device_id=device_id, - client_session=self.http_client, - log=self.log, - state_store=self.global_state_store, - ) + new_client = self._make_client(homeserver, access_token, device_id) whoami = await new_client.whoami() if whoami.user_id != self.id: raise ValueError(f"MXID mismatch: {whoami.user_id}") elif whoami.device_id and device_id and whoami.device_id != device_id: raise ValueError(f"Device ID mismatch: {whoami.device_id}") - new_client.sync_store = SyncStoreProxy(self.db_instance) + new_client.sync_store = self self.stop_sync() # TODO this event handler transfer is pretty hacky @@ -398,9 +461,9 @@ class Client: new_client.global_event_handlers = self.client.global_event_handlers self.client = new_client - self.db_instance.homeserver = homeserver - self.db_instance.access_token = access_token - self.db_instance.device_id = device_id + self.homeserver = homeserver + self.access_token = access_token + self.device_id = device_id if self.enable_crypto: self._prepare_crypto() await self._start_crypto() @@ -413,97 +476,53 @@ class Client: profile = await self.client.get_profile(self.id) self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url - # region Properties + async def delete(self) -> None: + try: + del self.cache[self.id] + except KeyError: + pass + await super().delete() - @property - def id(self) -> UserID: - return self.db_instance.id + @classmethod + @async_getter_lock + async def get( + cls, + user_id: UserID, + *, + homeserver: str | None = None, + access_token: str | None = None, + device_id: DeviceID | None = None, + ) -> Client | None: + try: + return cls.cache[user_id] + except KeyError: + pass - @property - def homeserver(self) -> str: - return self.db_instance.homeserver + user = cast(cls, await super().get(user_id)) + if user is not None: + user.postinit() + return user - @property - def access_token(self) -> str: - return self.db_instance.access_token + if homeserver and access_token: + user = cls( + user_id, + homeserver=homeserver, + access_token=access_token, + device_id=device_id or "", + ) + await user.insert() + user.postinit() + return user - @property - def device_id(self) -> DeviceID: - return self.db_instance.device_id + return None - @property - def enabled(self) -> bool: - return self.db_instance.enabled - - @enabled.setter - def enabled(self, value: bool) -> None: - self.db_instance.enabled = value - - @property - def next_batch(self) -> SyncToken: - return self.db_instance.next_batch - - @property - def filter_id(self) -> FilterID: - return self.db_instance.filter_id - - @property - def sync(self) -> bool: - return self.db_instance.sync - - @sync.setter - def sync(self, value: bool) -> None: - if value == self.db_instance.sync: - return - self.db_instance.sync = value - if self.started: - if value: - self.start_sync() - else: - self.stop_sync() - - @property - def autojoin(self) -> bool: - return self.db_instance.autojoin - - @autojoin.setter - def autojoin(self, value: bool) -> None: - if value == self.db_instance.autojoin: - return - if value: - self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - else: - self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - self.db_instance.autojoin = value - - @property - def online(self) -> bool: - return self.db_instance.online - - @online.setter - def online(self, value: bool) -> None: - self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE - self.db_instance.online = value - - @property - def displayname(self) -> str: - return self.db_instance.displayname - - @property - def avatar_url(self) -> ContentURI: - return self.db_instance.avatar_url - - # endregion - - -def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]: - Client.http_client = ClientSession(loop=loop) - Client.loop = loop - - if OlmMachine: - db_url = config["crypto_database"] - if db_url == "default": - db_url = config["database"] - Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table) - - return Client.all() + @classmethod + async def all(cls) -> AsyncGenerator[Client, None]: + users = await super().all() + user: cls + for user in users: + try: + yield cls.cache[user.id] + except KeyError: + user.postinit() + yield user diff --git a/maubot/db.py b/maubot/db.py deleted file mode 100644 index 9f388d3..0000000 --- a/maubot/db.py +++ /dev/null @@ -1,108 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Iterable, Optional -import logging -import sys - -from sqlalchemy import Boolean, Column, ForeignKey, String, Text -from sqlalchemy.engine.base import Engine -import sqlalchemy as sql - -from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile -from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID -from mautrix.util.db import Base - -from .config import Config - - -class DBPlugin(Base): - __tablename__ = "plugin" - - id: str = Column(String(255), primary_key=True) - type: str = Column(String(255), nullable=False) - enabled: bool = Column(Boolean, nullable=False, default=False) - primary_user: UserID = Column( - String(255), - ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), - nullable=False, - ) - config: str = Column(Text, nullable=False, default="") - - @classmethod - def all(cls) -> Iterable["DBPlugin"]: - return cls._select_all() - - @classmethod - def get(cls, id: str) -> Optional["DBPlugin"]: - return cls._select_one_or_none(cls.c.id == id) - - -class DBClient(Base): - __tablename__ = "client" - - id: UserID = Column(String(255), primary_key=True) - homeserver: str = Column(String(255), nullable=False) - access_token: str = Column(Text, nullable=False) - device_id: DeviceID = Column(String(255), nullable=True) - enabled: bool = Column(Boolean, nullable=False, default=False) - - next_batch: SyncToken = Column(String(255), nullable=False, default="") - filter_id: FilterID = Column(String(255), nullable=False, default="") - - sync: bool = Column(Boolean, nullable=False, default=True) - autojoin: bool = Column(Boolean, nullable=False, default=True) - online: bool = Column(Boolean, nullable=False, default=True) - - displayname: str = Column(String(255), nullable=False, default="") - avatar_url: ContentURI = Column(String(255), nullable=False, default="") - - @classmethod - def all(cls) -> Iterable["DBClient"]: - return cls._select_all() - - @classmethod - def get(cls, id: str) -> Optional["DBClient"]: - return cls._select_one_or_none(cls.c.id == id) - - -def init(config: Config) -> Engine: - db = sql.create_engine(config["database"]) - Base.metadata.bind = db - - for table in (DBPlugin, DBClient, RoomState, UserProfile): - table.bind(db) - - if not db.has_table("alembic_version"): - log = logging.getLogger("maubot.db") - - if db.has_table("client") and db.has_table("plugin"): - log.warning( - "alembic_version table not found, but client and plugin tables found. " - "Assuming pre-Alembic database and inserting version." - ) - db.execute( - "CREATE TABLE IF NOT EXISTS alembic_version (" - " version_num VARCHAR(32) PRIMARY KEY" - ");" - ) - db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');") - else: - log.critical( - "alembic_version table not found. " "Did you forget to `alembic upgrade head`?" - ) - sys.exit(10) - - return db diff --git a/maubot/db/__init__.py b/maubot/db/__init__.py new file mode 100644 index 0000000..d6aeb09 --- /dev/null +++ b/maubot/db/__init__.py @@ -0,0 +1,13 @@ +from mautrix.util.async_db import Database + +from .client import Client +from .instance import Instance +from .upgrade import upgrade_table + + +def init(db: Database) -> None: + for table in (Client, Instance): + table.db = db + + +__all__ = ["upgrade_table", "init", "Client", "Instance"] diff --git a/maubot/db/client.py b/maubot/db/client.py new file mode 100644 index 0000000..52f3a20 --- /dev/null +++ b/maubot/db/client.py @@ -0,0 +1,114 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from asyncpg import Record +from attr import dataclass + +from mautrix.client import SyncStore +from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID +from mautrix.util.async_db import Database + +fake_db = Database.create("") if TYPE_CHECKING else None + + +@dataclass +class Client(SyncStore): + db: ClassVar[Database] = fake_db + + id: UserID + homeserver: str + access_token: str + device_id: DeviceID + enabled: bool + + next_batch: SyncToken + filter_id: FilterID + + sync: bool + autojoin: bool + online: bool + + displayname: str + avatar_url: ContentURI + + @classmethod + def _from_row(cls, row: Record | None) -> Client | None: + if row is None: + return None + return cls(**row) + + _columns = ( + "id, homeserver, access_token, device_id, enabled, next_batch, filter_id, " + "sync, autojoin, online, displayname, avatar_url" + ) + + @property + def _values(self): + return ( + self.id, + self.homeserver, + self.access_token, + self.device_id, + self.enabled, + self.next_batch, + self.filter_id, + self.sync, + self.autojoin, + self.online, + self.displayname, + self.avatar_url, + ) + + @classmethod + async def all(cls) -> list[Client]: + rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client") + return [cls._from_row(row) for row in rows] + + @classmethod + async def get(cls, id: str) -> Client | None: + q = f"SELECT {cls._columns} FROM client WHERE id=$1" + return cls._from_row(await cls.db.fetchrow(q, id)) + + async def insert(self) -> None: + q = """ + INSERT INTO client ( + id, homeserver, access_token, device_id, enabled, next_batch, filter_id, + sync, autojoin, online, displayname, avatar_url + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + """ + await self.db.execute(q, *self._values) + + async def put_next_batch(self, next_batch: SyncToken) -> None: + await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id) + self.next_batch = next_batch + + async def get_next_batch(self) -> SyncToken: + return self.next_batch + + async def update(self) -> None: + q = """ + UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5, + next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10, + displayname=$11, avatar_url=$12 + WHERE id=$1 + """ + await self.db.execute(q, *self._values) + + async def delete(self) -> None: + await self.db.execute("DELETE FROM client WHERE id=$1", self.id) diff --git a/maubot/db/instance.py b/maubot/db/instance.py new file mode 100644 index 0000000..dff7064 --- /dev/null +++ b/maubot/db/instance.py @@ -0,0 +1,75 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from asyncpg import Record +from attr import dataclass + +from mautrix.types import UserID +from mautrix.util.async_db import Database + +fake_db = Database.create("") if TYPE_CHECKING else None + + +@dataclass +class Instance: + db: ClassVar[Database] = fake_db + + id: str + type: str + enabled: bool + primary_user: UserID + config_str: str + + @classmethod + def _from_row(cls, row: Record | None) -> Instance | None: + if row is None: + return None + return cls(**row) + + @classmethod + async def all(cls) -> list[Instance]: + rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance") + return [cls._from_row(row) for row in rows] + + @classmethod + async def get(cls, id: str) -> Instance | None: + q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1" + return cls._from_row(await cls.db.fetchrow(q, id)) + + async def update_id(self, new_id: str) -> None: + await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id) + self.id = new_id + + @property + def _values(self): + return self.id, self.type, self.enabled, self.primary_user, self.config_str + + async def insert(self) -> None: + q = ( + "INSERT INTO instance (id, type, enabled, primary_user, config) " + "VALUES ($1, $2, $3, $4, $5)" + ) + await self.db.execute(q, *self._values) + + async def update(self) -> None: + q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1" + await self.db.execute(q, *self._values) + + async def delete(self) -> None: + await self.db.execute("DELETE FROM instance WHERE id=$1", self.id) diff --git a/maubot/db/upgrade/__init__.py b/maubot/db/upgrade/__init__.py new file mode 100644 index 0000000..146e713 --- /dev/null +++ b/maubot/db/upgrade/__init__.py @@ -0,0 +1,5 @@ +from mautrix.util.async_db import UpgradeTable + +upgrade_table = UpgradeTable() + +from . import v01_initial_revision diff --git a/maubot/db/upgrade/v01_initial_revision.py b/maubot/db/upgrade/v01_initial_revision.py new file mode 100644 index 0000000..2da8aff --- /dev/null +++ b/maubot/db/upgrade/v01_initial_revision.py @@ -0,0 +1,136 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from mautrix.util.async_db import Connection, Scheme + +from . import upgrade_table + +legacy_version_query = "SELECT version_num FROM alembic_version" +last_legacy_version = "90aa88820eab" + + +@upgrade_table.register(description="Initial asyncpg revision") +async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: + if await conn.table_exists("alembic_version"): + await migrate_legacy_to_v1(conn, scheme) + else: + return await create_v1_tables(conn) + + +async def create_v1_tables(conn: Connection) -> None: + await conn.execute( + """CREATE TABLE client ( + id TEXT PRIMARY KEY, + homeserver TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT NOT NULL, + enabled BOOLEAN NOT NULL, + + next_batch TEXT NOT NULL, + filter_id TEXT NOT NULL, + + sync BOOLEAN NOT NULL, + autojoin BOOLEAN NOT NULL, + online BOOLEAN NOT NULL, + + displayname TEXT NOT NULL, + avatar_url TEXT NOT NULL + )""" + ) + await conn.execute( + """CREATE TABLE instance ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + enabled BOOLEAN NOT NULL, + primary_user TEXT NOT NULL, + config TEXT NOT NULL, + FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE + )""" + ) + + +async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None: + legacy_version = await conn.fetchval(legacy_version_query) + if legacy_version != last_legacy_version: + raise RuntimeError( + "Legacy database is not on last version. " + "Please upgrade the old database with alembic or drop it completely first." + ) + await conn.execute("ALTER TABLE plugin RENAME TO instance") + await update_state_store(conn, scheme) + if scheme != Scheme.SQLITE: + await varchar_to_text(conn) + await conn.execute("DROP TABLE alembic_version") + + +async def update_state_store(conn: Connection, scheme: Scheme) -> None: + # The Matrix state store already has more or less the correct schema, so set the version + await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)") + await conn.execute("INSERT INTO mx_version (version) VALUES (2)") + if scheme != Scheme.SQLITE: + # Remove old uppercase membership type and recreate it as lowercase + await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT") + await conn.execute("DROP TYPE IF EXISTS membership") + await conn.execute( + "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" + ) + await conn.execute( + "ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership " + "USING LOWER(membership)::membership" + ) + else: + # Recreate table to remove CHECK constraint and lowercase everything + await conn.execute( + """CREATE TABLE new_mx_user_profile ( + room_id TEXT, + user_id TEXT, + membership TEXT NOT NULL + CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')), + displayname TEXT, + avatar_url TEXT, + PRIMARY KEY (room_id, user_id) + )""" + ) + await conn.execute( + """ + INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url) + SELECT room_id, user_id, LOWER(membership), displayname, avatar_url + FROM mx_user_profile + """ + ) + await conn.execute("DROP TABLE mx_user_profile") + await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile") + + +async def varchar_to_text(conn: Connection) -> None: + columns_to_adjust = { + "client": ( + "id", + "homeserver", + "device_id", + "next_batch", + "filter_id", + "displayname", + "avatar_url", + ), + "instance": ("id", "type", "primary_user"), + "mx_room_state": ("room_id",), + "mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"), + } + for table, columns in columns_to_adjust.items(): + for column in columns: + await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT') diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml index eb9bfe2..0f82e12 100644 --- a/maubot/example-config.yaml +++ b/maubot/example-config.yaml @@ -6,9 +6,7 @@ database: sqlite:///maubot.db # Separate database URL for the crypto database. "default" means use the same database as above. -# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above. -# When using postgres, using the same database for both is safe. -crypto_database: sqlite:///crypto.db +crypto_database: default plugin_directories: # The directory where uploaded new plugins should be stored. diff --git a/maubot/instance.py b/maubot/instance.py index 7d7900b..d615a72 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -15,8 +15,10 @@ # along with this program. If not, see . from __future__ import annotations -from typing import TYPE_CHECKING, Iterable -from asyncio import AbstractEventLoop +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast +from collections import defaultdict +import asyncio +import inspect import io import logging import os.path @@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap import sqlalchemy as sql from mautrix.types import UserID +from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.config import BaseProxyConfig, RecursiveDict from .client import Client -from .config import Config -from .db import DBPlugin +from .db import Instance as DBInstance from .loader import PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: - from .server import MaubotServer, PluginWebApp + from .__main__ import Maubot + from .server import PluginWebApp log = logging.getLogger("maubot.instance") @@ -44,29 +47,42 @@ yaml.indent(4) yaml.width = 200 -class PluginInstance: - webserver: MaubotServer = None - mb_config: Config = None - loop: AbstractEventLoop = None +class PluginInstance(DBInstance): + maubot: "Maubot" = None cache: dict[str, PluginInstance] = {} plugin_directories: list[str] = [] + _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) log: logging.Logger - loader: PluginLoader - client: Client - plugin: Plugin - config: BaseProxyConfig + loader: PluginLoader | None + client: Client | None + plugin: Plugin | None + config: BaseProxyConfig | None base_cfg: RecursiveDict[CommentedMap] | None base_cfg_str: str | None - inst_db: sql.engine.Engine - inst_db_tables: dict[str, sql.Table] + inst_db: sql.engine.Engine | None + inst_db_tables: dict[str, sql.Table] | None inst_webapp: PluginWebApp | None inst_webapp_url: str | None started: bool - def __init__(self, db_instance: DBPlugin): - self.db_instance = db_instance + def __init__( + self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = "" + ) -> None: + super().__init__( + id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config + ) + + def __hash__(self) -> int: + return hash(self.id) + + @classmethod + def init_cls(cls, maubot: "Maubot") -> None: + cls.maubot = maubot + + def postinit(self) -> None: self.log = log.getChild(self.id) + self.cache[self.id] = self self.config = None self.started = False self.loader = None @@ -78,7 +94,6 @@ class PluginInstance: self.inst_webapp_url = None self.base_cfg = None self.base_cfg_str = None - self.cache[self.id] = self def to_dict(self) -> dict: return { @@ -87,10 +102,10 @@ class PluginInstance: "enabled": self.enabled, "started": self.started, "primary_user": self.primary_user, - "config": self.db_instance.config, + "config": self.config_str, "base_config": self.base_cfg_str, "database": ( - self.inst_db is not None and self.mb_config["api_features.instance_database"] + self.inst_db is not None and self.maubot.config["api_features.instance_database"] ), } @@ -101,19 +116,19 @@ class PluginInstance: self.inst_db_tables = metadata.tables return self.inst_db_tables - def load(self) -> bool: + async def load(self) -> bool: if not self.loader: try: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") - self.db_instance.enabled = False + await self.update_enabled(False) return False if not self.client: - self.client = Client.get(self.primary_user) + self.client = await Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") - self.db_instance.enabled = False + await self.update_enabled(False) return False if self.loader.meta.database: self.enable_database() @@ -125,18 +140,18 @@ class PluginInstance: return True def enable_webapp(self) -> None: - self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id) + self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id) def disable_webapp(self) -> None: - self.webserver.remove_instance_webapp(self.id) + self.maubot.server.remove_instance_webapp(self.id) self.inst_webapp = None self.inst_webapp_url = None def enable_database(self) -> None: - db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) + db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id) self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") - def delete(self) -> None: + async def delete(self) -> None: if self.loader is not None: self.loader.references.remove(self) if self.client is not None: @@ -145,23 +160,23 @@ class PluginInstance: del self.cache[self.id] except KeyError: pass - self.db_instance.delete() + await super().delete() if self.inst_db: self.inst_db.dispose() ZippedPluginLoader.trash( - os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), + os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"), reason="deleted", ) if self.inst_webapp: self.disable_webapp() def load_config(self) -> CommentedMap: - return yaml.load(self.db_instance.config) + return yaml.load(self.config_str) def save_config(self, data: RecursiveDict[CommentedMap]) -> None: buf = io.StringIO() yaml.dump(data, buf) - self.db_instance.config = buf.getvalue() + self.config_str = buf.getvalue() async def start(self) -> None: if self.started: @@ -172,7 +187,7 @@ class PluginInstance: return if not self.client or not self.loader: self.log.warning("Missing plugin instance dependencies, attempting to load...") - if not self.load(): + if not await self.load(): return cls = await self.loader.load() if self.loader.meta.webapp and self.inst_webapp is None: @@ -205,7 +220,7 @@ class PluginInstance: self.config = config_class(self.load_config, base_cfg_func, self.save_config) self.plugin = cls( client=self.client.client, - loop=self.loop, + loop=self.maubot.loop, http=self.client.http_client, instance_id=self.id, log=self.log, @@ -219,7 +234,7 @@ class PluginInstance: await self.plugin.internal_start() except Exception: self.log.exception("Failed to start instance") - self.db_instance.enabled = False + await self.update_enabled(False) return self.started = True self.inst_db_tables = None @@ -241,60 +256,51 @@ class PluginInstance: self.plugin = None self.inst_db_tables = None - @classmethod - def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None: - try: - return cls.cache[instance_id] - except KeyError: - db_instance = db_instance or DBPlugin.get(instance_id) - if not db_instance: - return None - return PluginInstance(db_instance) + async def update_id(self, new_id: str | None) -> None: + if new_id is not None and new_id.lower() != self.id: + await super().update_id(new_id.lower()) - @classmethod - def all(cls) -> Iterable[PluginInstance]: - return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all()) - - def update_id(self, new_id: str) -> None: - if new_id is not None and new_id != self.id: - self.db_instance.id = new_id.lower() - - def update_config(self, config: str) -> None: - if not config or self.db_instance.config == config: + async def update_config(self, config: str | None) -> None: + if config is None or self.config_str == config: return - self.db_instance.config = config + self.config_str = config if self.started and self.plugin is not None: - self.plugin.on_external_config_update() + res = self.plugin.on_external_config_update() + if inspect.isawaitable(res): + await res + await self.update() - async def update_primary_user(self, primary_user: UserID) -> bool: - if not primary_user or primary_user == self.primary_user: + async def update_primary_user(self, primary_user: UserID | None) -> bool: + if primary_user is None or primary_user == self.primary_user: return True - client = Client.get(primary_user) + client = await Client.get(primary_user) if not client: return False await self.stop() - self.db_instance.primary_user = client.id + self.primary_user = client.id if self.client: self.client.references.remove(self) self.client = client self.client.references.add(self) + await self.update() await self.start() self.log.debug(f"Primary user switched to {self.client.id}") return True - async def update_type(self, type: str) -> bool: - if not type or type == self.type: + async def update_type(self, type: str | None) -> bool: + if type is None or type == self.type: return True try: loader = PluginLoader.find(type) except KeyError: return False await self.stop() - self.db_instance.type = loader.meta.id + self.type = loader.meta.id if self.loader: self.loader.references.remove(self) self.loader = loader self.loader.references.add(self) + await self.update() await self.start() self.log.debug(f"Type switched to {self.loader.meta.id}") return True @@ -303,39 +309,41 @@ class PluginInstance: if started is not None and started != self.started: await (self.start() if started else self.stop()) - def update_enabled(self, enabled: bool) -> None: + async def update_enabled(self, enabled: bool) -> None: if enabled is not None and enabled != self.enabled: - self.db_instance.enabled = enabled + self.enabled = enabled + await self.update() - # region Properties + @classmethod + @async_getter_lock + async def get( + cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None + ) -> PluginInstance | None: + try: + return cls.cache[instance_id] + except KeyError: + pass - @property - def id(self) -> str: - return self.db_instance.id + instance = cast(cls, await super().get(instance_id)) + if instance is not None: + instance.postinit() + return instance - @id.setter - def id(self, value: str) -> None: - self.db_instance.id = value + if type and primary_user: + instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user) + await instance.insert() + instance.postinit() + return instance - @property - def type(self) -> str: - return self.db_instance.type + return None - @property - def enabled(self) -> bool: - return self.db_instance.enabled - - @property - def primary_user(self) -> UserID: - return self.db_instance.primary_user - - # endregion - - -def init( - config: Config, webserver: MaubotServer, loop: AbstractEventLoop -) -> Iterable[PluginInstance]: - PluginInstance.mb_config = config - PluginInstance.loop = loop - PluginInstance.webserver = webserver - return PluginInstance.all() + @classmethod + async def all(cls) -> AsyncGenerator[PluginInstance, None]: + instances = await super().all() + instance: PluginInstance + for instance in instances: + try: + yield cls.cache[instance.id] + except KeyError: + instance.postinit() + yield instance diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py index 104e9f7..4fb94e0 100644 --- a/maubot/lib/color_log.py +++ b/maubot/lib/color_log.py @@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue class ColorFormatter(BaseColorFormatter): def _color_name(self, module: str) -> str: client = "maubot.client" - if module.startswith(client): - return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}" + if module.startswith(client + "."): + suffix = "" + if module.endswith(".crypto"): + suffix = f".{MAU_COLOR}crypto{RESET}" + module = module[: -len(".crypto")] + module = module[len(client) + 1 :] + return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}" instance = "maubot.instance" - if module.startswith(instance): + if module.startswith(instance + "."): return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" loader = "maubot.loader" - if module.startswith(loader): + if module.startswith(loader + "."): return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" - if module.startswith("maubot"): + if module.startswith("maubot."): return f"{MAU_COLOR}{module}{RESET}" return super()._color_name(module) diff --git a/maubot/lib/store_proxy.py b/maubot/lib/state_store.py similarity index 64% rename from maubot/lib/store_proxy.py rename to maubot/lib/state_store.py index d8fa234..81fb5fd 100644 --- a/maubot/lib/store_proxy.py +++ b/maubot/lib/state_store.py @@ -13,16 +13,15 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.client import SyncStore -from mautrix.types import SyncToken +from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore +try: + from mautrix.crypto import StateStore as CryptoStateStore -class SyncStoreProxy(SyncStore): - def __init__(self, db_instance) -> None: - self.db_instance = db_instance + class PgStateStore(BasePgStateStore, CryptoStateStore): + pass - async def put_next_batch(self, next_batch: SyncToken) -> None: - self.db_instance.edit(next_batch=next_batch) +except ImportError as e: + PgStateStore = BasePgStateStore - async def get_next_batch(self) -> SyncToken: - return self.db_instance.next_batch +__all__ = ["PgStateStore"] diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index f99358c..c669398 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -13,17 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar from abc import ABC, abstractmethod import asyncio -from attr import dataclass -from packaging.version import InvalidVersion, Version - -from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer - -from ..__meta__ import __version__ from ..plugin_base import Plugin +from .meta import PluginMeta if TYPE_CHECKING: from ..instance import PluginInstance @@ -35,36 +32,6 @@ class IDConflictError(Exception): pass -@serializer(Version) -def serialize_version(version: Version) -> str: - return str(version) - - -@deserializer(Version) -def deserialize_version(version: str) -> Version: - try: - return Version(version) - except InvalidVersion as e: - raise SerializerError("Invalid version") from e - - -@dataclass -class PluginMeta(SerializableAttrs): - id: str - version: Version - modules: List[str] - main_class: str - - maubot: Version = Version(__version__) - database: bool = False - config: bool = False - webapp: bool = False - license: str = "" - extra_files: List[str] = [] - dependencies: List[str] = [] - soft_dependencies: List[str] = [] - - class BasePluginLoader(ABC): meta: PluginMeta @@ -80,25 +47,25 @@ class BasePluginLoader(ABC): async def read_file(self, path: str) -> bytes: pass - def sync_list_files(self, directory: str) -> List[str]: + def sync_list_files(self, directory: str) -> list[str]: raise NotImplementedError("This loader doesn't support synchronous operations") @abstractmethod - async def list_files(self, directory: str) -> List[str]: + async def list_files(self, directory: str) -> list[str]: pass class PluginLoader(BasePluginLoader, ABC): - id_cache: Dict[str, "PluginLoader"] = {} + id_cache: dict[str, PluginLoader] = {} meta: PluginMeta - references: Set["PluginInstance"] + references: set[PluginInstance] def __init__(self): self.references = set() @classmethod - def find(cls, plugin_id: str) -> "PluginLoader": + def find(cls, plugin_id: str) -> PluginLoader: return cls.id_cache[plugin_id] def to_dict(self) -> dict: @@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC): ) @abstractmethod - async def load(self) -> Type[PluginClass]: + async def load(self) -> type[PluginClass]: pass @abstractmethod - async def reload(self) -> Type[PluginClass]: + async def reload(self) -> type[PluginClass]: pass @abstractmethod diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py new file mode 100644 index 0000000..7d44483 --- /dev/null +++ b/maubot/loader/meta.py @@ -0,0 +1,53 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import List + +from attr import dataclass +from packaging.version import InvalidVersion, Version + +from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer + +from ..__meta__ import __version__ + + +@serializer(Version) +def serialize_version(version: Version) -> str: + return str(version) + + +@deserializer(Version) +def deserialize_version(version: str) -> Version: + try: + return Version(version) + except InvalidVersion as e: + raise SerializerError("Invalid version") from e + + +@dataclass +class PluginMeta(SerializableAttrs): + id: str + version: Version + modules: List[str] + main_class: str + + maubot: Version = Version(__version__) + database: bool = False + config: bool = False + webapp: bool = False + license: str = "" + extra_files: List[str] = [] + dependencies: List[str] = [] + soft_dependencies: List[str] = [] diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 62db112..739656f 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -29,7 +29,8 @@ from mautrix.types import SerializerError from ..config import Config from ..lib.zipimport import ZipImportError, zipimporter from ..plugin_base import Plugin -from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta +from .abc import IDConflictError, PluginClass, PluginLoader +from .meta import PluginMeta yaml = YAML() diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py index 1c4d7d3..c2e5f24 100644 --- a/maubot/management/api/__init__.py +++ b/maubot/management/api/__init__.py @@ -20,7 +20,7 @@ from aiohttp import web from ...config import Config from .auth import check_token -from .base import get_config, routes, set_config, set_loop +from .base import get_config, routes, set_config from .middleware import auth, error @@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response: def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: set_config(cfg) - set_loop(loop) for pkg, enabled in cfg["api_features"].items(): if enabled: importlib.import_module(f"maubot.management.api.{pkg}") diff --git a/maubot/management/api/auth.py b/maubot/management/api/auth.py index 76ddcf3..0abc3ad 100644 --- a/maubot/management/api/auth.py +++ b/maubot/management/api/auth.py @@ -46,7 +46,7 @@ def create_token(user: UserID) -> str: def get_token(request: web.Request) -> str: token = request.headers.get("Authorization", "") if not token or not token.startswith("Bearer "): - token = request.query.get("access_token", None) + token = request.query.get("access_token", "") else: token = token[len("Bearer ") :] return token diff --git a/maubot/management/api/base.py b/maubot/management/api/base.py index 73b2508..3d7693a 100644 --- a/maubot/management/api/base.py +++ b/maubot/management/api/base.py @@ -24,7 +24,6 @@ from ...config import Config routes: web.RouteTableDef = web.RouteTableDef() _config: Config | None = None -_loop: asyncio.AbstractEventLoop | None = None def set_config(config: Config) -> None: @@ -36,15 +35,6 @@ def get_config() -> Config: return _config -def set_loop(loop: asyncio.AbstractEventLoop) -> None: - global _loop - _loop = loop - - -def get_loop() -> asyncio.AbstractEventLoop: - return _loop - - @routes.get("/version") async def version(_: web.Request) -> web.Response: return web.json_response({"version": __version__}) diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index 0b3a239..d95286b 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ from mautrix.types import FilterID, SyncToken, UserID from ...client import Client -from ...db import DBClient from .base import routes from .responses import resp @@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response: @routes.get("/client/{id}") async def get_client(request: web.Request) -> web.Response: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + client = await Client.get(user_id) if not client: return resp.client_not_found return resp.found(client.to_dict()) @@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response: mxid="@not:a.mxid", base_url=homeserver, token=access_token, - loop=Client.loop, client_session=Client.http_client, ) try: @@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response: except MatrixConnectionError: return resp.bad_client_connection_details if user_id is None: - existing_client = Client.get(whoami.user_id, None) + existing_client = await Client.get(whoami.user_id) if existing_client is not None: return resp.user_exists elif whoami.user_id != user_id: return resp.mxid_mismatch(whoami.user_id) elif whoami.device_id and device_id and whoami.device_id != device_id: return resp.device_id_mismatch(whoami.device_id) - db_instance = DBClient( - id=whoami.user_id, - homeserver=homeserver, - access_token=access_token, - enabled=data.get("enabled", True), - next_batch=SyncToken(""), - filter_id=FilterID(""), - sync=data.get("sync", True), - autojoin=data.get("autojoin", True), - online=data.get("online", True), - displayname=data.get("displayname", "disable"), - avatar_url=data.get("avatar_url", "disable"), - device_id=device_id, + client = await Client.get( + whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id ) - client = Client(db_instance) - client.db_instance.insert() + client.enabled = data.get("enabled", True) + client.sync = data.get("sync", True) + client.autojoin = data.get("autojoin", True) + client.online = data.get("online", True) + client.displayname = data.get("displayname", "disable") + client.avatar_url = data.get("avatar_url", "disable") + await client.update() await client.start() return resp.created(client.to_dict()) @@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response: async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: try: await client.update_access_details( - data.get("access_token", None), - data.get("homeserver", None), - data.get("device_id", None), + data.get("access_token"), data.get("homeserver"), data.get("device_id") ) except MatrixInvalidToken: return resp.bad_client_access_token @@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) -> return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :]) elif str_err.startswith("Device ID mismatch"): return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :]) - with client.db_instance.edit_mode(): - await client.update_avatar_url(data.get("avatar_url", None)) - await client.update_displayname(data.get("displayname", None)) - await client.update_started(data.get("started", None)) - client.enabled = data.get("enabled", client.enabled) - client.autojoin = data.get("autojoin", client.autojoin) - client.online = data.get("online", client.online) - client.sync = data.get("sync", client.sync) - return resp.updated(client.to_dict(), is_login=is_login) + await client.update_avatar_url(data.get("avatar_url"), save=False) + await client.update_displayname(data.get("displayname"), save=False) + await client.update_started(data.get("started")) + await client.update_enabled(data.get("enabled"), save=False) + await client.update_autojoin(data.get("autojoin"), save=False) + await client.update_online(data.get("online"), save=False) + await client.update_sync(data.get("sync"), save=False) + await client.update() + return resp.updated(client.to_dict(), is_login=is_login) async def _create_or_update_client( user_id: UserID, data: dict, is_login: bool = False ) -> web.Response: - client = Client.get(user_id, None) + client = await Client.get(user_id) if not client: return await _create_client(user_id, data) else: @@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response: @routes.put("/client/{id}") async def update_client(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) + user_id = request.match_info["id"] try: data = await request.json() except JSONDecodeError: @@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response: @routes.delete("/client/{id}") async def delete_client(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + user_id = request.match_info["id"] + client = await Client.get(user_id) if not client: return resp.client_not_found if len(client.references) > 0: return resp.client_in_use if client.started: await client.stop() - client.delete() + await client.delete() return resp.deleted @routes.post("/client/{id}/clearcache") async def clear_client_cache(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + user_id = request.match_info["id"] + client = await Client.get(user_id) if not client: return resp.client_not_found - client.clear_cache() + await client.clear_cache() return resp.ok diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py index 754c0d7..c5baade 100644 --- a/maubot/management/api/client_auth.py +++ b/maubot/management/api/client_auth.py @@ -13,7 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple from http import HTTPStatus from json import JSONDecodeError import asyncio @@ -30,12 +32,12 @@ from mautrix.client import ClientAPI from mautrix.errors import MatrixRequestError from mautrix.types import LoginResponse, LoginType -from .base import get_config, get_loop, routes +from .base import get_config, routes from .client import _create_client, _create_or_update_client from .responses import resp -def known_homeservers() -> Dict[str, Dict[str, str]]: +def known_homeservers() -> dict[str, dict[str, str]]: return get_config()["homeservers"] @@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes") async def read_client_auth_request( request: web.Request, -) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]: +) -> tuple[AuthRequestInfo | None, web.Response | None]: server_name = request.match_info.get("server", None) server = known_homeservers().get(server_name, None) if not server: @@ -85,7 +87,7 @@ async def read_client_auth_request( return ( AuthRequestInfo( server_name=server_name, - client=ClientAPI(base_url=base_url, loop=get_loop()), + client=ClientAPI(base_url=base_url), secret=server.get("secret"), username=username, password=password, @@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response: sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query( {"redirectUrl": str(public_url)} ) - sso_waiters[waiter_id] = req, get_loop().create_future() + sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future() return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) -async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response: +async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response: device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) device_id = f"maubot_{device_id}" try: @@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> return web.json_response(res.serialize()) -sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {} +sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {} @routes.post("/client/auth/{server}/sso/{id}/wait") diff --git a/maubot/management/api/client_proxy.py b/maubot/management/api/client_proxy.py index dca741f..3fa682b 100644 --- a/maubot/management/api/client_proxy.py +++ b/maubot/management/api/client_proxy.py @@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024 @routes.view("/proxy/{id}/{path:_matrix/.+}") async def proxy(request: web.Request) -> web.StreamResponse: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + client = await Client.get(user_id) if not client: return resp.client_not_found diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index c875c6a..edc34bd 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -18,7 +18,6 @@ from json import JSONDecodeError from aiohttp import web from ...client import Client -from ...db import DBPlugin from ...instance import PluginInstance from ...loader import PluginLoader from .base import routes @@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response: @routes.get("/instance/{id}") async def get_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found return resp.found(instance.to_dict()) async def _create_instance(instance_id: str, data: dict) -> web.Response: - plugin_type = data.get("type", None) - primary_user = data.get("primary_user", None) + plugin_type = data.get("type") + primary_user = data.get("primary_user") if not plugin_type: return resp.plugin_type_required elif not primary_user: return resp.primary_user_required - elif not Client.get(primary_user): + elif not await Client.get(primary_user): return resp.primary_user_not_found try: PluginLoader.find(plugin_type) except KeyError: return resp.plugin_type_not_found - db_instance = DBPlugin( - id=instance_id, - type=plugin_type, - enabled=data.get("enabled", True), - primary_user=primary_user, - config=data.get("config", ""), - ) - instance = PluginInstance(db_instance) - instance.load() - instance.db_instance.insert() + instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user) + instance.enabled = data.get("enabled", True) + instance.config_str = data.get("config") or "" + await instance.update() await instance.start() return resp.created(instance.to_dict()) async def _update_instance(instance: PluginInstance, data: dict) -> web.Response: - if not await instance.update_primary_user(data.get("primary_user", None)): + if not await instance.update_primary_user(data.get("primary_user")): return resp.primary_user_not_found - with instance.db_instance.edit_mode(): - instance.update_id(data.get("id", None)) - instance.update_enabled(data.get("enabled", None)) - instance.update_config(data.get("config", None)) - await instance.update_started(data.get("started", None)) - await instance.update_type(data.get("type", None)) - return resp.updated(instance.to_dict()) + await instance.update_id(data.get("id")) + await instance.update_enabled(data.get("enabled")) + await instance.update_config(data.get("config")) + await instance.update_started(data.get("started")) + await instance.update_type(data.get("type")) + return resp.updated(instance.to_dict()) @routes.put("/instance/{id}") async def update_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) try: data = await request.json() except JSONDecodeError: @@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response: @routes.delete("/instance/{id}") async def delete_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found if instance.started: await instance.stop() - instance.delete() + await instance.delete() return resp.deleted diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py index ef7da30..25869ce 100644 --- a/maubot/management/api/instance_database.py +++ b/maubot/management/api/instance_database.py @@ -29,8 +29,8 @@ from .responses import resp @routes.get("/instance/{id}/database") async def get_database(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: @@ -65,8 +65,8 @@ def check_type(val): @routes.get("/instance/{id}/database/{table}") async def get_table(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: @@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response: ] except KeyError: order = [] - limit = int(request.query.get("limit", 100)) + limit = int(request.query.get("limit", "100")) return execute_query(instance, table.select().order_by(*order).limit(limit)) @routes.post("/instance/{id}/database/query") async def query(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py index 1c5df93..05c11d3 100644 --- a/maubot/management/api/log.py +++ b/maubot/management/api/log.py @@ -23,7 +23,7 @@ import logging from aiohttp import web, web_ws from .auth import is_valid_token -from .base import get_loop, routes +from .base import routes BUILTIN_ATTRS = { "args", @@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse: authenticated = False async def close_if_not_authenticated(): - await asyncio.sleep(5, loop=get_loop()) + await asyncio.sleep(5) if not authenticated: await ws.close(code=4000) log.debug(f"Connection from {request.remote} terminated due to no authentication") - asyncio.ensure_future(close_if_not_authenticated()) + asyncio.create_task(close_if_not_authenticated()) try: msg: web_ws.WSMessage diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index ecd3c6a..94d8d9d 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response: @routes.get("/plugin/{id}") async def get_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found return resp.found(plugin.to_dict()) @@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response: @routes.delete("/plugin/{id}") async def delete_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found elif len(plugin.references) > 0: @@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response: @routes.post("/plugin/{id}/reload") async def reload_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found diff --git a/maubot/management/api/plugin_upload.py b/maubot/management/api/plugin_upload.py index f187c71..ffedbb8 100644 --- a/maubot/management/api/plugin_upload.py +++ b/maubot/management/api/plugin_upload.py @@ -29,7 +29,7 @@ from .responses import resp @routes.put("/plugin/{id}") async def put_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) + plugin_id = request.match_info["id"] content = await request.read() file = BytesIO(content) try: diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index f6bb578..3fae788 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable from abc import ABC from asyncio import AbstractEventLoop @@ -124,6 +124,7 @@ class Plugin(ABC): def get_config_class(cls) -> type[BaseProxyConfig] | None: return None - def on_external_config_update(self) -> None: + def on_external_config_update(self) -> Awaitable[None] | None: if self.config: self.config.load_and_update() + return None diff --git a/optional-requirements.txt b/optional-requirements.txt index 0397722..f42cab6 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -1,13 +1,10 @@ # Format: #/name defines a new extras_require group called name # Uncommented lines after the group definition insert things into that group. -#/postgres -psycopg2-binary>=2,<3 -asyncpg>=0.20,<0.26 +#/sqlite +aiosqlite>=0.16,<0.18 #/encryption -asyncpg>=0.20,<0.26 -aiosqlite>=0.16,<0.18 python-olm>=3,<4 pycryptodome>=3,<4 unpaddedbase64>=1,<3 diff --git a/requirements.txt b/requirements.txt index ee3bc37..dd541e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -mautrix>=0.15.0,<0.16 +mautrix>=0.15.2,<0.16 aiohttp>=3,<4 yarl>=1,<2 SQLAlchemy>=1,<1.4 +asyncpg>=0.20,<0.26 alembic>=1,<2 commonmark>=0.9,<1 ruamel.yaml>=0.15.35,<0.18 diff --git a/setup.py b/setup.py index 574a1c6..cba8c20 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ import setuptools -import glob import os with open("requirements.txt") as reqs: @@ -57,9 +56,7 @@ setuptools.setup( mbc=maubot.cli:app """, data_files=[ - (".", ["maubot/example-config.yaml", "alembic.ini"]), - ("alembic", ["alembic/env.py"]), - ("alembic/versions", glob.glob("alembic/versions/*.py")), + (".", ["maubot/example-config.yaml"]), ], package_data={ "maubot": [