Switch to asyncpg/aiosqlite

Fixes #142
Fixes #98
Probably fixes #62
This commit is contained in:
Tulir Asokan 2022-03-25 19:45:48 +02:00
parent 068e268c63
commit 21ed971d2f
43 changed files with 911 additions and 955 deletions

View File

@ -15,7 +15,6 @@ RUN apk add --no-cache \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-psycopg2 \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps

View File

@ -10,7 +10,6 @@ RUN apk add --no-cache \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-psycopg2 \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps

View File

@ -1,83 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks=black
# black.type=console_scripts
# black.entrypoint=black
# black.options=-l 79
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1 +0,0 @@
Generic single-database configuration.

View File

@ -1,92 +0,0 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import sys
from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from mautrix.util.db import Base
from maubot.config import Config
from maubot import db
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
maubot_config = Config(maubot_config_path, None)
maubot_config.load()
config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -1,24 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@ -1,32 +0,0 @@
"""Add device_id to clients
Revision ID: 4b93300852aa
Revises: fccd1f95544d
Create Date: 2020-07-11 15:49:38.831459
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4b93300852aa'
down_revision = 'fccd1f95544d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('client', schema=None) as batch_op:
batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('client', schema=None) as batch_op:
batch_op.drop_column('device_id')
# ### end Alembic commands ###

View File

@ -1,47 +0,0 @@
"""Add Matrix state store
Revision ID: 90aa88820eab
Revises: 4b93300852aa
Create Date: 2020-07-12 01:50:06.215623
"""
from alembic import op
import sqlalchemy as sa
from mautrix.client.state_store.sqlalchemy import SerializableType
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
# revision identifiers, used by Alembic.
revision = '90aa88820eab'
down_revision = '4b93300852aa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mx_room_state',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
sa.PrimaryKeyConstraint('room_id')
)
op.create_table('mx_user_profile',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=False),
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
sa.Column('displayname', sa.String(), nullable=True),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint('room_id', 'user_id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mx_user_profile')
op.drop_table('mx_room_state')
# ### end Alembic commands ###

View File

@ -1,50 +0,0 @@
"""Initial revision
Revision ID: d295f8dcfa64
Revises:
Create Date: 2019-09-27 00:21:02.527915
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd295f8dcfa64'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('client',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('homeserver', sa.String(length=255), nullable=False),
sa.Column('access_token', sa.Text(), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('next_batch', sa.String(length=255), nullable=False),
sa.Column('filter_id', sa.String(length=255), nullable=False),
sa.Column('sync', sa.Boolean(), nullable=False),
sa.Column('autojoin', sa.Boolean(), nullable=False),
sa.Column('displayname', sa.String(length=255), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('plugin',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('primary_user', sa.String(length=255), nullable=False),
sa.Column('config', sa.Text(), nullable=False),
sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('plugin')
op.drop_table('client')
# ### end Alembic commands ###

View File

@ -1,30 +0,0 @@
"""Add online field to clients
Revision ID: fccd1f95544d
Revises: d295f8dcfa64
Create Date: 2020-03-06 15:07:50.136644
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fccd1f95544d'
down_revision = 'd295f8dcfa64'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client") as batch_op:
batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("client") as batch_op:
batch_op.drop_column('online')
# ### end Alembic commands ###

View File

@ -1,7 +1,7 @@
#!/bin/sh
function fixperms {
chown -R $UID:$GID /var/log /data /opt/maubot
chown -R $UID:$GID /var/log /data
}
function fixdefault {

View File

@ -13,24 +13,37 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import asyncio
from __future__ import annotations
import asyncio
import sys
from mautrix.util.async_db import Database, DatabaseException
from mautrix.util.program import Program
from .__meta__ import __version__
from .client import Client, init as init_client_class
from .client import Client
from .config import Config
from .db import init as init_db
from .instance import init as init_plugin_instance_class
from .db import init as init_db, upgrade_table
from .instance import PluginInstance
from .lib.future_awaitable import FutureAwaitable
from .lib.state_store import PgStateStore
from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .server import MaubotServer
try:
from mautrix.crypto.store import PgCryptoStore
except ImportError:
PgCryptoStore = None
class Maubot(Program):
config: Config
server: MaubotServer
db: Database
crypto_db: Database | None
state_store: PgStateStore
config_class = Config
module = "maubot"
@ -45,6 +58,19 @@ class Maubot(Program):
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
def prepare_arg_parser(self) -> None:
super().prepare_arg_parser()
self.parser.add_argument(
"--ignore-unsupported-database",
action="store_true",
help="Run even if the database schema is too new",
)
self.parser.add_argument(
"--ignore-foreign-tables",
action="store_true",
help="Run even if the database contains tables from other programs (like Synapse)",
)
def prepare(self) -> None:
super().prepare()
@ -52,21 +78,59 @@ class Maubot(Program):
self.prepare_log_websocket()
init_zip_loader(self.config)
init_db(self.config)
clients = init_client_class(self.config, self.loop)
self.add_startup_actions(*(client.start() for client in clients))
self.db = Database.create(
self.config["database"],
upgrade_table=upgrade_table,
db_args=self.config["database_opts"],
owner_name=self.name,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
self.crypto_db = Database.create(
self.config["crypto_database"],
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
self.state_store = PgStateStore(self.db)
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
for plugin in plugins:
plugin.load()
async def start_db(self) -> None:
self.log.debug("Starting database...")
ignore_unsupported = self.args.ignore_unsupported_database
self.db.upgrade_table.allow_unsupported = ignore_unsupported
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
try:
await self.db.start()
await self.state_store.upgrade_table.upgrade(self.db)
if self.crypto_db and self.crypto_db is not self.db:
await self.crypto_db.start()
else:
await PgCryptoStore.upgrade_table.upgrade(self.db)
except DatabaseException as e:
self.log.critical("Failed to initialize database", exc_info=e)
if e.explanation:
self.log.info(e.explanation)
sys.exit(25)
async def system_exit(self) -> None:
if hasattr(self, "db"):
self.log.trace("Stopping database due to SystemExit")
await self.db.stop()
async def start(self) -> None:
if Client.crypto_db:
self.log.debug("Starting client crypto database")
await Client.crypto_db.start()
await self.start_db()
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
await asyncio.gather(*[client.start() async for client in Client.all()])
await super().start()
async for plugin in PluginInstance.all():
await plugin.load()
await self.server.start()
async def stop(self) -> None:
@ -77,6 +141,7 @@ class Maubot(Program):
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
await self.db.stop()
Maubot().run()

View File

@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.3.0+dev"

View File

@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
global history_count
history_count = tail
loop = asyncio.get_event_loop()
future = asyncio.ensure_future(view_logs(server, token), loop=loop)
future = asyncio.create_task(view_logs(server, token), loop=loop)
try:
loop.run_until_complete(future)
except KeyboardInterrupt:

View File

@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
from collections import defaultdict
import asyncio
import logging
from aiohttp import ClientSession
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
@ -41,69 +41,110 @@ from mautrix.types import (
SyncToken,
UserID,
)
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.logging import TraceLogger
from .db import DBClient
from .lib.store_proxy import SyncStoreProxy
from .db import Client as DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
from mautrix.crypto import OlmMachine, PgCryptoStore
crypto_import_error = None
except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
SQLStateStore = BaseSQLStateStore
OlmMachine = PgCryptoStore = None
crypto_import_error = e
if TYPE_CHECKING:
from .config import Config
from .__main__ import Maubot
from .instance import PluginInstance
log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
class Client(DBClient):
maubot: "Maubot" = None
cache: dict[UserID, Client] = {}
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: TraceLogger = logging.getLogger("maubot.client")
http_client: ClientSession = None
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: AsyncDatabase | None = None
references: set[PluginInstance]
db_instance: DBClient
client: MaubotMatrixClient
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
sync_ok: bool
remote_displayname: str | None
remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
def __init__(
self,
id: UserID,
homeserver: str,
access_token: str,
device_id: DeviceID,
enabled: bool = False,
next_batch: SyncToken = "",
filter_id: FilterID = "",
sync: bool = True,
autojoin: bool = True,
online: bool = True,
displayname: str = "disable",
avatar_url: str = "disable",
) -> None:
super().__init__(
id=id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id,
enabled=bool(enabled),
next_batch=next_batch,
filter_id=filter_id,
sync=bool(sync),
autojoin=bool(autojoin),
online=bool(online),
displayname=displayname,
avatar_url=avatar_url,
)
self._postinited = False
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def _make_client(
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
) -> MaubotMatrixClient:
return MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=token or self.access_token,
client_session=self.http_client,
log=self.log,
crypto_log=self.log.getChild("crypto"),
loop=self.maubot.loop,
device_id=device_id or self.device_id,
sync_store=self,
state_store=self.maubot.state_store,
)
def postinit(self) -> None:
if self._postinited:
raise RuntimeError("postinit() called twice")
self._postinited = True
self.cache[self.id] = self
self.log = log.getChild(self.id)
self.log = self.log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop)
self.references = set()
self.started = False
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
self.client = MaubotMatrixClient(
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
self.client = self._make_client()
if self.enable_crypto:
self._prepare_crypto()
else:
@ -118,6 +159,12 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
@property
def enable_crypto(self) -> bool:
if not self.device_id:
@ -131,16 +178,21 @@ class Client:
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
elif not self.crypto_db:
elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared")
return False
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
)
self.crypto = OlmMachine(
self.client,
self.crypto_store,
self.maubot.state_store,
log=self.client.crypto_log,
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None:
@ -156,12 +208,6 @@ class Client:
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
async def start(self, try_n: int | None = 0) -> None:
try:
if try_n > 0:
@ -196,31 +242,34 @@ class Client:
whoami = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
except Exception as e:
if try_n >= 8:
self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False
self.enabled = False
await self.update()
else:
self.log.warning(
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
)
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
_ = asyncio.create_task(self.start(try_n + 1))
return
if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
if not self.filter_id:
self.db_instance.edit(
filter_id=await self.client.create_filter(
self.filter_id = await self.client.create_filter(
Filter(
room=RoomFilter(
timeline=RoomEventFilter(
@ -236,7 +285,7 @@ class Client:
),
)
)
)
await self.update()
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
@ -270,18 +319,13 @@ class Client:
if self.crypto:
await self.crypto_store.close()
def clear_cache(self) -> None:
async def clear_cache(self) -> None:
self.stop_sync()
self.db_instance.edit(filter_id="", next_batch="")
self.filter_id = FilterID("")
self.next_batch = SyncToken("")
await self.update()
self.start_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
def to_dict(self) -> dict:
return {
"id": self.id,
@ -304,20 +348,6 @@ class Client:
"instances": [instance.to_dict() for instance in self.references],
}
@classmethod
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
@ -329,7 +359,7 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
async def update_started(self, started: bool) -> None:
async def update_started(self, started: bool | None) -> None:
if started is None or started == self.started:
return
if started:
@ -337,23 +367,65 @@ class Client:
else:
await self.stop()
async def update_displayname(self, displayname: str) -> None:
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
if enabled is None or enabled == self.enabled:
return
self.enabled = enabled
if save:
await self.update()
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
if displayname is None or displayname == self.displayname:
return
self.db_instance.displayname = displayname
self.displayname = displayname
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
if avatar_url is None or avatar_url == self.avatar_url:
return
self.db_instance.avatar_url = avatar_url
self.avatar_url = avatar_url
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
if sync is None or self.sync == sync:
return
self.sync = sync
if self.started:
if sync:
self.start_sync()
else:
self.stop_sync()
if save:
await self.update()
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
if autojoin is None or autojoin == self.autojoin:
return
if autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.autojoin = autojoin
if save:
await self.update()
async def update_online(self, online: bool | None, save: bool = True) -> None:
if online is None or online == self.online:
return
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
self.online = online
if save:
await self.update()
async def update_access_details(
self,
@ -373,22 +445,13 @@ class Client:
and device_id == self.device_id
):
return
new_client = MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
new_client = self._make_client(homeserver, access_token, device_id)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = SyncStoreProxy(self.db_instance)
new_client.sync_store = self
self.stop_sync()
# TODO this event handler transfer is pretty hacky
@ -398,9 +461,9 @@ class Client:
new_client.global_event_handlers = self.client.global_event_handlers
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.db_instance.device_id = device_id
self.homeserver = homeserver
self.access_token = access_token
self.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
@ -413,97 +476,53 @@ class Client:
profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
# region Properties
async def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
await super().delete()
@property
def id(self) -> UserID:
return self.db_instance.id
@classmethod
@async_getter_lock
async def get(
cls,
user_id: UserID,
*,
homeserver: str | None = None,
access_token: str | None = None,
device_id: DeviceID | None = None,
) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
pass
@property
def homeserver(self) -> str:
return self.db_instance.homeserver
user = cast(cls, await super().get(user_id))
if user is not None:
user.postinit()
return user
@property
def access_token(self) -> str:
return self.db_instance.access_token
if homeserver and access_token:
user = cls(
user_id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id or "",
)
await user.insert()
user.postinit()
return user
@property
def device_id(self) -> DeviceID:
return self.db_instance.device_id
return None
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.db_instance.autojoin = value
@property
def online(self) -> bool:
return self.db_instance.online
@online.setter
def online(self, value: bool) -> None:
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
self.db_instance.online = value
@property
def displayname(self) -> str:
return self.db_instance.displayname
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
# endregion
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
if OlmMachine:
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()
@classmethod
async def all(cls) -> AsyncGenerator[Client, None]:
users = await super().all()
user: cls
for user in users:
try:
yield cls.cache[user.id]
except KeyError:
user.postinit()
yield user

View File

@ -1,108 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, Optional
import logging
import sys
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.db import Base
from .config import Config
class DBPlugin(Base):
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(
String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False,
)
config: str = Column(Text, nullable=False, default="")
@classmethod
def all(cls) -> Iterable["DBPlugin"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional["DBPlugin"]:
return cls._select_one_or_none(cls.c.id == id)
class DBClient(Base):
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(Text, nullable=False)
device_id: DeviceID = Column(String(255), nullable=True)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="")
sync: bool = Column(Boolean, nullable=False, default=True)
autojoin: bool = Column(Boolean, nullable=False, default=True)
online: bool = Column(Boolean, nullable=False, default=True)
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable["DBClient"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional["DBClient"]:
return cls._select_one_or_none(cls.c.id == id)
def init(config: Config) -> Engine:
db = sql.create_engine(config["database"])
Base.metadata.bind = db
for table in (DBPlugin, DBClient, RoomState, UserProfile):
table.bind(db)
if not db.has_table("alembic_version"):
log = logging.getLogger("maubot.db")
if db.has_table("client") and db.has_table("plugin"):
log.warning(
"alembic_version table not found, but client and plugin tables found. "
"Assuming pre-Alembic database and inserting version."
)
db.execute(
"CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");"
)
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
else:
log.critical(
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
)
sys.exit(10)
return db

13
maubot/db/__init__.py Normal file
View File

@ -0,0 +1,13 @@
from mautrix.util.async_db import Database
from .client import Client
from .instance import Instance
from .upgrade import upgrade_table
def init(db: Database) -> None:
for table in (Client, Instance):
table.db = db
__all__ = ["upgrade_table", "init", "Client", "Instance"]

114
maubot/db/client.py Normal file
View File

@ -0,0 +1,114 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Client(SyncStore):
db: ClassVar[Database] = fake_db
id: UserID
homeserver: str
access_token: str
device_id: DeviceID
enabled: bool
next_batch: SyncToken
filter_id: FilterID
sync: bool
autojoin: bool
online: bool
displayname: str
avatar_url: ContentURI
@classmethod
def _from_row(cls, row: Record | None) -> Client | None:
if row is None:
return None
return cls(**row)
_columns = (
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
"sync, autojoin, online, displayname, avatar_url"
)
@property
def _values(self):
return (
self.id,
self.homeserver,
self.access_token,
self.device_id,
self.enabled,
self.next_batch,
self.filter_id,
self.sync,
self.autojoin,
self.online,
self.displayname,
self.avatar_url,
)
@classmethod
async def all(cls) -> list[Client]:
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Client | None:
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def insert(self) -> None:
q = """
INSERT INTO client (
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
sync, autojoin, online, displayname, avatar_url
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"""
await self.db.execute(q, *self._values)
async def put_next_batch(self, next_batch: SyncToken) -> None:
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
self.next_batch = next_batch
async def get_next_batch(self) -> SyncToken:
return self.next_batch
async def update(self) -> None:
q = """
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
displayname=$11, avatar_url=$12
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)

75
maubot/db/instance.py Normal file
View File

@ -0,0 +1,75 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Instance:
db: ClassVar[Database] = fake_db
id: str
type: str
enabled: bool
primary_user: UserID
config_str: str
@classmethod
def _from_row(cls, row: Record | None) -> Instance | None:
if row is None:
return None
return cls(**row)
@classmethod
async def all(cls) -> list[Instance]:
rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Instance | None:
q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def update_id(self, new_id: str) -> None:
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
self.id = new_id
@property
def _values(self):
return self.id, self.type, self.enabled, self.primary_user, self.config_str
async def insert(self) -> None:
q = (
"INSERT INTO instance (id, type, enabled, primary_user, config) "
"VALUES ($1, $2, $3, $4, $5)"
)
await self.db.execute(q, *self._values)
async def update(self) -> None:
q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)

View File

@ -0,0 +1,5 @@
from mautrix.util.async_db import UpgradeTable
upgrade_table = UpgradeTable()
from . import v01_initial_revision

View File

@ -0,0 +1,136 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version"
last_legacy_version = "90aa88820eab"
@upgrade_table.register(description="Initial asyncpg revision")
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
if await conn.table_exists("alembic_version"):
await migrate_legacy_to_v1(conn, scheme)
else:
return await create_v1_tables(conn)
async def create_v1_tables(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE client (
id TEXT PRIMARY KEY,
homeserver TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
next_batch TEXT NOT NULL,
filter_id TEXT NOT NULL,
sync BOOLEAN NOT NULL,
autojoin BOOLEAN NOT NULL,
online BOOLEAN NOT NULL,
displayname TEXT NOT NULL,
avatar_url TEXT NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE instance (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
primary_user TEXT NOT NULL,
config TEXT NOT NULL,
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
)"""
)
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version:
raise RuntimeError(
"Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
await conn.execute("ALTER TABLE plugin RENAME TO instance")
await update_state_store(conn, scheme)
if scheme != Scheme.SQLITE:
await varchar_to_text(conn)
await conn.execute("DROP TABLE alembic_version")
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
# The Matrix state store already has more or less the correct schema, so set the version
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
if scheme != Scheme.SQLITE:
# Remove old uppercase membership type and recreate it as lowercase
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
await conn.execute("DROP TYPE IF EXISTS membership")
await conn.execute(
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
)
await conn.execute(
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
"USING LOWER(membership)::membership"
)
else:
# Recreate table to remove CHECK constraint and lowercase everything
await conn.execute(
"""CREATE TABLE new_mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
)"""
)
await conn.execute(
"""
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
FROM mx_user_profile
"""
)
await conn.execute("DROP TABLE mx_user_profile")
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = {
"client": (
"id",
"homeserver",
"device_id",
"next_batch",
"filter_id",
"displayname",
"avatar_url",
),
"instance": ("id", "type", "primary_user"),
"mx_room_state": ("room_id",),
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
}
for table, columns in columns_to_adjust.items():
for column in columns:
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')

View File

@ -6,9 +6,7 @@
database: sqlite:///maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above.
# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above.
# When using postgres, using the same database for both is safe.
crypto_database: sqlite:///crypto.db
crypto_database: default
plugin_directories:
# The directory where uploaded new plugins should be stored.

View File

@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable
from asyncio import AbstractEventLoop
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
from collections import defaultdict
import asyncio
import inspect
import io
import logging
import os.path
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .client import Client
from .config import Config
from .db import DBPlugin
from .db import Instance as DBInstance
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .server import MaubotServer, PluginWebApp
from .__main__ import Maubot
from .server import PluginWebApp
log = logging.getLogger("maubot.instance")
@ -44,29 +47,42 @@ yaml.indent(4)
yaml.width = 200
class PluginInstance:
webserver: MaubotServer = None
mb_config: Config = None
loop: AbstractEventLoop = None
class PluginInstance(DBInstance):
maubot: "Maubot" = None
cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = []
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
config: BaseProxyConfig
loader: PluginLoader | None
client: Client | None
plugin: Plugin | None
config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine
inst_db_tables: dict[str, sql.Table]
inst_db: sql.engine.Engine | None
inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
started: bool
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
def __init__(
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
) -> None:
super().__init__(
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
)
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def postinit(self) -> None:
self.log = log.getChild(self.id)
self.cache[self.id] = self
self.config = None
self.started = False
self.loader = None
@ -78,7 +94,6 @@ class PluginInstance:
self.inst_webapp_url = None
self.base_cfg = None
self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict:
return {
@ -87,10 +102,10 @@ class PluginInstance:
"enabled": self.enabled,
"started": self.started,
"primary_user": self.primary_user,
"config": self.db_instance.config,
"config": self.config_str,
"base_config": self.base_cfg_str,
"database": (
self.inst_db is not None and self.mb_config["api_features.instance_database"]
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
),
}
@ -101,19 +116,19 @@ class PluginInstance:
self.inst_db_tables = metadata.tables
return self.inst_db_tables
def load(self) -> bool:
async def load(self) -> bool:
if not self.loader:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False
await self.update_enabled(False)
return False
if not self.client:
self.client = Client.get(self.primary_user)
self.client = await Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False
await self.update_enabled(False)
return False
if self.loader.meta.database:
self.enable_database()
@ -125,18 +140,18 @@ class PluginInstance:
return True
def enable_webapp(self) -> None:
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None:
self.webserver.remove_instance_webapp(self.id)
self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None
self.inst_webapp_url = None
def enable_database(self) -> None:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
def delete(self) -> None:
async def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
@ -145,23 +160,23 @@ class PluginInstance:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
await super().delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted",
)
if self.inst_webapp:
self.disable_webapp()
def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config)
return yaml.load(self.config_str)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
self.db_instance.config = buf.getvalue()
self.config_str = buf.getvalue()
async def start(self) -> None:
if self.started:
@ -172,7 +187,7 @@ class PluginInstance:
return
if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not self.load():
if not await self.load():
return
cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None:
@ -205,7 +220,7 @@ class PluginInstance:
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(
client=self.client.client,
loop=self.loop,
loop=self.maubot.loop,
http=self.client.http_client,
instance_id=self.id,
log=self.log,
@ -219,7 +234,7 @@ class PluginInstance:
await self.plugin.internal_start()
except Exception:
self.log.exception("Failed to start instance")
self.db_instance.enabled = False
await self.update_enabled(False)
return
self.started = True
self.inst_db_tables = None
@ -241,60 +256,51 @@ class PluginInstance:
self.plugin = None
self.inst_db_tables = None
@classmethod
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
async def update_id(self, new_id: str | None) -> None:
if new_id is not None and new_id.lower() != self.id:
await super().update_id(new_id.lower())
@classmethod
def all(cls) -> Iterable[PluginInstance]:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id.lower()
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
async def update_config(self, config: str | None) -> None:
if config is None or self.config_str == config:
return
self.db_instance.config = config
self.config_str = config
if self.started and self.plugin is not None:
self.plugin.on_external_config_update()
res = self.plugin.on_external_config_update()
if inspect.isawaitable(res):
await res
await self.update()
async def update_primary_user(self, primary_user: UserID) -> bool:
if not primary_user or primary_user == self.primary_user:
async def update_primary_user(self, primary_user: UserID | None) -> bool:
if primary_user is None or primary_user == self.primary_user:
return True
client = Client.get(primary_user)
client = await Client.get(primary_user)
if not client:
return False
await self.stop()
self.db_instance.primary_user = client.id
self.primary_user = client.id
if self.client:
self.client.references.remove(self)
self.client = client
self.client.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
async def update_type(self, type: str) -> bool:
if not type or type == self.type:
async def update_type(self, type: str | None) -> bool:
if type is None or type == self.type:
return True
try:
loader = PluginLoader.find(type)
except KeyError:
return False
await self.stop()
self.db_instance.type = loader.meta.id
self.type = loader.meta.id
if self.loader:
self.loader.references.remove(self)
self.loader = loader
self.loader.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}")
return True
@ -303,39 +309,41 @@ class PluginInstance:
if started is not None and started != self.started:
await (self.start() if started else self.stop())
def update_enabled(self, enabled: bool) -> None:
async def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
self.db_instance.enabled = enabled
self.enabled = enabled
await self.update()
# region Properties
@classmethod
@async_getter_lock
async def get(
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
pass
@property
def id(self) -> str:
return self.db_instance.id
instance = cast(cls, await super().get(instance_id))
if instance is not None:
instance.postinit()
return instance
@id.setter
def id(self, value: str) -> None:
self.db_instance.id = value
if type and primary_user:
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
await instance.insert()
instance.postinit()
return instance
@property
def type(self) -> str:
return self.db_instance.type
return None
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
# endregion
def init(
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all()
@classmethod
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
instances = await super().all()
instance: PluginInstance
for instance in instances:
try:
yield cls.cache[instance.id]
except KeyError:
instance.postinit()
yield instance

View File

@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
client = "maubot.client"
if module.startswith(client):
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
if module.startswith(client + "."):
suffix = ""
if module.endswith(".crypto"):
suffix = f".{MAU_COLOR}crypto{RESET}"
module = module[: -len(".crypto")]
module = module[len(client) + 1 :]
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
instance = "maubot.instance"
if module.startswith(instance):
if module.startswith(instance + "."):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
loader = "maubot.loader"
if module.startswith(loader):
if module.startswith(loader + "."):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot"):
if module.startswith("maubot."):
return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module)

View File

@ -13,16 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client import SyncStore
from mautrix.types import SyncToken
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
try:
from mautrix.crypto import StateStore as CryptoStateStore
class SyncStoreProxy(SyncStore):
def __init__(self, db_instance) -> None:
self.db_instance = db_instance
class PgStateStore(BasePgStateStore, CryptoStateStore):
pass
async def put_next_batch(self, next_batch: SyncToken) -> None:
self.db_instance.edit(next_batch=next_batch)
except ImportError as e:
PgStateStore = BasePgStateStore
async def get_next_batch(self) -> SyncToken:
return self.db_instance.next_batch
__all__ = ["PgStateStore"]

View File

@ -13,17 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
from abc import ABC, abstractmethod
import asyncio
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__
from ..plugin_base import Plugin
from .meta import PluginMeta
if TYPE_CHECKING:
from ..instance import PluginInstance
@ -35,36 +32,6 @@ class IDConflictError(Exception):
pass
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
class BasePluginLoader(ABC):
meta: PluginMeta
@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
async def read_file(self, path: str) -> bytes:
pass
def sync_list_files(self, directory: str) -> List[str]:
def sync_list_files(self, directory: str) -> list[str]:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
async def list_files(self, directory: str) -> List[str]:
async def list_files(self, directory: str) -> list[str]:
pass
class PluginLoader(BasePluginLoader, ABC):
id_cache: Dict[str, "PluginLoader"] = {}
id_cache: dict[str, PluginLoader] = {}
meta: PluginMeta
references: Set["PluginInstance"]
references: set[PluginInstance]
def __init__(self):
self.references = set()
@classmethod
def find(cls, plugin_id: str) -> "PluginLoader":
def find(cls, plugin_id: str) -> PluginLoader:
return cls.id_cache[plugin_id]
def to_dict(self) -> dict:
@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
)
@abstractmethod
async def load(self) -> Type[PluginClass]:
async def load(self) -> type[PluginClass]:
pass
@abstractmethod
async def reload(self) -> Type[PluginClass]:
async def reload(self) -> type[PluginClass]:
pass
@abstractmethod

53
maubot/loader/meta.py Normal file
View File

@ -0,0 +1,53 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []

View File

@ -29,7 +29,8 @@ from mautrix.types import SerializerError
from ..config import Config
from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
from .abc import IDConflictError, PluginClass, PluginLoader
from .meta import PluginMeta
yaml = YAML()

View File

@ -20,7 +20,7 @@ from aiohttp import web
from ...config import Config
from .auth import check_token
from .base import get_config, routes, set_config, set_loop
from .base import get_config, routes, set_config
from .middleware import auth, error
@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
set_config(cfg)
set_loop(loop)
for pkg, enabled in cfg["api_features"].items():
if enabled:
importlib.import_module(f"maubot.management.api.{pkg}")

View File

@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
def get_token(request: web.Request) -> str:
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", None)
token = request.query.get("access_token", "")
else:
token = token[len("Bearer ") :]
return token

View File

@ -24,7 +24,6 @@ from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef()
_config: Config | None = None
_loop: asyncio.AbstractEventLoop | None = None
def set_config(config: Config) -> None:
@ -36,15 +35,6 @@ def get_config() -> Config:
return _config
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
global _loop
_loop = loop
def get_loop() -> asyncio.AbstractEventLoop:
return _loop
@routes.get("/version")
async def version(_: web.Request) -> web.Response:
return web.json_response({"version": __version__})

View File

@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
from mautrix.types import FilterID, SyncToken, UserID
from ...client import Client
from ...db import DBClient
from .base import routes
from .responses import resp
@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
@routes.get("/client/{id}")
async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
client = await Client.get(user_id)
if not client:
return resp.client_not_found
return resp.found(client.to_dict())
@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
loop=Client.loop,
client_session=Client.http_client,
)
try:
@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
except MatrixConnectionError:
return resp.bad_client_connection_details
if user_id is None:
existing_client = Client.get(whoami.user_id, None)
existing_client = await Client.get(whoami.user_id)
if existing_client is not None:
return resp.user_exists
elif whoami.user_id != user_id:
return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient(
id=whoami.user_id,
homeserver=homeserver,
access_token=access_token,
enabled=data.get("enabled", True),
next_batch=SyncToken(""),
filter_id=FilterID(""),
sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
online=data.get("online", True),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id,
client = await Client.get(
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
)
client = Client(db_instance)
client.db_instance.insert()
client.enabled = data.get("enabled", True)
client.sync = data.get("sync", True)
client.autojoin = data.get("autojoin", True)
client.online = data.get("online", True)
client.displayname = data.get("displayname", "disable")
client.avatar_url = data.get("avatar_url", "disable")
await client.update()
await client.start()
return resp.created(client.to_dict())
@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try:
await client.update_access_details(
data.get("access_token", None),
data.get("homeserver", None),
data.get("device_id", None),
data.get("access_token"), data.get("homeserver"), data.get("device_id")
)
except MatrixInvalidToken:
return resp.bad_client_access_token
@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
with client.db_instance.edit_mode():
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
await client.update_started(data.get("started", None))
client.enabled = data.get("enabled", client.enabled)
client.autojoin = data.get("autojoin", client.autojoin)
client.online = data.get("online", client.online)
client.sync = data.get("sync", client.sync)
await client.update_avatar_url(data.get("avatar_url"), save=False)
await client.update_displayname(data.get("displayname"), save=False)
await client.update_started(data.get("started"))
await client.update_enabled(data.get("enabled"), save=False)
await client.update_autojoin(data.get("autojoin"), save=False)
await client.update_online(data.get("online"), save=False)
await client.update_sync(data.get("sync"), save=False)
await client.update()
return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = Client.get(user_id, None)
client = await Client.get(user_id)
if not client:
return await _create_client(user_id, data)
else:
@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
user_id = request.match_info["id"]
try:
data = await request.json()
except JSONDecodeError:
@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
@routes.delete("/client/{id}")
async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
user_id = request.match_info["id"]
client = await Client.get(user_id)
if not client:
return resp.client_not_found
if len(client.references) > 0:
return resp.client_in_use
if client.started:
await client.stop()
client.delete()
await client.delete()
return resp.deleted
@routes.post("/client/{id}/clearcache")
async def clear_client_cache(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
user_id = request.match_info["id"]
client = await Client.get(user_id)
if not client:
return resp.client_not_found
client.clear_cache()
await client.clear_cache()
return resp.ok

View File

@ -13,7 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, NamedTuple, Optional, Tuple
from __future__ import annotations
from typing import NamedTuple
from http import HTTPStatus
from json import JSONDecodeError
import asyncio
@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import get_config, get_loop, routes
from .base import get_config, routes
from .client import _create_client, _create_or_update_client
from .responses import resp
def known_homeservers() -> Dict[str, Dict[str, str]]:
def known_homeservers() -> dict[str, dict[str, str]]:
return get_config()["homeservers"]
@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(
request: web.Request,
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
) -> tuple[AuthRequestInfo | None, web.Response | None]:
server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None)
if not server:
@ -85,7 +87,7 @@ async def read_client_auth_request(
return (
AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()),
client=ClientAPI(base_url=base_url),
secret=server.get("secret"),
username=username,
password=password,
@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, get_loop().create_future()
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try:
@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
return web.json_response(res.serialize())
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")

View File

@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
@routes.view("/proxy/{id}/{path:_matrix/.+}")
async def proxy(request: web.Request) -> web.StreamResponse:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
client = await Client.get(user_id)
if not client:
return resp.client_not_found

View File

@ -18,7 +18,6 @@ from json import JSONDecodeError
from aiohttp import web
from ...client import Client
from ...db import DBPlugin
from ...instance import PluginInstance
from ...loader import PluginLoader
from .base import routes
@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
@routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type", None)
primary_user = data.get("primary_user", None)
plugin_type = data.get("type")
primary_user = data.get("primary_user")
if not plugin_type:
return resp.plugin_type_required
elif not primary_user:
return resp.primary_user_required
elif not Client.get(primary_user):
elif not await Client.get(primary_user):
return resp.primary_user_not_found
try:
PluginLoader.find(plugin_type)
except KeyError:
return resp.plugin_type_not_found
db_instance = DBPlugin(
id=instance_id,
type=plugin_type,
enabled=data.get("enabled", True),
primary_user=primary_user,
config=data.get("config", ""),
)
instance = PluginInstance(db_instance)
instance.load()
instance.db_instance.insert()
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
instance.enabled = data.get("enabled", True)
instance.config_str = data.get("config") or ""
await instance.update()
await instance.start()
return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user", None)):
if not await instance.update_primary_user(data.get("primary_user")):
return resp.primary_user_not_found
with instance.db_instance.edit_mode():
instance.update_id(data.get("id", None))
instance.update_enabled(data.get("enabled", None))
instance.update_config(data.get("config", None))
await instance.update_started(data.get("started", None))
await instance.update_type(data.get("type", None))
await instance.update_id(data.get("id"))
await instance.update_enabled(data.get("enabled"))
await instance.update_config(data.get("config"))
await instance.update_started(data.get("started"))
await instance.update_type(data.get("type"))
return resp.updated(instance.to_dict())
@routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
try:
data = await request.json()
except JSONDecodeError:
@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
@routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
if instance.started:
await instance.stop()
instance.delete()
await instance.delete()
return resp.deleted

View File

@ -29,8 +29,8 @@ from .responses import resp
@routes.get("/instance/{id}/database")
async def get_database(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
@ -65,8 +65,8 @@ def check_type(val):
@routes.get("/instance/{id}/database/{table}")
async def get_table(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
]
except KeyError:
order = []
limit = int(request.query.get("limit", 100))
limit = int(request.query.get("limit", "100"))
return execute_query(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query")
async def query(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:

View File

@ -23,7 +23,7 @@ import logging
from aiohttp import web, web_ws
from .auth import is_valid_token
from .base import get_loop, routes
from .base import routes
BUILTIN_ATTRS = {
"args",
@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
authenticated = False
async def close_if_not_authenticated():
await asyncio.sleep(5, loop=get_loop())
await asyncio.sleep(5)
if not authenticated:
await ws.close(code=4000)
log.debug(f"Connection from {request.remote} terminated due to no authentication")
asyncio.ensure_future(close_if_not_authenticated())
asyncio.create_task(close_if_not_authenticated())
try:
msg: web_ws.WSMessage

View File

@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
@routes.get("/plugin/{id}")
async def get_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found
return resp.found(plugin.to_dict())
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
@routes.delete("/plugin/{id}")
async def delete_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found
elif len(plugin.references) > 0:
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
@routes.post("/plugin/{id}/reload")
async def reload_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found

View File

@ -29,7 +29,7 @@ from .responses import resp
@routes.put("/plugin/{id}")
async def put_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin_id = request.match_info["id"]
content = await request.read()
file = BytesIO(content)
try:

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Awaitable
from abc import ABC
from asyncio import AbstractEventLoop
@ -124,6 +124,7 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None
def on_external_config_update(self) -> None:
def on_external_config_update(self) -> Awaitable[None] | None:
if self.config:
self.config.load_and_update()
return None

View File

@ -1,13 +1,10 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
#/postgres
psycopg2-binary>=2,<3
asyncpg>=0.20,<0.26
#/sqlite
aiosqlite>=0.16,<0.18
#/encryption
asyncpg>=0.20,<0.26
aiosqlite>=0.16,<0.18
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<3

View File

@ -1,7 +1,8 @@
mautrix>=0.15.0,<0.16
mautrix>=0.15.2,<0.16
aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4
asyncpg>=0.20,<0.26
alembic>=1,<2
commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18

View File

@ -1,5 +1,4 @@
import setuptools
import glob
import os
with open("requirements.txt") as reqs:
@ -57,9 +56,7 @@ setuptools.setup(
mbc=maubot.cli:app
""",
data_files=[
(".", ["maubot/example-config.yaml", "alembic.ini"]),
("alembic", ["alembic/env.py"]),
("alembic/versions", glob.glob("alembic/versions/*.py")),
(".", ["maubot/example-config.yaml"]),
],
package_data={
"maubot": [