Switch to asyncpg/aiosqlite

Fixes #142
Fixes #98
Probably fixes #62
This commit is contained in:
Tulir Asokan 2022-03-25 19:45:48 +02:00
parent 068e268c63
commit 21ed971d2f
43 changed files with 911 additions and 955 deletions

View File

@ -15,7 +15,6 @@ RUN apk add --no-cache \
py3-attrs \ py3-attrs \
py3-bcrypt \ py3-bcrypt \
py3-cffi \ py3-cffi \
py3-psycopg2 \
py3-ruamel.yaml \ py3-ruamel.yaml \
py3-jinja2 \ py3-jinja2 \
py3-click \ py3-click \
@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \ RUN apk add --virtual .build-deps python3-dev build-base git \
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \ && pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \ dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps && apk del .build-deps

View File

@ -10,7 +10,6 @@ RUN apk add --no-cache \
py3-attrs \ py3-attrs \
py3-bcrypt \ py3-bcrypt \
py3-cffi \ py3-cffi \
py3-psycopg2 \
py3-ruamel.yaml \ py3-ruamel.yaml \
py3-jinja2 \ py3-jinja2 \
py3-click \ py3-click \
@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \ RUN apk add --virtual .build-deps python3-dev build-base git \
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \ && pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \ dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps && apk del .build-deps

View File

@ -1,83 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks=black
# black.type=console_scripts
# black.entrypoint=black
# black.options=-l 79
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1 +0,0 @@
Generic single-database configuration.

View File

@ -1,92 +0,0 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import sys
from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from mautrix.util.db import Base
from maubot.config import Config
from maubot import db
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
maubot_config = Config(maubot_config_path, None)
maubot_config.load()
config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -1,24 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@ -1,32 +0,0 @@
"""Add device_id to clients
Revision ID: 4b93300852aa
Revises: fccd1f95544d
Create Date: 2020-07-11 15:49:38.831459
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4b93300852aa'
down_revision = 'fccd1f95544d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('client', schema=None) as batch_op:
batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('client', schema=None) as batch_op:
batch_op.drop_column('device_id')
# ### end Alembic commands ###

View File

@ -1,47 +0,0 @@
"""Add Matrix state store
Revision ID: 90aa88820eab
Revises: 4b93300852aa
Create Date: 2020-07-12 01:50:06.215623
"""
from alembic import op
import sqlalchemy as sa
from mautrix.client.state_store.sqlalchemy import SerializableType
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
# revision identifiers, used by Alembic.
revision = '90aa88820eab'
down_revision = '4b93300852aa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mx_room_state',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
sa.PrimaryKeyConstraint('room_id')
)
op.create_table('mx_user_profile',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=False),
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
sa.Column('displayname', sa.String(), nullable=True),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint('room_id', 'user_id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mx_user_profile')
op.drop_table('mx_room_state')
# ### end Alembic commands ###

View File

@ -1,50 +0,0 @@
"""Initial revision
Revision ID: d295f8dcfa64
Revises:
Create Date: 2019-09-27 00:21:02.527915
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd295f8dcfa64'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('client',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('homeserver', sa.String(length=255), nullable=False),
sa.Column('access_token', sa.Text(), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('next_batch', sa.String(length=255), nullable=False),
sa.Column('filter_id', sa.String(length=255), nullable=False),
sa.Column('sync', sa.Boolean(), nullable=False),
sa.Column('autojoin', sa.Boolean(), nullable=False),
sa.Column('displayname', sa.String(length=255), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('plugin',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('primary_user', sa.String(length=255), nullable=False),
sa.Column('config', sa.Text(), nullable=False),
sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('plugin')
op.drop_table('client')
# ### end Alembic commands ###

View File

@ -1,30 +0,0 @@
"""Add online field to clients
Revision ID: fccd1f95544d
Revises: d295f8dcfa64
Create Date: 2020-03-06 15:07:50.136644
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fccd1f95544d'
down_revision = 'd295f8dcfa64'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client") as batch_op:
batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client") as batch_op:
batch_op.drop_column('online')
# ### end Alembic commands ###

View File

@ -1,7 +1,7 @@
#!/bin/sh #!/bin/sh
function fixperms { function fixperms {
chown -R $UID:$GID /var/log /data /opt/maubot chown -R $UID:$GID /var/log /data
} }
function fixdefault { function fixdefault {

View File

@ -13,24 +13,37 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import asyncio from __future__ import annotations
import asyncio
import sys
from mautrix.util.async_db import Database, DatabaseException
from mautrix.util.program import Program from mautrix.util.program import Program
from .__meta__ import __version__ from .__meta__ import __version__
from .client import Client, init as init_client_class from .client import Client
from .config import Config from .config import Config
from .db import init as init_db from .db import init as init_db, upgrade_table
from .instance import init as init_plugin_instance_class from .instance import PluginInstance
from .lib.future_awaitable import FutureAwaitable from .lib.future_awaitable import FutureAwaitable
from .lib.state_store import PgStateStore
from .loader.zip import init as init_zip_loader from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api from .management.api import init as init_mgmt_api
from .server import MaubotServer from .server import MaubotServer
try:
from mautrix.crypto.store import PgCryptoStore
except ImportError:
PgCryptoStore = None
class Maubot(Program): class Maubot(Program):
config: Config config: Config
server: MaubotServer server: MaubotServer
db: Database
crypto_db: Database | None
state_store: PgStateStore
config_class = Config config_class = Config
module = "maubot" module = "maubot"
@ -45,6 +58,19 @@ class Maubot(Program):
init(self.loop) init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all)) self.add_shutdown_actions(FutureAwaitable(stop_all))
def prepare_arg_parser(self) -> None:
super().prepare_arg_parser()
self.parser.add_argument(
"--ignore-unsupported-database",
action="store_true",
help="Run even if the database schema is too new",
)
self.parser.add_argument(
"--ignore-foreign-tables",
action="store_true",
help="Run even if the database contains tables from other programs (like Synapse)",
)
def prepare(self) -> None: def prepare(self) -> None:
super().prepare() super().prepare()
@ -52,21 +78,59 @@ class Maubot(Program):
self.prepare_log_websocket() self.prepare_log_websocket()
init_zip_loader(self.config) init_zip_loader(self.config)
init_db(self.config) self.db = Database.create(
clients = init_client_class(self.config, self.loop) self.config["database"],
self.add_startup_actions(*(client.start() for client in clients)) upgrade_table=upgrade_table,
db_args=self.config["database_opts"],
owner_name=self.name,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
self.crypto_db = Database.create(
self.config["crypto_database"],
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop) management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop) self.server = MaubotServer(management_api, self.config, self.loop)
self.state_store = PgStateStore(self.db)
plugins = init_plugin_instance_class(self.config, self.server, self.loop) async def start_db(self) -> None:
for plugin in plugins: self.log.debug("Starting database...")
plugin.load() ignore_unsupported = self.args.ignore_unsupported_database
self.db.upgrade_table.allow_unsupported = ignore_unsupported
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
try:
await self.db.start()
await self.state_store.upgrade_table.upgrade(self.db)
if self.crypto_db and self.crypto_db is not self.db:
await self.crypto_db.start()
else:
await PgCryptoStore.upgrade_table.upgrade(self.db)
except DatabaseException as e:
self.log.critical("Failed to initialize database", exc_info=e)
if e.explanation:
self.log.info(e.explanation)
sys.exit(25)
async def system_exit(self) -> None:
if hasattr(self, "db"):
self.log.trace("Stopping database due to SystemExit")
await self.db.stop()
async def start(self) -> None: async def start(self) -> None:
if Client.crypto_db: await self.start_db()
self.log.debug("Starting client crypto database") await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
await Client.crypto_db.start() await asyncio.gather(*[client.start() async for client in Client.all()])
await super().start() await super().start()
async for plugin in PluginInstance.all():
await plugin.load()
await self.server.start() await self.server.start()
async def stop(self) -> None: async def stop(self) -> None:
@ -77,6 +141,7 @@ class Maubot(Program):
await asyncio.wait_for(self.server.stop(), 5) await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.log.warning("Stopping server timed out") self.log.warning("Stopping server timed out")
await self.db.stop()
Maubot().run() Maubot().run()

View File

@ -1 +1 @@
__version__ = "0.2.1" __version__ = "0.3.0+dev"

View File

@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
global history_count global history_count
history_count = tail history_count = tail
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
future = asyncio.ensure_future(view_logs(server, token), loop=loop) future = asyncio.create_task(view_logs(server, token), loop=loop)
try: try:
loop.run_until_complete(future) loop.run_until_complete(future)
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
from collections import defaultdict
import asyncio import asyncio
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from mautrix.client import InternalEventType from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken from mautrix.errors import MatrixInvalidToken
from mautrix.types import ( from mautrix.types import (
ContentURI, ContentURI,
@ -41,69 +41,110 @@ from mautrix.types import (
SyncToken, SyncToken,
UserID, UserID,
) )
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.logging import TraceLogger
from .db import DBClient from .db import Client as DBClient
from .lib.store_proxy import SyncStoreProxy
from .matrix import MaubotMatrixClient from .matrix import MaubotMatrixClient
try: try:
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore from mautrix.crypto import OlmMachine, PgCryptoStore
from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
crypto_import_error = None crypto_import_error = None
except ImportError as e: except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None OlmMachine = PgCryptoStore = None
SQLStateStore = BaseSQLStateStore
crypto_import_error = e crypto_import_error = e
if TYPE_CHECKING: if TYPE_CHECKING:
from .config import Config from .__main__ import Maubot
from .instance import PluginInstance from .instance import PluginInstance
log = logging.getLogger("maubot.client")
class Client(DBClient):
class Client: maubot: "Maubot" = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: dict[UserID, Client] = {} cache: dict[UserID, Client] = {}
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: TraceLogger = logging.getLogger("maubot.client")
http_client: ClientSession = None http_client: ClientSession = None
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: AsyncDatabase | None = None
references: set[PluginInstance] references: set[PluginInstance]
db_instance: DBClient
client: MaubotMatrixClient client: MaubotMatrixClient
crypto: OlmMachine | None crypto: OlmMachine | None
crypto_store: PgCryptoStore | None crypto_store: PgCryptoStore | None
started: bool started: bool
sync_ok: bool
remote_displayname: str | None remote_displayname: str | None
remote_avatar_url: ContentURI | None remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None: def __init__(
self.db_instance = db_instance self,
id: UserID,
homeserver: str,
access_token: str,
device_id: DeviceID,
enabled: bool = False,
next_batch: SyncToken = "",
filter_id: FilterID = "",
sync: bool = True,
autojoin: bool = True,
online: bool = True,
displayname: str = "disable",
avatar_url: str = "disable",
) -> None:
super().__init__(
id=id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id,
enabled=bool(enabled),
next_batch=next_batch,
filter_id=filter_id,
sync=bool(sync),
autojoin=bool(autojoin),
online=bool(online),
displayname=displayname,
avatar_url=avatar_url,
)
self._postinited = False
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def _make_client(
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
) -> MaubotMatrixClient:
return MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=token or self.access_token,
client_session=self.http_client,
log=self.log,
crypto_log=self.log.getChild("crypto"),
loop=self.maubot.loop,
device_id=device_id or self.device_id,
sync_store=self,
state_store=self.maubot.state_store,
)
def postinit(self) -> None:
if self._postinited:
raise RuntimeError("postinit() called twice")
self._postinited = True
self.cache[self.id] = self self.cache[self.id] = self
self.log = log.getChild(self.id) self.log = self.log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop)
self.references = set() self.references = set()
self.started = False self.started = False
self.sync_ok = True self.sync_ok = True
self.remote_displayname = None self.remote_displayname = None
self.remote_avatar_url = None self.remote_avatar_url = None
self.client = MaubotMatrixClient( self.client = self._make_client()
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
if self.enable_crypto: if self.enable_crypto:
self._prepare_crypto() self._prepare_crypto()
else: else:
@ -118,6 +159,12 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
@property @property
def enable_crypto(self) -> bool: def enable_crypto(self) -> bool:
if not self.device_id: if not self.device_id:
@ -131,16 +178,21 @@ class Client:
# Clear the stack trace after it's logged once to avoid spamming logs # Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None crypto_import_error = None
return False return False
elif not self.crypto_db: elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared") self.log.warning("Client has device ID, but crypto database is not prepared")
return False return False
return True return True
def _prepare_crypto(self) -> None: def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore( self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
)
self.crypto = OlmMachine(
self.client,
self.crypto_store,
self.maubot.state_store,
log=self.client.crypto_log,
) )
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None: def _remove_crypto_event_handlers(self) -> None:
@ -156,12 +208,6 @@ class Client:
for event_type, func in handlers: for event_type, func in handlers:
self.client.remove_event_handler(event_type, func) self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
async def start(self, try_n: int | None = 0) -> None: async def start(self, try_n: int | None = 0) -> None:
try: try:
if try_n > 0: if try_n > 0:
@ -196,31 +242,34 @@ class Client:
whoami = await self.client.whoami() whoami = await self.client.whoami()
except MatrixInvalidToken as e: except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client") self.log.error(f"Invalid token: {e}. Disabling client")
self.db_instance.enabled = False self.enabled = False
await self.update()
return return
except Exception as e: except Exception as e:
if try_n >= 8: if try_n >= 8:
self.log.exception("Failed to get /account/whoami, disabling client") self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False self.enabled = False
await self.update()
else: else:
self.log.warning( self.log.warning(
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}" f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
) )
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) _ = asyncio.create_task(self.start(try_n + 1))
return return
if whoami.user_id != self.id: if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.db_instance.enabled = False self.enabled = False
await self.update()
return return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error( self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}" f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
) )
self.db_instance.enabled = False self.enabled = False
await self.update()
return return
if not self.filter_id: if not self.filter_id:
self.db_instance.edit( self.filter_id = await self.client.create_filter(
filter_id=await self.client.create_filter(
Filter( Filter(
room=RoomFilter( room=RoomFilter(
timeline=RoomEventFilter( timeline=RoomEventFilter(
@ -236,7 +285,7 @@ class Client:
), ),
) )
) )
) await self.update()
if self.displayname != "disable": if self.displayname != "disable":
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable": if self.avatar_url != "disable":
@ -270,18 +319,13 @@ class Client:
if self.crypto: if self.crypto:
await self.crypto_store.close() await self.crypto_store.close()
def clear_cache(self) -> None: async def clear_cache(self) -> None:
self.stop_sync() self.stop_sync()
self.db_instance.edit(filter_id="", next_batch="") self.filter_id = FilterID("")
self.next_batch = SyncToken("")
await self.update()
self.start_sync() self.start_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
@ -304,20 +348,6 @@ class Client:
"instances": [instance.to_dict() for instance in self.references], "instances": [instance.to_dict() for instance in self.references],
} }
@classmethod
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None: async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room: if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring") self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
@ -329,7 +359,7 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id) await self.client.join_room(evt.room_id)
async def update_started(self, started: bool) -> None: async def update_started(self, started: bool | None) -> None:
if started is None or started == self.started: if started is None or started == self.started:
return return
if started: if started:
@ -337,23 +367,65 @@ class Client:
else: else:
await self.stop() await self.stop()
async def update_displayname(self, displayname: str) -> None: async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
if enabled is None or enabled == self.enabled:
return
self.enabled = enabled
if save:
await self.update()
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
if displayname is None or displayname == self.displayname: if displayname is None or displayname == self.displayname:
return return
self.db_instance.displayname = displayname self.displayname = displayname
if self.displayname != "disable": if self.displayname != "disable":
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
else: else:
await self._update_remote_profile() await self._update_remote_profile()
if save:
await self.update()
async def update_avatar_url(self, avatar_url: ContentURI) -> None: async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
if avatar_url is None or avatar_url == self.avatar_url: if avatar_url is None or avatar_url == self.avatar_url:
return return
self.db_instance.avatar_url = avatar_url self.avatar_url = avatar_url
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
else: else:
await self._update_remote_profile() await self._update_remote_profile()
if save:
await self.update()
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
if sync is None or self.sync == sync:
return
self.sync = sync
if self.started:
if sync:
self.start_sync()
else:
self.stop_sync()
if save:
await self.update()
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
if autojoin is None or autojoin == self.autojoin:
return
if autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.autojoin = autojoin
if save:
await self.update()
async def update_online(self, online: bool | None, save: bool = True) -> None:
if online is None or online == self.online:
return
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
self.online = online
if save:
await self.update()
async def update_access_details( async def update_access_details(
self, self,
@ -373,22 +445,13 @@ class Client:
and device_id == self.device_id and device_id == self.device_id
): ):
return return
new_client = MaubotMatrixClient( new_client = self._make_client(homeserver, access_token, device_id)
mxid=self.id,
base_url=homeserver or self.homeserver,
token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
whoami = await new_client.whoami() whoami = await new_client.whoami()
if whoami.user_id != self.id: if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}") raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id: elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}") raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = SyncStoreProxy(self.db_instance) new_client.sync_store = self
self.stop_sync() self.stop_sync()
# TODO this event handler transfer is pretty hacky # TODO this event handler transfer is pretty hacky
@ -398,9 +461,9 @@ class Client:
new_client.global_event_handlers = self.client.global_event_handlers new_client.global_event_handlers = self.client.global_event_handlers
self.client = new_client self.client = new_client
self.db_instance.homeserver = homeserver self.homeserver = homeserver
self.db_instance.access_token = access_token self.access_token = access_token
self.db_instance.device_id = device_id self.device_id = device_id
if self.enable_crypto: if self.enable_crypto:
self._prepare_crypto() self._prepare_crypto()
await self._start_crypto() await self._start_crypto()
@ -413,97 +476,53 @@ class Client:
profile = await self.client.get_profile(self.id) profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
# region Properties async def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
await super().delete()
@property @classmethod
def id(self) -> UserID: @async_getter_lock
return self.db_instance.id async def get(
cls,
user_id: UserID,
*,
homeserver: str | None = None,
access_token: str | None = None,
device_id: DeviceID | None = None,
) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
pass
@property user = cast(cls, await super().get(user_id))
def homeserver(self) -> str: if user is not None:
return self.db_instance.homeserver user.postinit()
return user
@property if homeserver and access_token:
def access_token(self) -> str: user = cls(
return self.db_instance.access_token user_id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id or "",
)
await user.insert()
user.postinit()
return user
@property return None
def device_id(self) -> DeviceID:
return self.db_instance.device_id
@property @classmethod
def enabled(self) -> bool: async def all(cls) -> AsyncGenerator[Client, None]:
return self.db_instance.enabled users = await super().all()
user: cls
@enabled.setter for user in users:
def enabled(self, value: bool) -> None: try:
self.db_instance.enabled = value yield cls.cache[user.id]
except KeyError:
@property user.postinit()
def next_batch(self) -> SyncToken: yield user
return self.db_instance.next_batch
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.db_instance.autojoin = value
@property
def online(self) -> bool:
return self.db_instance.online
@online.setter
def online(self, value: bool) -> None:
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
self.db_instance.online = value
@property
def displayname(self) -> str:
return self.db_instance.displayname
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
# endregion
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
if OlmMachine:
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()

View File

@ -1,108 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, Optional
import logging
import sys
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.db import Base
from .config import Config
class DBPlugin(Base):
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(
String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False,
)
config: str = Column(Text, nullable=False, default="")
@classmethod
def all(cls) -> Iterable["DBPlugin"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional["DBPlugin"]:
return cls._select_one_or_none(cls.c.id == id)
class DBClient(Base):
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(Text, nullable=False)
device_id: DeviceID = Column(String(255), nullable=True)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="")
sync: bool = Column(Boolean, nullable=False, default=True)
autojoin: bool = Column(Boolean, nullable=False, default=True)
online: bool = Column(Boolean, nullable=False, default=True)
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable["DBClient"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional["DBClient"]:
return cls._select_one_or_none(cls.c.id == id)
def init(config: Config) -> Engine:
db = sql.create_engine(config["database"])
Base.metadata.bind = db
for table in (DBPlugin, DBClient, RoomState, UserProfile):
table.bind(db)
if not db.has_table("alembic_version"):
log = logging.getLogger("maubot.db")
if db.has_table("client") and db.has_table("plugin"):
log.warning(
"alembic_version table not found, but client and plugin tables found. "
"Assuming pre-Alembic database and inserting version."
)
db.execute(
"CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");"
)
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
else:
log.critical(
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
)
sys.exit(10)
return db

13
maubot/db/__init__.py Normal file
View File

@ -0,0 +1,13 @@
from mautrix.util.async_db import Database
from .client import Client
from .instance import Instance
from .upgrade import upgrade_table
def init(db: Database) -> None:
for table in (Client, Instance):
table.db = db
__all__ = ["upgrade_table", "init", "Client", "Instance"]

114
maubot/db/client.py Normal file
View File

@ -0,0 +1,114 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Client(SyncStore):
db: ClassVar[Database] = fake_db
id: UserID
homeserver: str
access_token: str
device_id: DeviceID
enabled: bool
next_batch: SyncToken
filter_id: FilterID
sync: bool
autojoin: bool
online: bool
displayname: str
avatar_url: ContentURI
@classmethod
def _from_row(cls, row: Record | None) -> Client | None:
if row is None:
return None
return cls(**row)
_columns = (
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
"sync, autojoin, online, displayname, avatar_url"
)
@property
def _values(self):
return (
self.id,
self.homeserver,
self.access_token,
self.device_id,
self.enabled,
self.next_batch,
self.filter_id,
self.sync,
self.autojoin,
self.online,
self.displayname,
self.avatar_url,
)
@classmethod
async def all(cls) -> list[Client]:
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Client | None:
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def insert(self) -> None:
q = """
INSERT INTO client (
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
sync, autojoin, online, displayname, avatar_url
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"""
await self.db.execute(q, *self._values)
async def put_next_batch(self, next_batch: SyncToken) -> None:
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
self.next_batch = next_batch
async def get_next_batch(self) -> SyncToken:
return self.next_batch
async def update(self) -> None:
q = """
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
displayname=$11, avatar_url=$12
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)

75
maubot/db/instance.py Normal file
View File

@ -0,0 +1,75 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Instance:
db: ClassVar[Database] = fake_db
id: str
type: str
enabled: bool
primary_user: UserID
config_str: str
@classmethod
def _from_row(cls, row: Record | None) -> Instance | None:
if row is None:
return None
return cls(**row)
@classmethod
async def all(cls) -> list[Instance]:
rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Instance | None:
q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def update_id(self, new_id: str) -> None:
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
self.id = new_id
@property
def _values(self):
return self.id, self.type, self.enabled, self.primary_user, self.config_str
async def insert(self) -> None:
q = (
"INSERT INTO instance (id, type, enabled, primary_user, config) "
"VALUES ($1, $2, $3, $4, $5)"
)
await self.db.execute(q, *self._values)
async def update(self) -> None:
q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)

View File

@ -0,0 +1,5 @@
from mautrix.util.async_db import UpgradeTable
upgrade_table = UpgradeTable()
from . import v01_initial_revision

View File

@ -0,0 +1,136 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version"
last_legacy_version = "90aa88820eab"
@upgrade_table.register(description="Initial asyncpg revision")
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
if await conn.table_exists("alembic_version"):
await migrate_legacy_to_v1(conn, scheme)
else:
return await create_v1_tables(conn)
async def create_v1_tables(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE client (
id TEXT PRIMARY KEY,
homeserver TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
next_batch TEXT NOT NULL,
filter_id TEXT NOT NULL,
sync BOOLEAN NOT NULL,
autojoin BOOLEAN NOT NULL,
online BOOLEAN NOT NULL,
displayname TEXT NOT NULL,
avatar_url TEXT NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE instance (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
primary_user TEXT NOT NULL,
config TEXT NOT NULL,
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
)"""
)
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version:
raise RuntimeError(
"Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
await conn.execute("ALTER TABLE plugin RENAME TO instance")
await update_state_store(conn, scheme)
if scheme != Scheme.SQLITE:
await varchar_to_text(conn)
await conn.execute("DROP TABLE alembic_version")
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
# The Matrix state store already has more or less the correct schema, so set the version
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
if scheme != Scheme.SQLITE:
# Remove old uppercase membership type and recreate it as lowercase
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
await conn.execute("DROP TYPE IF EXISTS membership")
await conn.execute(
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
)
await conn.execute(
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
"USING LOWER(membership)::membership"
)
else:
# Recreate table to remove CHECK constraint and lowercase everything
await conn.execute(
"""CREATE TABLE new_mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
)"""
)
await conn.execute(
"""
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
FROM mx_user_profile
"""
)
await conn.execute("DROP TABLE mx_user_profile")
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = {
"client": (
"id",
"homeserver",
"device_id",
"next_batch",
"filter_id",
"displayname",
"avatar_url",
),
"instance": ("id", "type", "primary_user"),
"mx_room_state": ("room_id",),
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
}
for table, columns in columns_to_adjust.items():
for column in columns:
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')

View File

@ -6,9 +6,7 @@
database: sqlite:///maubot.db database: sqlite:///maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above. # Separate database URL for the crypto database. "default" means use the same database as above.
# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above. crypto_database: default
# When using postgres, using the same database for both is safe.
crypto_database: sqlite:///crypto.db
plugin_directories: plugin_directories:
# The directory where uploaded new plugins should be stored. # The directory where uploaded new plugins should be stored.

View File

@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Iterable from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
from asyncio import AbstractEventLoop from collections import defaultdict
import asyncio
import inspect
import io import io
import logging import logging
import os.path import os.path
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .client import Client from .client import Client
from .config import Config from .db import Instance as DBInstance
from .db import DBPlugin
from .loader import PluginLoader, ZippedPluginLoader from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import MaubotServer, PluginWebApp from .__main__ import Maubot
from .server import PluginWebApp
log = logging.getLogger("maubot.instance") log = logging.getLogger("maubot.instance")
@ -44,29 +47,42 @@ yaml.indent(4)
yaml.width = 200 yaml.width = 200
class PluginInstance: class PluginInstance(DBInstance):
webserver: MaubotServer = None maubot: "Maubot" = None
mb_config: Config = None
loop: AbstractEventLoop = None
cache: dict[str, PluginInstance] = {} cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = [] plugin_directories: list[str] = []
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: logging.Logger log: logging.Logger
loader: PluginLoader loader: PluginLoader | None
client: Client client: Client | None
plugin: Plugin plugin: Plugin | None
config: BaseProxyConfig config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None base_cfg_str: str | None
inst_db: sql.engine.Engine inst_db: sql.engine.Engine | None
inst_db_tables: dict[str, sql.Table] inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None inst_webapp: PluginWebApp | None
inst_webapp_url: str | None inst_webapp_url: str | None
started: bool started: bool
def __init__(self, db_instance: DBPlugin): def __init__(
self.db_instance = db_instance self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
) -> None:
super().__init__(
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
)
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def postinit(self) -> None:
self.log = log.getChild(self.id) self.log = log.getChild(self.id)
self.cache[self.id] = self
self.config = None self.config = None
self.started = False self.started = False
self.loader = None self.loader = None
@ -78,7 +94,6 @@ class PluginInstance:
self.inst_webapp_url = None self.inst_webapp_url = None
self.base_cfg = None self.base_cfg = None
self.base_cfg_str = None self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@ -87,10 +102,10 @@ class PluginInstance:
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"primary_user": self.primary_user, "primary_user": self.primary_user,
"config": self.db_instance.config, "config": self.config_str,
"base_config": self.base_cfg_str, "base_config": self.base_cfg_str,
"database": ( "database": (
self.inst_db is not None and self.mb_config["api_features.instance_database"] self.inst_db is not None and self.maubot.config["api_features.instance_database"]
), ),
} }
@ -101,19 +116,19 @@ class PluginInstance:
self.inst_db_tables = metadata.tables self.inst_db_tables = metadata.tables
return self.inst_db_tables return self.inst_db_tables
def load(self) -> bool: async def load(self) -> bool:
if not self.loader: if not self.loader:
try: try:
self.loader = PluginLoader.find(self.type) self.loader = PluginLoader.find(self.type)
except KeyError: except KeyError:
self.log.error(f"Failed to find loader for type {self.type}") self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False await self.update_enabled(False)
return False return False
if not self.client: if not self.client:
self.client = Client.get(self.primary_user) self.client = await Client.get(self.primary_user)
if not self.client: if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}") self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False await self.update_enabled(False)
return False return False
if self.loader.meta.database: if self.loader.meta.database:
self.enable_database() self.enable_database()
@ -125,18 +140,18 @@ class PluginInstance:
return True return True
def enable_webapp(self) -> None: def enable_webapp(self) -> None:
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id) self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None: def disable_webapp(self) -> None:
self.webserver.remove_instance_webapp(self.id) self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None self.inst_webapp = None
self.inst_webapp_url = None self.inst_webapp_url = None
def enable_database(self) -> None: def enable_database(self) -> None:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
def delete(self) -> None: async def delete(self) -> None:
if self.loader is not None: if self.loader is not None:
self.loader.references.remove(self) self.loader.references.remove(self)
if self.client is not None: if self.client is not None:
@ -145,23 +160,23 @@ class PluginInstance:
del self.cache[self.id] del self.cache[self.id]
except KeyError: except KeyError:
pass pass
self.db_instance.delete() await super().delete()
if self.inst_db: if self.inst_db:
self.inst_db.dispose() self.inst_db.dispose()
ZippedPluginLoader.trash( ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted", reason="deleted",
) )
if self.inst_webapp: if self.inst_webapp:
self.disable_webapp() self.disable_webapp()
def load_config(self) -> CommentedMap: def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config) return yaml.load(self.config_str)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None: def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO() buf = io.StringIO()
yaml.dump(data, buf) yaml.dump(data, buf)
self.db_instance.config = buf.getvalue() self.config_str = buf.getvalue()
async def start(self) -> None: async def start(self) -> None:
if self.started: if self.started:
@ -172,7 +187,7 @@ class PluginInstance:
return return
if not self.client or not self.loader: if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...") self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not self.load(): if not await self.load():
return return
cls = await self.loader.load() cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None: if self.loader.meta.webapp and self.inst_webapp is None:
@ -205,7 +220,7 @@ class PluginInstance:
self.config = config_class(self.load_config, base_cfg_func, self.save_config) self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls( self.plugin = cls(
client=self.client.client, client=self.client.client,
loop=self.loop, loop=self.maubot.loop,
http=self.client.http_client, http=self.client.http_client,
instance_id=self.id, instance_id=self.id,
log=self.log, log=self.log,
@ -219,7 +234,7 @@ class PluginInstance:
await self.plugin.internal_start() await self.plugin.internal_start()
except Exception: except Exception:
self.log.exception("Failed to start instance") self.log.exception("Failed to start instance")
self.db_instance.enabled = False await self.update_enabled(False)
return return
self.started = True self.started = True
self.inst_db_tables = None self.inst_db_tables = None
@ -241,60 +256,51 @@ class PluginInstance:
self.plugin = None self.plugin = None
self.inst_db_tables = None self.inst_db_tables = None
@classmethod async def update_id(self, new_id: str | None) -> None:
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None: if new_id is not None and new_id.lower() != self.id:
try: await super().update_id(new_id.lower())
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod async def update_config(self, config: str | None) -> None:
def all(cls) -> Iterable[PluginInstance]: if config is None or self.config_str == config:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id.lower()
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
return return
self.db_instance.config = config self.config_str = config
if self.started and self.plugin is not None: if self.started and self.plugin is not None:
self.plugin.on_external_config_update() res = self.plugin.on_external_config_update()
if inspect.isawaitable(res):
await res
await self.update()
async def update_primary_user(self, primary_user: UserID) -> bool: async def update_primary_user(self, primary_user: UserID | None) -> bool:
if not primary_user or primary_user == self.primary_user: if primary_user is None or primary_user == self.primary_user:
return True return True
client = Client.get(primary_user) client = await Client.get(primary_user)
if not client: if not client:
return False return False
await self.stop() await self.stop()
self.db_instance.primary_user = client.id self.primary_user = client.id
if self.client: if self.client:
self.client.references.remove(self) self.client.references.remove(self)
self.client = client self.client = client
self.client.references.add(self) self.client.references.add(self)
await self.update()
await self.start() await self.start()
self.log.debug(f"Primary user switched to {self.client.id}") self.log.debug(f"Primary user switched to {self.client.id}")
return True return True
async def update_type(self, type: str) -> bool: async def update_type(self, type: str | None) -> bool:
if not type or type == self.type: if type is None or type == self.type:
return True return True
try: try:
loader = PluginLoader.find(type) loader = PluginLoader.find(type)
except KeyError: except KeyError:
return False return False
await self.stop() await self.stop()
self.db_instance.type = loader.meta.id self.type = loader.meta.id
if self.loader: if self.loader:
self.loader.references.remove(self) self.loader.references.remove(self)
self.loader = loader self.loader = loader
self.loader.references.add(self) self.loader.references.add(self)
await self.update()
await self.start() await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}") self.log.debug(f"Type switched to {self.loader.meta.id}")
return True return True
@ -303,39 +309,41 @@ class PluginInstance:
if started is not None and started != self.started: if started is not None and started != self.started:
await (self.start() if started else self.stop()) await (self.start() if started else self.stop())
def update_enabled(self, enabled: bool) -> None: async def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled: if enabled is not None and enabled != self.enabled:
self.db_instance.enabled = enabled self.enabled = enabled
await self.update()
# region Properties @classmethod
@async_getter_lock
async def get(
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
pass
@property instance = cast(cls, await super().get(instance_id))
def id(self) -> str: if instance is not None:
return self.db_instance.id instance.postinit()
return instance
@id.setter if type and primary_user:
def id(self, value: str) -> None: instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
self.db_instance.id = value await instance.insert()
instance.postinit()
return instance
@property return None
def type(self) -> str:
return self.db_instance.type
@property @classmethod
def enabled(self) -> bool: async def all(cls) -> AsyncGenerator[PluginInstance, None]:
return self.db_instance.enabled instances = await super().all()
instance: PluginInstance
@property for instance in instances:
def primary_user(self) -> UserID: try:
return self.db_instance.primary_user yield cls.cache[instance.id]
except KeyError:
# endregion instance.postinit()
yield instance
def init(
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all()

View File

@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter): class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str: def _color_name(self, module: str) -> str:
client = "maubot.client" client = "maubot.client"
if module.startswith(client): if module.startswith(client + "."):
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}" suffix = ""
if module.endswith(".crypto"):
suffix = f".{MAU_COLOR}crypto{RESET}"
module = module[: -len(".crypto")]
module = module[len(client) + 1 :]
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
instance = "maubot.instance" instance = "maubot.instance"
if module.startswith(instance): if module.startswith(instance + "."):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
loader = "maubot.loader" loader = "maubot.loader"
if module.startswith(loader): if module.startswith(loader + "."):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot"): if module.startswith("maubot."):
return f"{MAU_COLOR}{module}{RESET}" return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module) return super()._color_name(module)

View File

@ -13,16 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client import SyncStore from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
from mautrix.types import SyncToken
try:
from mautrix.crypto import StateStore as CryptoStateStore
class SyncStoreProxy(SyncStore): class PgStateStore(BasePgStateStore, CryptoStateStore):
def __init__(self, db_instance) -> None: pass
self.db_instance = db_instance
async def put_next_batch(self, next_batch: SyncToken) -> None: except ImportError as e:
self.db_instance.edit(next_batch=next_batch) PgStateStore = BasePgStateStore
async def get_next_batch(self) -> SyncToken: __all__ = ["PgStateStore"]
return self.db_instance.next_batch

View File

@ -13,17 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__
from ..plugin_base import Plugin from ..plugin_base import Plugin
from .meta import PluginMeta
if TYPE_CHECKING: if TYPE_CHECKING:
from ..instance import PluginInstance from ..instance import PluginInstance
@ -35,36 +32,6 @@ class IDConflictError(Exception):
pass pass
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
class BasePluginLoader(ABC): class BasePluginLoader(ABC):
meta: PluginMeta meta: PluginMeta
@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
async def read_file(self, path: str) -> bytes: async def read_file(self, path: str) -> bytes:
pass pass
def sync_list_files(self, directory: str) -> List[str]: def sync_list_files(self, directory: str) -> list[str]:
raise NotImplementedError("This loader doesn't support synchronous operations") raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod @abstractmethod
async def list_files(self, directory: str) -> List[str]: async def list_files(self, directory: str) -> list[str]:
pass pass
class PluginLoader(BasePluginLoader, ABC): class PluginLoader(BasePluginLoader, ABC):
id_cache: Dict[str, "PluginLoader"] = {} id_cache: dict[str, PluginLoader] = {}
meta: PluginMeta meta: PluginMeta
references: Set["PluginInstance"] references: set[PluginInstance]
def __init__(self): def __init__(self):
self.references = set() self.references = set()
@classmethod @classmethod
def find(cls, plugin_id: str) -> "PluginLoader": def find(cls, plugin_id: str) -> PluginLoader:
return cls.id_cache[plugin_id] return cls.id_cache[plugin_id]
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
) )
@abstractmethod @abstractmethod
async def load(self) -> Type[PluginClass]: async def load(self) -> type[PluginClass]:
pass pass
@abstractmethod @abstractmethod
async def reload(self) -> Type[PluginClass]: async def reload(self) -> type[PluginClass]:
pass pass
@abstractmethod @abstractmethod

53
maubot/loader/meta.py Normal file
View File

@ -0,0 +1,53 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []

View File

@ -29,7 +29,8 @@ from mautrix.types import SerializerError
from ..config import Config from ..config import Config
from ..lib.zipimport import ZipImportError, zipimporter from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta from .abc import IDConflictError, PluginClass, PluginLoader
from .meta import PluginMeta
yaml = YAML() yaml = YAML()

View File

@ -20,7 +20,7 @@ from aiohttp import web
from ...config import Config from ...config import Config
from .auth import check_token from .auth import check_token
from .base import get_config, routes, set_config, set_loop from .base import get_config, routes, set_config
from .middleware import auth, error from .middleware import auth, error
@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
set_config(cfg) set_config(cfg)
set_loop(loop)
for pkg, enabled in cfg["api_features"].items(): for pkg, enabled in cfg["api_features"].items():
if enabled: if enabled:
importlib.import_module(f"maubot.management.api.{pkg}") importlib.import_module(f"maubot.management.api.{pkg}")

View File

@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
def get_token(request: web.Request) -> str: def get_token(request: web.Request) -> str:
token = request.headers.get("Authorization", "") token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "): if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", None) token = request.query.get("access_token", "")
else: else:
token = token[len("Bearer ") :] token = token[len("Bearer ") :]
return token return token

View File

@ -24,7 +24,6 @@ from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef() routes: web.RouteTableDef = web.RouteTableDef()
_config: Config | None = None _config: Config | None = None
_loop: asyncio.AbstractEventLoop | None = None
def set_config(config: Config) -> None: def set_config(config: Config) -> None:
@ -36,15 +35,6 @@ def get_config() -> Config:
return _config return _config
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
global _loop
_loop = loop
def get_loop() -> asyncio.AbstractEventLoop:
return _loop
@routes.get("/version") @routes.get("/version")
async def version(_: web.Request) -> web.Response: async def version(_: web.Request) -> web.Response:
return web.json_response({"version": __version__}) return web.json_response({"version": __version__})

View File

@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
from mautrix.types import FilterID, SyncToken, UserID from mautrix.types import FilterID, SyncToken, UserID
from ...client import Client from ...client import Client
from ...db import DBClient
from .base import routes from .base import routes
from .responses import resp from .responses import resp
@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
@routes.get("/client/{id}") @routes.get("/client/{id}")
async def get_client(request: web.Request) -> web.Response: async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = Client.get(user_id, None) client = await Client.get(user_id)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
return resp.found(client.to_dict()) return resp.found(client.to_dict())
@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
mxid="@not:a.mxid", mxid="@not:a.mxid",
base_url=homeserver, base_url=homeserver,
token=access_token, token=access_token,
loop=Client.loop,
client_session=Client.http_client, client_session=Client.http_client,
) )
try: try:
@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
except MatrixConnectionError: except MatrixConnectionError:
return resp.bad_client_connection_details return resp.bad_client_connection_details
if user_id is None: if user_id is None:
existing_client = Client.get(whoami.user_id, None) existing_client = await Client.get(whoami.user_id)
if existing_client is not None: if existing_client is not None:
return resp.user_exists return resp.user_exists
elif whoami.user_id != user_id: elif whoami.user_id != user_id:
return resp.mxid_mismatch(whoami.user_id) return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id: elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id) return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient( client = await Client.get(
id=whoami.user_id, whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
homeserver=homeserver,
access_token=access_token,
enabled=data.get("enabled", True),
next_batch=SyncToken(""),
filter_id=FilterID(""),
sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
online=data.get("online", True),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id,
) )
client = Client(db_instance) client.enabled = data.get("enabled", True)
client.db_instance.insert() client.sync = data.get("sync", True)
client.autojoin = data.get("autojoin", True)
client.online = data.get("online", True)
client.displayname = data.get("displayname", "disable")
client.avatar_url = data.get("avatar_url", "disable")
await client.update()
await client.start() await client.start()
return resp.created(client.to_dict()) return resp.created(client.to_dict())
@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try: try:
await client.update_access_details( await client.update_access_details(
data.get("access_token", None), data.get("access_token"), data.get("homeserver"), data.get("device_id")
data.get("homeserver", None),
data.get("device_id", None),
) )
except MatrixInvalidToken: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :]) return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"): elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :]) return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
with client.db_instance.edit_mode(): await client.update_avatar_url(data.get("avatar_url"), save=False)
await client.update_avatar_url(data.get("avatar_url", None)) await client.update_displayname(data.get("displayname"), save=False)
await client.update_displayname(data.get("displayname", None)) await client.update_started(data.get("started"))
await client.update_started(data.get("started", None)) await client.update_enabled(data.get("enabled"), save=False)
client.enabled = data.get("enabled", client.enabled) await client.update_autojoin(data.get("autojoin"), save=False)
client.autojoin = data.get("autojoin", client.autojoin) await client.update_online(data.get("online"), save=False)
client.online = data.get("online", client.online) await client.update_sync(data.get("sync"), save=False)
client.sync = data.get("sync", client.sync) await client.update()
return resp.updated(client.to_dict(), is_login=is_login) return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client( async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False user_id: UserID, data: dict, is_login: bool = False
) -> web.Response: ) -> web.Response:
client = Client.get(user_id, None) client = await Client.get(user_id)
if not client: if not client:
return await _create_client(user_id, data) return await _create_client(user_id, data)
else: else:
@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}") @routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response: async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info["id"]
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
@routes.delete("/client/{id}") @routes.delete("/client/{id}")
async def delete_client(request: web.Request) -> web.Response: async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info["id"]
client = Client.get(user_id, None) client = await Client.get(user_id)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
if len(client.references) > 0: if len(client.references) > 0:
return resp.client_in_use return resp.client_in_use
if client.started: if client.started:
await client.stop() await client.stop()
client.delete() await client.delete()
return resp.deleted return resp.deleted
@routes.post("/client/{id}/clearcache") @routes.post("/client/{id}/clearcache")
async def clear_client_cache(request: web.Request) -> web.Response: async def clear_client_cache(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info["id"]
client = Client.get(user_id, None) client = await Client.get(user_id)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
client.clear_cache() await client.clear_cache()
return resp.ok return resp.ok

View File

@ -13,7 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, NamedTuple, Optional, Tuple from __future__ import annotations
from typing import NamedTuple
from http import HTTPStatus from http import HTTPStatus
from json import JSONDecodeError from json import JSONDecodeError
import asyncio import asyncio
@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
from mautrix.errors import MatrixRequestError from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType from mautrix.types import LoginResponse, LoginType
from .base import get_config, get_loop, routes from .base import get_config, routes
from .client import _create_client, _create_or_update_client from .client import _create_client, _create_or_update_client
from .responses import resp from .responses import resp
def known_homeservers() -> Dict[str, Dict[str, str]]: def known_homeservers() -> dict[str, dict[str, str]]:
return get_config()["homeservers"] return get_config()["homeservers"]
@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
async def read_client_auth_request( async def read_client_auth_request(
request: web.Request, request: web.Request,
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]: ) -> tuple[AuthRequestInfo | None, web.Response | None]:
server_name = request.match_info.get("server", None) server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None) server = known_homeservers().get(server_name, None)
if not server: if not server:
@ -85,7 +87,7 @@ async def read_client_auth_request(
return ( return (
AuthRequestInfo( AuthRequestInfo(
server_name=server_name, server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()), client=ClientAPI(base_url=base_url),
secret=server.get("secret"), secret=server.get("secret"),
username=username, username=username,
password=password, password=password,
@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query( sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)} {"redirectUrl": str(public_url)}
) )
sso_waiters[waiter_id] = req, get_loop().create_future() sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response: async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}" device_id = f"maubot_{device_id}"
try: try:
@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
return web.json_response(res.serialize()) return web.json_response(res.serialize())
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {} sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait") @routes.post("/client/auth/{server}/sso/{id}/wait")

View File

@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
@routes.view("/proxy/{id}/{path:_matrix/.+}") @routes.view("/proxy/{id}/{path:_matrix/.+}")
async def proxy(request: web.Request) -> web.StreamResponse: async def proxy(request: web.Request) -> web.StreamResponse:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = Client.get(user_id, None) client = await Client.get(user_id)
if not client: if not client:
return resp.client_not_found return resp.client_not_found

View File

@ -18,7 +18,6 @@ from json import JSONDecodeError
from aiohttp import web from aiohttp import web
from ...client import Client from ...client import Client
from ...db import DBPlugin
from ...instance import PluginInstance from ...instance import PluginInstance
from ...loader import PluginLoader from ...loader import PluginLoader
from .base import routes from .base import routes
@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
@routes.get("/instance/{id}") @routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response: async def get_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower() instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id, None) instance = await PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
return resp.found(instance.to_dict()) return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response: async def _create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type", None) plugin_type = data.get("type")
primary_user = data.get("primary_user", None) primary_user = data.get("primary_user")
if not plugin_type: if not plugin_type:
return resp.plugin_type_required return resp.plugin_type_required
elif not primary_user: elif not primary_user:
return resp.primary_user_required return resp.primary_user_required
elif not Client.get(primary_user): elif not await Client.get(primary_user):
return resp.primary_user_not_found return resp.primary_user_not_found
try: try:
PluginLoader.find(plugin_type) PluginLoader.find(plugin_type)
except KeyError: except KeyError:
return resp.plugin_type_not_found return resp.plugin_type_not_found
db_instance = DBPlugin( instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
id=instance_id, instance.enabled = data.get("enabled", True)
type=plugin_type, instance.config_str = data.get("config") or ""
enabled=data.get("enabled", True), await instance.update()
primary_user=primary_user,
config=data.get("config", ""),
)
instance = PluginInstance(db_instance)
instance.load()
instance.db_instance.insert()
await instance.start() await instance.start()
return resp.created(instance.to_dict()) return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response: async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user", None)): if not await instance.update_primary_user(data.get("primary_user")):
return resp.primary_user_not_found return resp.primary_user_not_found
with instance.db_instance.edit_mode(): await instance.update_id(data.get("id"))
instance.update_id(data.get("id", None)) await instance.update_enabled(data.get("enabled"))
instance.update_enabled(data.get("enabled", None)) await instance.update_config(data.get("config"))
instance.update_config(data.get("config", None)) await instance.update_started(data.get("started"))
await instance.update_started(data.get("started", None)) await instance.update_type(data.get("type"))
await instance.update_type(data.get("type", None))
return resp.updated(instance.to_dict()) return resp.updated(instance.to_dict())
@routes.put("/instance/{id}") @routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response: async def update_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower() instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id, None) instance = await PluginInstance.get(instance_id)
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
@routes.delete("/instance/{id}") @routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response: async def delete_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower() instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id) instance = await PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
if instance.started: if instance.started:
await instance.stop() await instance.stop()
instance.delete() await instance.delete()
return resp.deleted return resp.deleted

View File

@ -29,8 +29,8 @@ from .responses import resp
@routes.get("/instance/{id}/database") @routes.get("/instance/{id}/database")
async def get_database(request: web.Request) -> web.Response: async def get_database(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "") instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id, None) instance = await PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
@ -65,8 +65,8 @@ def check_type(val):
@routes.get("/instance/{id}/database/{table}") @routes.get("/instance/{id}/database/{table}")
async def get_table(request: web.Request) -> web.Response: async def get_table(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "") instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id, None) instance = await PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
] ]
except KeyError: except KeyError:
order = [] order = []
limit = int(request.query.get("limit", 100)) limit = int(request.query.get("limit", "100"))
return execute_query(instance, table.select().order_by(*order).limit(limit)) return execute_query(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query") @routes.post("/instance/{id}/database/query")
async def query(request: web.Request) -> web.Response: async def query(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "") instance_id = request.match_info["id"].lower()
instance = PluginInstance.get(instance_id, None) instance = await PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:

View File

@ -23,7 +23,7 @@ import logging
from aiohttp import web, web_ws from aiohttp import web, web_ws
from .auth import is_valid_token from .auth import is_valid_token
from .base import get_loop, routes from .base import routes
BUILTIN_ATTRS = { BUILTIN_ATTRS = {
"args", "args",
@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
authenticated = False authenticated = False
async def close_if_not_authenticated(): async def close_if_not_authenticated():
await asyncio.sleep(5, loop=get_loop()) await asyncio.sleep(5)
if not authenticated: if not authenticated:
await ws.close(code=4000) await ws.close(code=4000)
log.debug(f"Connection from {request.remote} terminated due to no authentication") log.debug(f"Connection from {request.remote} terminated due to no authentication")
asyncio.ensure_future(close_if_not_authenticated()) asyncio.create_task(close_if_not_authenticated())
try: try:
msg: web_ws.WSMessage msg: web_ws.WSMessage

View File

@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
@routes.get("/plugin/{id}") @routes.get("/plugin/{id}")
async def get_plugin(request: web.Request) -> web.Response: async def get_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None) plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id, None) plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found
return resp.found(plugin.to_dict()) return resp.found(plugin.to_dict())
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
@routes.delete("/plugin/{id}") @routes.delete("/plugin/{id}")
async def delete_plugin(request: web.Request) -> web.Response: async def delete_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None) plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id, None) plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found
elif len(plugin.references) > 0: elif len(plugin.references) > 0:
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
@routes.post("/plugin/{id}/reload") @routes.post("/plugin/{id}/reload")
async def reload_plugin(request: web.Request) -> web.Response: async def reload_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None) plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id, None) plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found

View File

@ -29,7 +29,7 @@ from .responses import resp
@routes.put("/plugin/{id}") @routes.put("/plugin/{id}")
async def put_plugin(request: web.Request) -> web.Response: async def put_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None) plugin_id = request.match_info["id"]
content = await request.read() content = await request.read()
file = BytesIO(content) file = BytesIO(content)
try: try:

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Awaitable
from abc import ABC from abc import ABC
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
@ -124,6 +124,7 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None: def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None return None
def on_external_config_update(self) -> None: def on_external_config_update(self) -> Awaitable[None] | None:
if self.config: if self.config:
self.config.load_and_update() self.config.load_and_update()
return None

View File

@ -1,13 +1,10 @@
# Format: #/name defines a new extras_require group called name # Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group. # Uncommented lines after the group definition insert things into that group.
#/postgres #/sqlite
psycopg2-binary>=2,<3 aiosqlite>=0.16,<0.18
asyncpg>=0.20,<0.26
#/encryption #/encryption
asyncpg>=0.20,<0.26
aiosqlite>=0.16,<0.18
python-olm>=3,<4 python-olm>=3,<4
pycryptodome>=3,<4 pycryptodome>=3,<4
unpaddedbase64>=1,<3 unpaddedbase64>=1,<3

View File

@ -1,7 +1,8 @@
mautrix>=0.15.0,<0.16 mautrix>=0.15.2,<0.16
aiohttp>=3,<4 aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
SQLAlchemy>=1,<1.4 SQLAlchemy>=1,<1.4
asyncpg>=0.20,<0.26
alembic>=1,<2 alembic>=1,<2
commonmark>=0.9,<1 commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18 ruamel.yaml>=0.15.35,<0.18

View File

@ -1,5 +1,4 @@
import setuptools import setuptools
import glob
import os import os
with open("requirements.txt") as reqs: with open("requirements.txt") as reqs:
@ -57,9 +56,7 @@ setuptools.setup(
mbc=maubot.cli:app mbc=maubot.cli:app
""", """,
data_files=[ data_files=[
(".", ["maubot/example-config.yaml", "alembic.ini"]), (".", ["maubot/example-config.yaml"]),
("alembic", ["alembic/env.py"]),
("alembic/versions", glob.glob("alembic/versions/*.py")),
], ],
package_data={ package_data={
"maubot": [ "maubot": [