parent
068e268c63
commit
21ed971d2f
@ -15,7 +15,6 @@ RUN apk add --no-cache \
|
||||
py3-attrs \
|
||||
py3-bcrypt \
|
||||
py3-cffi \
|
||||
py3-psycopg2 \
|
||||
py3-ruamel.yaml \
|
||||
py3-jinja2 \
|
||||
py3-click \
|
||||
@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
|
||||
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
||||
WORKDIR /opt/maubot
|
||||
RUN apk add --virtual .build-deps python3-dev build-base git \
|
||||
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
|
||||
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
||||
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
||||
&& apk del .build-deps
|
||||
|
@ -10,7 +10,6 @@ RUN apk add --no-cache \
|
||||
py3-attrs \
|
||||
py3-bcrypt \
|
||||
py3-cffi \
|
||||
py3-psycopg2 \
|
||||
py3-ruamel.yaml \
|
||||
py3-jinja2 \
|
||||
py3-click \
|
||||
@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
|
||||
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
||||
WORKDIR /opt/maubot
|
||||
RUN apk add --virtual .build-deps python3-dev build-base git \
|
||||
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
|
||||
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
||||
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
||||
&& apk del .build-deps
|
||||
|
83
alembic.ini
83
alembic.ini
@ -1,83 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# timezone to use when rendering the date
|
||||
# within the migration file as well as the filename.
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; this defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path
|
||||
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks=black
|
||||
# black.type=console_scripts
|
||||
# black.entrypoint=black
|
||||
# black.options=-l 79
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
@ -1 +0,0 @@
|
||||
Generic single-database configuration.
|
@ -1,92 +0,0 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
import sys
|
||||
from os.path import abspath, dirname
|
||||
|
||||
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
||||
|
||||
from mautrix.util.db import Base
|
||||
from maubot.config import Config
|
||||
from maubot import db
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
|
||||
maubot_config = Config(maubot_config_path, None)
|
||||
maubot_config.load()
|
||||
config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
@ -1,24 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade():
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade():
|
||||
${downgrades if downgrades else "pass"}
|
@ -1,32 +0,0 @@
|
||||
"""Add device_id to clients
|
||||
|
||||
Revision ID: 4b93300852aa
|
||||
Revises: fccd1f95544d
|
||||
Create Date: 2020-07-11 15:49:38.831459
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4b93300852aa'
|
||||
down_revision = 'fccd1f95544d'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('client', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('client', schema=None) as batch_op:
|
||||
batch_op.drop_column('device_id')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -1,47 +0,0 @@
|
||||
"""Add Matrix state store
|
||||
|
||||
Revision ID: 90aa88820eab
|
||||
Revises: 4b93300852aa
|
||||
Create Date: 2020-07-12 01:50:06.215623
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from mautrix.client.state_store.sqlalchemy import SerializableType
|
||||
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '90aa88820eab'
|
||||
down_revision = '4b93300852aa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('mx_room_state',
|
||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
|
||||
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
|
||||
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
|
||||
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
|
||||
sa.PrimaryKeyConstraint('room_id')
|
||||
)
|
||||
op.create_table('mx_user_profile',
|
||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
|
||||
sa.Column('displayname', sa.String(), nullable=True),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.PrimaryKeyConstraint('room_id', 'user_id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('mx_user_profile')
|
||||
op.drop_table('mx_room_state')
|
||||
# ### end Alembic commands ###
|
@ -1,50 +0,0 @@
|
||||
"""Initial revision
|
||||
|
||||
Revision ID: d295f8dcfa64
|
||||
Revises:
|
||||
Create Date: 2019-09-27 00:21:02.527915
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd295f8dcfa64'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('client',
|
||||
sa.Column('id', sa.String(length=255), nullable=False),
|
||||
sa.Column('homeserver', sa.String(length=255), nullable=False),
|
||||
sa.Column('access_token', sa.Text(), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('next_batch', sa.String(length=255), nullable=False),
|
||||
sa.Column('filter_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('sync', sa.Boolean(), nullable=False),
|
||||
sa.Column('autojoin', sa.Boolean(), nullable=False),
|
||||
sa.Column('displayname', sa.String(length=255), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('plugin',
|
||||
sa.Column('id', sa.String(length=255), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('primary_user', sa.String(length=255), nullable=False),
|
||||
sa.Column('config', sa.Text(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('plugin')
|
||||
op.drop_table('client')
|
||||
# ### end Alembic commands ###
|
@ -1,30 +0,0 @@
|
||||
"""Add online field to clients
|
||||
|
||||
Revision ID: fccd1f95544d
|
||||
Revises: d295f8dcfa64
|
||||
Create Date: 2020-03-06 15:07:50.136644
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fccd1f95544d'
|
||||
down_revision = 'd295f8dcfa64'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client") as batch_op:
|
||||
batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("client") as batch_op:
|
||||
batch_op.drop_column('online')
|
||||
# ### end Alembic commands ###
|
@ -1,7 +1,7 @@
|
||||
#!/bin/sh
|
||||
|
||||
function fixperms {
|
||||
chown -R $UID:$GID /var/log /data /opt/maubot
|
||||
chown -R $UID:$GID /var/log /data
|
||||
}
|
||||
|
||||
function fixdefault {
|
||||
|
@ -13,24 +13,37 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import asyncio
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from mautrix.util.async_db import Database, DatabaseException
|
||||
from mautrix.util.program import Program
|
||||
|
||||
from .__meta__ import __version__
|
||||
from .client import Client, init as init_client_class
|
||||
from .client import Client
|
||||
from .config import Config
|
||||
from .db import init as init_db
|
||||
from .instance import init as init_plugin_instance_class
|
||||
from .db import init as init_db, upgrade_table
|
||||
from .instance import PluginInstance
|
||||
from .lib.future_awaitable import FutureAwaitable
|
||||
from .lib.state_store import PgStateStore
|
||||
from .loader.zip import init as init_zip_loader
|
||||
from .management.api import init as init_mgmt_api
|
||||
from .server import MaubotServer
|
||||
|
||||
try:
|
||||
from mautrix.crypto.store import PgCryptoStore
|
||||
except ImportError:
|
||||
PgCryptoStore = None
|
||||
|
||||
|
||||
class Maubot(Program):
|
||||
config: Config
|
||||
server: MaubotServer
|
||||
db: Database
|
||||
crypto_db: Database | None
|
||||
state_store: PgStateStore
|
||||
|
||||
config_class = Config
|
||||
module = "maubot"
|
||||
@ -45,6 +58,19 @@ class Maubot(Program):
|
||||
init(self.loop)
|
||||
self.add_shutdown_actions(FutureAwaitable(stop_all))
|
||||
|
||||
def prepare_arg_parser(self) -> None:
|
||||
super().prepare_arg_parser()
|
||||
self.parser.add_argument(
|
||||
"--ignore-unsupported-database",
|
||||
action="store_true",
|
||||
help="Run even if the database schema is too new",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ignore-foreign-tables",
|
||||
action="store_true",
|
||||
help="Run even if the database contains tables from other programs (like Synapse)",
|
||||
)
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
|
||||
@ -52,21 +78,59 @@ class Maubot(Program):
|
||||
self.prepare_log_websocket()
|
||||
|
||||
init_zip_loader(self.config)
|
||||
init_db(self.config)
|
||||
clients = init_client_class(self.config, self.loop)
|
||||
self.add_startup_actions(*(client.start() for client in clients))
|
||||
self.db = Database.create(
|
||||
self.config["database"],
|
||||
upgrade_table=upgrade_table,
|
||||
db_args=self.config["database_opts"],
|
||||
owner_name=self.name,
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
init_db(self.db)
|
||||
if self.config["crypto_database"] == "default":
|
||||
self.crypto_db = self.db
|
||||
else:
|
||||
self.crypto_db = Database.create(
|
||||
self.config["crypto_database"],
|
||||
upgrade_table=PgCryptoStore.upgrade_table,
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
Client.init_cls(self)
|
||||
PluginInstance.init_cls(self)
|
||||
management_api = init_mgmt_api(self.config, self.loop)
|
||||
self.server = MaubotServer(management_api, self.config, self.loop)
|
||||
self.state_store = PgStateStore(self.db)
|
||||
|
||||
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
|
||||
for plugin in plugins:
|
||||
plugin.load()
|
||||
async def start_db(self) -> None:
|
||||
self.log.debug("Starting database...")
|
||||
ignore_unsupported = self.args.ignore_unsupported_database
|
||||
self.db.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
try:
|
||||
await self.db.start()
|
||||
await self.state_store.upgrade_table.upgrade(self.db)
|
||||
if self.crypto_db and self.crypto_db is not self.db:
|
||||
await self.crypto_db.start()
|
||||
else:
|
||||
await PgCryptoStore.upgrade_table.upgrade(self.db)
|
||||
except DatabaseException as e:
|
||||
self.log.critical("Failed to initialize database", exc_info=e)
|
||||
if e.explanation:
|
||||
self.log.info(e.explanation)
|
||||
sys.exit(25)
|
||||
|
||||
async def system_exit(self) -> None:
|
||||
if hasattr(self, "db"):
|
||||
self.log.trace("Stopping database due to SystemExit")
|
||||
await self.db.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if Client.crypto_db:
|
||||
self.log.debug("Starting client crypto database")
|
||||
await Client.crypto_db.start()
|
||||
await self.start_db()
|
||||
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
|
||||
await asyncio.gather(*[client.start() async for client in Client.all()])
|
||||
await super().start()
|
||||
async for plugin in PluginInstance.all():
|
||||
await plugin.load()
|
||||
await self.server.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
@ -77,6 +141,7 @@ class Maubot(Program):
|
||||
await asyncio.wait_for(self.server.stop(), 5)
|
||||
except asyncio.TimeoutError:
|
||||
self.log.warning("Stopping server timed out")
|
||||
await self.db.stop()
|
||||
|
||||
|
||||
Maubot().run()
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.2.1"
|
||||
__version__ = "0.3.0+dev"
|
||||
|
@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
|
||||
global history_count
|
||||
history_count = tail
|
||||
loop = asyncio.get_event_loop()
|
||||
future = asyncio.ensure_future(view_logs(server, token), loop=loop)
|
||||
future = asyncio.create_task(view_logs(server, token), loop=loop)
|
||||
try:
|
||||
loop.run_until_complete(future)
|
||||
except KeyboardInterrupt:
|
||||
|
407
maubot/client.py
407
maubot/client.py
@ -15,14 +15,14 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
from mautrix.errors import MatrixInvalidToken
|
||||
from mautrix.types import (
|
||||
ContentURI,
|
||||
@ -41,69 +41,110 @@ from mautrix.types import (
|
||||
SyncToken,
|
||||
UserID,
|
||||
)
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from .db import DBClient
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import Client as DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore
|
||||
|
||||
crypto_import_error = None
|
||||
except ImportError as e:
|
||||
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
OlmMachine = PgCryptoStore = None
|
||||
crypto_import_error = e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .__main__ import Maubot
|
||||
from .instance import PluginInstance
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class Client:
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
class Client(DBClient):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[UserID, Client] = {}
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
log: TraceLogger = logging.getLogger("maubot.client")
|
||||
|
||||
http_client: ClientSession = None
|
||||
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
|
||||
crypto_db: AsyncDatabase | None = None
|
||||
|
||||
references: set[PluginInstance]
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: OlmMachine | None
|
||||
crypto_store: PgCryptoStore | None
|
||||
started: bool
|
||||
sync_ok: bool
|
||||
|
||||
remote_displayname: str | None
|
||||
remote_avatar_url: ContentURI | None
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self,
|
||||
id: UserID,
|
||||
homeserver: str,
|
||||
access_token: str,
|
||||
device_id: DeviceID,
|
||||
enabled: bool = False,
|
||||
next_batch: SyncToken = "",
|
||||
filter_id: FilterID = "",
|
||||
sync: bool = True,
|
||||
autojoin: bool = True,
|
||||
online: bool = True,
|
||||
displayname: str = "disable",
|
||||
avatar_url: str = "disable",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id,
|
||||
enabled=bool(enabled),
|
||||
next_batch=next_batch,
|
||||
filter_id=filter_id,
|
||||
sync=bool(sync),
|
||||
autojoin=bool(autojoin),
|
||||
online=bool(online),
|
||||
displayname=displayname,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
self._postinited = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def _make_client(
|
||||
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
|
||||
) -> MaubotMatrixClient:
|
||||
return MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=token or self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
crypto_log=self.log.getChild("crypto"),
|
||||
loop=self.maubot.loop,
|
||||
device_id=device_id or self.device_id,
|
||||
sync_store=self,
|
||||
state_store=self.maubot.state_store,
|
||||
)
|
||||
|
||||
def postinit(self) -> None:
|
||||
if self._postinited:
|
||||
raise RuntimeError("postinit() called twice")
|
||||
self._postinited = True
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.log = self.log.getChild(self.id)
|
||||
self.http_client = ClientSession(loop=self.maubot.loop)
|
||||
self.references = set()
|
||||
self.started = False
|
||||
self.sync_ok = True
|
||||
self.remote_displayname = None
|
||||
self.remote_avatar_url = None
|
||||
self.client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
loop=self.loop,
|
||||
device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
self.client = self._make_client()
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
else:
|
||||
@ -118,6 +159,12 @@ class Client:
|
||||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
@property
|
||||
def enable_crypto(self) -> bool:
|
||||
if not self.device_id:
|
||||
@ -131,16 +178,21 @@ class Client:
|
||||
# Clear the stack trace after it's logged once to avoid spamming logs
|
||||
crypto_import_error = None
|
||||
return False
|
||||
elif not self.crypto_db:
|
||||
elif not self.maubot.crypto_db:
|
||||
self.log.warning("Client has device ID, but crypto database is not prepared")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _prepare_crypto(self) -> None:
|
||||
self.crypto_store = PgCryptoStore(
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
|
||||
)
|
||||
self.crypto = OlmMachine(
|
||||
self.client,
|
||||
self.crypto_store,
|
||||
self.maubot.state_store,
|
||||
log=self.client.crypto_log,
|
||||
)
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
|
||||
def _remove_crypto_event_handlers(self) -> None:
|
||||
@ -156,12 +208,6 @@ class Client:
|
||||
for event_type, func in handlers:
|
||||
self.client.remove_event_handler(event_type, func)
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
async def start(self, try_n: int | None = 0) -> None:
|
||||
try:
|
||||
if try_n > 0:
|
||||
@ -196,47 +242,50 @@ class Client:
|
||||
whoami = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
except Exception as e:
|
||||
if try_n >= 8:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
else:
|
||||
self.log.warning(
|
||||
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
|
||||
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
|
||||
)
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
_ = asyncio.create_task(self.start(try_n + 1))
|
||||
return
|
||||
if whoami.user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||
self.log.error(
|
||||
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
||||
)
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.db_instance.edit(
|
||||
filter_id=await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
self.filter_id = await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
)
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.update()
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
@ -270,18 +319,13 @@ class Client:
|
||||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
async def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
self.db_instance.edit(filter_id="", next_batch="")
|
||||
self.filter_id = FilterID("")
|
||||
self.next_batch = SyncToken("")
|
||||
await self.update()
|
||||
self.start_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
@ -304,20 +348,6 @@ class Client:
|
||||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[Client]:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
||||
if not evt.content.replacement_room:
|
||||
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
|
||||
@ -329,7 +359,7 @@ class Client:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room(evt.room_id)
|
||||
|
||||
async def update_started(self, started: bool) -> None:
|
||||
async def update_started(self, started: bool | None) -> None:
|
||||
if started is None or started == self.started:
|
||||
return
|
||||
if started:
|
||||
@ -337,23 +367,65 @@ class Client:
|
||||
else:
|
||||
await self.stop()
|
||||
|
||||
async def update_displayname(self, displayname: str) -> None:
|
||||
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
|
||||
if enabled is None or enabled == self.enabled:
|
||||
return
|
||||
self.enabled = enabled
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
|
||||
if displayname is None or displayname == self.displayname:
|
||||
return
|
||||
self.db_instance.displayname = displayname
|
||||
self.displayname = displayname
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
||||
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
|
||||
if avatar_url is None or avatar_url == self.avatar_url:
|
||||
return
|
||||
self.db_instance.avatar_url = avatar_url
|
||||
self.avatar_url = avatar_url
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
|
||||
if sync is None or self.sync == sync:
|
||||
return
|
||||
self.sync = sync
|
||||
if self.started:
|
||||
if sync:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
|
||||
if autojoin is None or autojoin == self.autojoin:
|
||||
return
|
||||
if autojoin:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.autojoin = autojoin
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_online(self, online: bool | None, save: bool = True) -> None:
|
||||
if online is None or online == self.online:
|
||||
return
|
||||
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
|
||||
self.online = online
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_access_details(
|
||||
self,
|
||||
@ -373,22 +445,13 @@ class Client:
|
||||
and device_id == self.device_id
|
||||
):
|
||||
return
|
||||
new_client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token,
|
||||
loop=self.loop,
|
||||
device_id=device_id,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
new_client = self._make_client(homeserver, access_token, device_id)
|
||||
whoami = await new_client.whoami()
|
||||
if whoami.user_id != self.id:
|
||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
||||
new_client.sync_store = self
|
||||
self.stop_sync()
|
||||
|
||||
# TODO this event handler transfer is pretty hacky
|
||||
@ -398,9 +461,9 @@ class Client:
|
||||
new_client.global_event_handlers = self.client.global_event_handlers
|
||||
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
self.db_instance.access_token = access_token
|
||||
self.db_instance.device_id = device_id
|
||||
self.homeserver = homeserver
|
||||
self.access_token = access_token
|
||||
self.device_id = device_id
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
await self._start_crypto()
|
||||
@ -413,97 +476,53 @@ class Client:
|
||||
profile = await self.client.get_profile(self.id)
|
||||
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
|
||||
|
||||
# region Properties
|
||||
async def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
await super().delete()
|
||||
|
||||
@property
|
||||
def id(self) -> UserID:
|
||||
return self.db_instance.id
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls,
|
||||
user_id: UserID,
|
||||
*,
|
||||
homeserver: str | None = None,
|
||||
access_token: str | None = None,
|
||||
device_id: DeviceID | None = None,
|
||||
) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def homeserver(self) -> str:
|
||||
return self.db_instance.homeserver
|
||||
user = cast(cls, await super().get(user_id))
|
||||
if user is not None:
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def access_token(self) -> str:
|
||||
return self.db_instance.access_token
|
||||
if homeserver and access_token:
|
||||
user = cls(
|
||||
user_id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id or "",
|
||||
)
|
||||
await user.insert()
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def device_id(self) -> DeviceID:
|
||||
return self.db_instance.device_id
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
||||
@property
|
||||
def filter_id(self) -> FilterID:
|
||||
return self.db_instance.filter_id
|
||||
|
||||
@property
|
||||
def sync(self) -> bool:
|
||||
return self.db_instance.sync
|
||||
|
||||
@sync.setter
|
||||
def sync(self, value: bool) -> None:
|
||||
if value == self.db_instance.sync:
|
||||
return
|
||||
self.db_instance.sync = value
|
||||
if self.started:
|
||||
if value:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
|
||||
@property
|
||||
def autojoin(self) -> bool:
|
||||
return self.db_instance.autojoin
|
||||
|
||||
@autojoin.setter
|
||||
def autojoin(self, value: bool) -> None:
|
||||
if value == self.db_instance.autojoin:
|
||||
return
|
||||
if value:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.db_instance.autojoin = value
|
||||
|
||||
@property
|
||||
def online(self) -> bool:
|
||||
return self.db_instance.online
|
||||
|
||||
@online.setter
|
||||
def online(self, value: bool) -> None:
|
||||
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
|
||||
self.db_instance.online = value
|
||||
|
||||
@property
|
||||
def displayname(self) -> str:
|
||||
return self.db_instance.displayname
|
||||
|
||||
@property
|
||||
def avatar_url(self) -> ContentURI:
|
||||
return self.db_instance.avatar_url
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_url = config["crypto_database"]
|
||||
if db_url == "default":
|
||||
db_url = config["database"]
|
||||
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
|
||||
|
||||
return Client.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[Client, None]:
|
||||
users = await super().all()
|
||||
user: cls
|
||||
for user in users:
|
||||
try:
|
||||
yield cls.cache[user.id]
|
||||
except KeyError:
|
||||
user.postinit()
|
||||
yield user
|
||||
|
108
maubot/db.py
108
maubot/db.py
@ -1,108 +0,0 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Iterable, Optional
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
|
||||
from sqlalchemy.engine.base import Engine
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
||||
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
class DBPlugin(Base):
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: str = Column(String(255), primary_key=True)
|
||||
type: str = Column(String(255), nullable=False)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
primary_user: UserID = Column(
|
||||
String(255),
|
||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
config: str = Column(Text, nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable["DBPlugin"]:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional["DBPlugin"]:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
class DBClient(Base):
|
||||
__tablename__ = "client"
|
||||
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
homeserver: str = Column(String(255), nullable=False)
|
||||
access_token: str = Column(Text, nullable=False)
|
||||
device_id: DeviceID = Column(String(255), nullable=True)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
||||
|
||||
sync: bool = Column(Boolean, nullable=False, default=True)
|
||||
autojoin: bool = Column(Boolean, nullable=False, default=True)
|
||||
online: bool = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
displayname: str = Column(String(255), nullable=False, default="")
|
||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable["DBClient"]:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional["DBClient"]:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
def init(config: Config) -> Engine:
|
||||
db = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db
|
||||
|
||||
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
||||
table.bind(db)
|
||||
|
||||
if not db.has_table("alembic_version"):
|
||||
log = logging.getLogger("maubot.db")
|
||||
|
||||
if db.has_table("client") and db.has_table("plugin"):
|
||||
log.warning(
|
||||
"alembic_version table not found, but client and plugin tables found. "
|
||||
"Assuming pre-Alembic database and inserting version."
|
||||
)
|
||||
db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS alembic_version ("
|
||||
" version_num VARCHAR(32) PRIMARY KEY"
|
||||
");"
|
||||
)
|
||||
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
|
||||
else:
|
||||
log.critical(
|
||||
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
|
||||
)
|
||||
sys.exit(10)
|
||||
|
||||
return db
|
13
maubot/db/__init__.py
Normal file
13
maubot/db/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from .client import Client
|
||||
from .instance import Instance
|
||||
from .upgrade import upgrade_table
|
||||
|
||||
|
||||
def init(db: Database) -> None:
|
||||
for table in (Client, Instance):
|
||||
table.db = db
|
||||
|
||||
|
||||
__all__ = ["upgrade_table", "init", "Client", "Instance"]
|
114
maubot/db/client.py
Normal file
114
maubot/db/client.py
Normal file
@ -0,0 +1,114 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client(SyncStore):
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: UserID
|
||||
homeserver: str
|
||||
access_token: str
|
||||
device_id: DeviceID
|
||||
enabled: bool
|
||||
|
||||
next_batch: SyncToken
|
||||
filter_id: FilterID
|
||||
|
||||
sync: bool
|
||||
autojoin: bool
|
||||
online: bool
|
||||
|
||||
displayname: str
|
||||
avatar_url: ContentURI
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Client | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
_columns = (
|
||||
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
|
||||
"sync, autojoin, online, displayname, avatar_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
self.id,
|
||||
self.homeserver,
|
||||
self.access_token,
|
||||
self.device_id,
|
||||
self.enabled,
|
||||
self.next_batch,
|
||||
self.filter_id,
|
||||
self.sync,
|
||||
self.autojoin,
|
||||
self.online,
|
||||
self.displayname,
|
||||
self.avatar_url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Client]:
|
||||
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get(cls, id: str) -> Client | None:
|
||||
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = """
|
||||
INSERT INTO client (
|
||||
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
|
||||
sync, autojoin, online, displayname, avatar_url
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
"""
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
|
||||
self.next_batch = next_batch
|
||||
|
||||
async def get_next_batch(self) -> SyncToken:
|
||||
return self.next_batch
|
||||
|
||||
async def update(self) -> None:
|
||||
q = """
|
||||
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
|
||||
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
|
||||
displayname=$11, avatar_url=$12
|
||||
WHERE id=$1
|
||||
"""
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)
|
75
maubot/db/instance.py
Normal file
75
maubot/db/instance.py
Normal file
@ -0,0 +1,75 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Instance:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: str
|
||||
type: str
|
||||
enabled: bool
|
||||
primary_user: UserID
|
||||
config_str: str
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Instance | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Instance]:
|
||||
rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get(cls, id: str) -> Instance | None:
|
||||
q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||
|
||||
async def update_id(self, new_id: str) -> None:
|
||||
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
|
||||
self.id = new_id
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return self.id, self.type, self.enabled, self.primary_user, self.config_str
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO instance (id, type, enabled, primary_user, config) "
|
||||
"VALUES ($1, $2, $3, $4, $5)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def update(self) -> None:
|
||||
q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)
|
5
maubot/db/upgrade/__init__.py
Normal file
5
maubot/db/upgrade/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from mautrix.util.async_db import UpgradeTable
|
||||
|
||||
upgrade_table = UpgradeTable()
|
||||
|
||||
from . import v01_initial_revision
|
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
@ -0,0 +1,136 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from mautrix.util.async_db import Connection, Scheme
|
||||
|
||||
from . import upgrade_table
|
||||
|
||||
legacy_version_query = "SELECT version_num FROM alembic_version"
|
||||
last_legacy_version = "90aa88820eab"
|
||||
|
||||
|
||||
@upgrade_table.register(description="Initial asyncpg revision")
|
||||
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
|
||||
if await conn.table_exists("alembic_version"):
|
||||
await migrate_legacy_to_v1(conn, scheme)
|
||||
else:
|
||||
return await create_v1_tables(conn)
|
||||
|
||||
|
||||
async def create_v1_tables(conn: Connection) -> None:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE client (
|
||||
id TEXT PRIMARY KEY,
|
||||
homeserver TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL,
|
||||
|
||||
next_batch TEXT NOT NULL,
|
||||
filter_id TEXT NOT NULL,
|
||||
|
||||
sync BOOLEAN NOT NULL,
|
||||
autojoin BOOLEAN NOT NULL,
|
||||
online BOOLEAN NOT NULL,
|
||||
|
||||
displayname TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE instance (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL,
|
||||
primary_user TEXT NOT NULL,
|
||||
config TEXT NOT NULL,
|
||||
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
)"""
|
||||
)
|
||||
|
||||
|
||||
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
|
||||
legacy_version = await conn.fetchval(legacy_version_query)
|
||||
if legacy_version != last_legacy_version:
|
||||
raise RuntimeError(
|
||||
"Legacy database is not on last version. "
|
||||
"Please upgrade the old database with alembic or drop it completely first."
|
||||
)
|
||||
await conn.execute("ALTER TABLE plugin RENAME TO instance")
|
||||
await update_state_store(conn, scheme)
|
||||
if scheme != Scheme.SQLITE:
|
||||
await varchar_to_text(conn)
|
||||
await conn.execute("DROP TABLE alembic_version")
|
||||
|
||||
|
||||
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
|
||||
# The Matrix state store already has more or less the correct schema, so set the version
|
||||
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
|
||||
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
|
||||
if scheme != Scheme.SQLITE:
|
||||
# Remove old uppercase membership type and recreate it as lowercase
|
||||
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
|
||||
await conn.execute("DROP TYPE IF EXISTS membership")
|
||||
await conn.execute(
|
||||
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
|
||||
)
|
||||
await conn.execute(
|
||||
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
|
||||
"USING LOWER(membership)::membership"
|
||||
)
|
||||
else:
|
||||
# Recreate table to remove CHECK constraint and lowercase everything
|
||||
await conn.execute(
|
||||
"""CREATE TABLE new_mx_user_profile (
|
||||
room_id TEXT,
|
||||
user_id TEXT,
|
||||
membership TEXT NOT NULL
|
||||
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
PRIMARY KEY (room_id, user_id)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
|
||||
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
|
||||
FROM mx_user_profile
|
||||
"""
|
||||
)
|
||||
await conn.execute("DROP TABLE mx_user_profile")
|
||||
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
|
||||
|
||||
|
||||
async def varchar_to_text(conn: Connection) -> None:
|
||||
columns_to_adjust = {
|
||||
"client": (
|
||||
"id",
|
||||
"homeserver",
|
||||
"device_id",
|
||||
"next_batch",
|
||||
"filter_id",
|
||||
"displayname",
|
||||
"avatar_url",
|
||||
),
|
||||
"instance": ("id", "type", "primary_user"),
|
||||
"mx_room_state": ("room_id",),
|
||||
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
|
||||
}
|
||||
for table, columns in columns_to_adjust.items():
|
||||
for column in columns:
|
||||
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')
|
@ -6,9 +6,7 @@
|
||||
database: sqlite:///maubot.db
|
||||
|
||||
# Separate database URL for the crypto database. "default" means use the same database as above.
|
||||
# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above.
|
||||
# When using postgres, using the same database for both is safe.
|
||||
crypto_database: sqlite:///crypto.db
|
||||
crypto_database: default
|
||||
|
||||
plugin_directories:
|
||||
# The directory where uploaded new plugins should be stored.
|
||||
|
@ -15,8 +15,10 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import os.path
|
||||
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
|
||||
from .client import Client
|
||||
from .config import Config
|
||||
from .db import DBPlugin
|
||||
from .db import Instance as DBInstance
|
||||
from .loader import PluginLoader, ZippedPluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import MaubotServer, PluginWebApp
|
||||
from .__main__ import Maubot
|
||||
from .server import PluginWebApp
|
||||
|
||||
log = logging.getLogger("maubot.instance")
|
||||
|
||||
@ -44,29 +47,42 @@ yaml.indent(4)
|
||||
yaml.width = 200
|
||||
|
||||
|
||||
class PluginInstance:
|
||||
webserver: MaubotServer = None
|
||||
mb_config: Config = None
|
||||
loop: AbstractEventLoop = None
|
||||
class PluginInstance(DBInstance):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[str, PluginInstance] = {}
|
||||
plugin_directories: list[str] = []
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
|
||||
log: logging.Logger
|
||||
loader: PluginLoader
|
||||
client: Client
|
||||
plugin: Plugin
|
||||
config: BaseProxyConfig
|
||||
loader: PluginLoader | None
|
||||
client: Client | None
|
||||
plugin: Plugin | None
|
||||
config: BaseProxyConfig | None
|
||||
base_cfg: RecursiveDict[CommentedMap] | None
|
||||
base_cfg_str: str | None
|
||||
inst_db: sql.engine.Engine
|
||||
inst_db_tables: dict[str, sql.Table]
|
||||
inst_db: sql.engine.Engine | None
|
||||
inst_db_tables: dict[str, sql.Table] | None
|
||||
inst_webapp: PluginWebApp | None
|
||||
inst_webapp_url: str | None
|
||||
started: bool
|
||||
|
||||
def __init__(self, db_instance: DBPlugin):
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def postinit(self) -> None:
|
||||
self.log = log.getChild(self.id)
|
||||
self.cache[self.id] = self
|
||||
self.config = None
|
||||
self.started = False
|
||||
self.loader = None
|
||||
@ -78,7 +94,6 @@ class PluginInstance:
|
||||
self.inst_webapp_url = None
|
||||
self.base_cfg = None
|
||||
self.base_cfg_str = None
|
||||
self.cache[self.id] = self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@ -87,10 +102,10 @@ class PluginInstance:
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"primary_user": self.primary_user,
|
||||
"config": self.db_instance.config,
|
||||
"config": self.config_str,
|
||||
"base_config": self.base_cfg_str,
|
||||
"database": (
|
||||
self.inst_db is not None and self.mb_config["api_features.instance_database"]
|
||||
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
|
||||
),
|
||||
}
|
||||
|
||||
@ -101,19 +116,19 @@ class PluginInstance:
|
||||
self.inst_db_tables = metadata.tables
|
||||
return self.inst_db_tables
|
||||
|
||||
def load(self) -> bool:
|
||||
async def load(self) -> bool:
|
||||
if not self.loader:
|
||||
try:
|
||||
self.loader = PluginLoader.find(self.type)
|
||||
except KeyError:
|
||||
self.log.error(f"Failed to find loader for type {self.type}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if not self.client:
|
||||
self.client = Client.get(self.primary_user)
|
||||
self.client = await Client.get(self.primary_user)
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if self.loader.meta.database:
|
||||
self.enable_database()
|
||||
@ -125,18 +140,18 @@ class PluginInstance:
|
||||
return True
|
||||
|
||||
def enable_webapp(self) -> None:
|
||||
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
|
||||
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
|
||||
|
||||
def disable_webapp(self) -> None:
|
||||
self.webserver.remove_instance_webapp(self.id)
|
||||
self.maubot.server.remove_instance_webapp(self.id)
|
||||
self.inst_webapp = None
|
||||
self.inst_webapp_url = None
|
||||
|
||||
def enable_database(self) -> None:
|
||||
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
|
||||
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
||||
|
||||
def delete(self) -> None:
|
||||
async def delete(self) -> None:
|
||||
if self.loader is not None:
|
||||
self.loader.references.remove(self)
|
||||
if self.client is not None:
|
||||
@ -145,23 +160,23 @@ class PluginInstance:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
await super().delete()
|
||||
if self.inst_db:
|
||||
self.inst_db.dispose()
|
||||
ZippedPluginLoader.trash(
|
||||
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
|
||||
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
||||
reason="deleted",
|
||||
)
|
||||
if self.inst_webapp:
|
||||
self.disable_webapp()
|
||||
|
||||
def load_config(self) -> CommentedMap:
|
||||
return yaml.load(self.db_instance.config)
|
||||
return yaml.load(self.config_str)
|
||||
|
||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||
buf = io.StringIO()
|
||||
yaml.dump(data, buf)
|
||||
self.db_instance.config = buf.getvalue()
|
||||
self.config_str = buf.getvalue()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.started:
|
||||
@ -172,7 +187,7 @@ class PluginInstance:
|
||||
return
|
||||
if not self.client or not self.loader:
|
||||
self.log.warning("Missing plugin instance dependencies, attempting to load...")
|
||||
if not self.load():
|
||||
if not await self.load():
|
||||
return
|
||||
cls = await self.loader.load()
|
||||
if self.loader.meta.webapp and self.inst_webapp is None:
|
||||
@ -205,7 +220,7 @@ class PluginInstance:
|
||||
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
|
||||
self.plugin = cls(
|
||||
client=self.client.client,
|
||||
loop=self.loop,
|
||||
loop=self.maubot.loop,
|
||||
http=self.client.http_client,
|
||||
instance_id=self.id,
|
||||
log=self.log,
|
||||
@ -219,7 +234,7 @@ class PluginInstance:
|
||||
await self.plugin.internal_start()
|
||||
except Exception:
|
||||
self.log.exception("Failed to start instance")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return
|
||||
self.started = True
|
||||
self.inst_db_tables = None
|
||||
@ -241,60 +256,51 @@ class PluginInstance:
|
||||
self.plugin = None
|
||||
self.inst_db_tables = None
|
||||
|
||||
@classmethod
|
||||
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
async def update_id(self, new_id: str | None) -> None:
|
||||
if new_id is not None and new_id.lower() != self.id:
|
||||
await super().update_id(new_id.lower())
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[PluginInstance]:
|
||||
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
|
||||
|
||||
def update_id(self, new_id: str) -> None:
|
||||
if new_id is not None and new_id != self.id:
|
||||
self.db_instance.id = new_id.lower()
|
||||
|
||||
def update_config(self, config: str) -> None:
|
||||
if not config or self.db_instance.config == config:
|
||||
async def update_config(self, config: str | None) -> None:
|
||||
if config is None or self.config_str == config:
|
||||
return
|
||||
self.db_instance.config = config
|
||||
self.config_str = config
|
||||
if self.started and self.plugin is not None:
|
||||
self.plugin.on_external_config_update()
|
||||
res = self.plugin.on_external_config_update()
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
await self.update()
|
||||
|
||||
async def update_primary_user(self, primary_user: UserID) -> bool:
|
||||
if not primary_user or primary_user == self.primary_user:
|
||||
async def update_primary_user(self, primary_user: UserID | None) -> bool:
|
||||
if primary_user is None or primary_user == self.primary_user:
|
||||
return True
|
||||
client = Client.get(primary_user)
|
||||
client = await Client.get(primary_user)
|
||||
if not client:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.primary_user = client.id
|
||||
self.primary_user = client.id
|
||||
if self.client:
|
||||
self.client.references.remove(self)
|
||||
self.client = client
|
||||
self.client.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Primary user switched to {self.client.id}")
|
||||
return True
|
||||
|
||||
async def update_type(self, type: str) -> bool:
|
||||
if not type or type == self.type:
|
||||
async def update_type(self, type: str | None) -> bool:
|
||||
if type is None or type == self.type:
|
||||
return True
|
||||
try:
|
||||
loader = PluginLoader.find(type)
|
||||
except KeyError:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.type = loader.meta.id
|
||||
self.type = loader.meta.id
|
||||
if self.loader:
|
||||
self.loader.references.remove(self)
|
||||
self.loader = loader
|
||||
self.loader.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Type switched to {self.loader.meta.id}")
|
||||
return True
|
||||
@ -303,39 +309,41 @@ class PluginInstance:
|
||||
if started is not None and started != self.started:
|
||||
await (self.start() if started else self.stop())
|
||||
|
||||
def update_enabled(self, enabled: bool) -> None:
|
||||
async def update_enabled(self, enabled: bool) -> None:
|
||||
if enabled is not None and enabled != self.enabled:
|
||||
self.db_instance.enabled = enabled
|
||||
self.enabled = enabled
|
||||
await self.update()
|
||||
|
||||
# region Properties
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
|
||||
) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.db_instance.id
|
||||
instance = cast(cls, await super().get(instance_id))
|
||||
if instance is not None:
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
self.db_instance.id = value
|
||||
if type and primary_user:
|
||||
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
|
||||
await instance.insert()
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.db_instance.type
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@property
|
||||
def primary_user(self) -> UserID:
|
||||
return self.db_instance.primary_user
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(
|
||||
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
|
||||
) -> Iterable[PluginInstance]:
|
||||
PluginInstance.mb_config = config
|
||||
PluginInstance.loop = loop
|
||||
PluginInstance.webserver = webserver
|
||||
return PluginInstance.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
|
||||
instances = await super().all()
|
||||
instance: PluginInstance
|
||||
for instance in instances:
|
||||
try:
|
||||
yield cls.cache[instance.id]
|
||||
except KeyError:
|
||||
instance.postinit()
|
||||
yield instance
|
||||
|
@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
|
||||
class ColorFormatter(BaseColorFormatter):
|
||||
def _color_name(self, module: str) -> str:
|
||||
client = "maubot.client"
|
||||
if module.startswith(client):
|
||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
|
||||
if module.startswith(client + "."):
|
||||
suffix = ""
|
||||
if module.endswith(".crypto"):
|
||||
suffix = f".{MAU_COLOR}crypto{RESET}"
|
||||
module = module[: -len(".crypto")]
|
||||
module = module[len(client) + 1 :]
|
||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
|
||||
instance = "maubot.instance"
|
||||
if module.startswith(instance):
|
||||
if module.startswith(instance + "."):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
|
||||
loader = "maubot.loader"
|
||||
if module.startswith(loader):
|
||||
if module.startswith(loader + "."):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
|
||||
if module.startswith("maubot"):
|
||||
if module.startswith("maubot."):
|
||||
return f"{MAU_COLOR}{module}{RESET}"
|
||||
return super()._color_name(module)
|
||||
|
@ -13,16 +13,15 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import SyncToken
|
||||
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
|
||||
|
||||
try:
|
||||
from mautrix.crypto import StateStore as CryptoStateStore
|
||||
|
||||
class SyncStoreProxy(SyncStore):
|
||||
def __init__(self, db_instance) -> None:
|
||||
self.db_instance = db_instance
|
||||
class PgStateStore(BasePgStateStore, CryptoStateStore):
|
||||
pass
|
||||
|
||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||
self.db_instance.edit(next_batch=next_batch)
|
||||
except ImportError as e:
|
||||
PgStateStore = BasePgStateStore
|
||||
|
||||
async def get_next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
__all__ = ["PgStateStore"]
|
@ -13,17 +13,14 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
|
||||
from attr import dataclass
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
from ..plugin_base import Plugin
|
||||
from .meta import PluginMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..instance import PluginInstance
|
||||
@ -35,36 +32,6 @@ class IDConflictError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@serializer(Version)
|
||||
def serialize_version(version: Version) -> str:
|
||||
return str(version)
|
||||
|
||||
|
||||
@deserializer(Version)
|
||||
def deserialize_version(version: str) -> Version:
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion as e:
|
||||
raise SerializerError("Invalid version") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMeta(SerializableAttrs):
|
||||
id: str
|
||||
version: Version
|
||||
modules: List[str]
|
||||
main_class: str
|
||||
|
||||
maubot: Version = Version(__version__)
|
||||
database: bool = False
|
||||
config: bool = False
|
||||
webapp: bool = False
|
||||
license: str = ""
|
||||
extra_files: List[str] = []
|
||||
dependencies: List[str] = []
|
||||
soft_dependencies: List[str] = []
|
||||
|
||||
|
||||
class BasePluginLoader(ABC):
|
||||
meta: PluginMeta
|
||||
|
||||
@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
|
||||
async def read_file(self, path: str) -> bytes:
|
||||
pass
|
||||
|
||||
def sync_list_files(self, directory: str) -> List[str]:
|
||||
def sync_list_files(self, directory: str) -> list[str]:
|
||||
raise NotImplementedError("This loader doesn't support synchronous operations")
|
||||
|
||||
@abstractmethod
|
||||
async def list_files(self, directory: str) -> List[str]:
|
||||
async def list_files(self, directory: str) -> list[str]:
|
||||
pass
|
||||
|
||||
|
||||
class PluginLoader(BasePluginLoader, ABC):
|
||||
id_cache: Dict[str, "PluginLoader"] = {}
|
||||
id_cache: dict[str, PluginLoader] = {}
|
||||
|
||||
meta: PluginMeta
|
||||
references: Set["PluginInstance"]
|
||||
references: set[PluginInstance]
|
||||
|
||||
def __init__(self):
|
||||
self.references = set()
|
||||
|
||||
@classmethod
|
||||
def find(cls, plugin_id: str) -> "PluginLoader":
|
||||
def find(cls, plugin_id: str) -> PluginLoader:
|
||||
return cls.id_cache[plugin_id]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def load(self) -> Type[PluginClass]:
|
||||
async def load(self) -> type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reload(self) -> Type[PluginClass]:
|
||||
async def reload(self) -> type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
53
maubot/loader/meta.py
Normal file
53
maubot/loader/meta.py
Normal file
@ -0,0 +1,53 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List
|
||||
|
||||
from attr import dataclass
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
|
||||
|
||||
@serializer(Version)
|
||||
def serialize_version(version: Version) -> str:
|
||||
return str(version)
|
||||
|
||||
|
||||
@deserializer(Version)
|
||||
def deserialize_version(version: str) -> Version:
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion as e:
|
||||
raise SerializerError("Invalid version") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMeta(SerializableAttrs):
|
||||
id: str
|
||||
version: Version
|
||||
modules: List[str]
|
||||
main_class: str
|
||||
|
||||
maubot: Version = Version(__version__)
|
||||
database: bool = False
|
||||
config: bool = False
|
||||
webapp: bool = False
|
||||
license: str = ""
|
||||
extra_files: List[str] = []
|
||||
dependencies: List[str] = []
|
||||
soft_dependencies: List[str] = []
|
@ -29,7 +29,8 @@ from mautrix.types import SerializerError
|
||||
from ..config import Config
|
||||
from ..lib.zipimport import ZipImportError, zipimporter
|
||||
from ..plugin_base import Plugin
|
||||
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
|
||||
from .abc import IDConflictError, PluginClass, PluginLoader
|
||||
from .meta import PluginMeta
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
@ -20,7 +20,7 @@ from aiohttp import web
|
||||
|
||||
from ...config import Config
|
||||
from .auth import check_token
|
||||
from .base import get_config, routes, set_config, set_loop
|
||||
from .base import get_config, routes, set_config
|
||||
from .middleware import auth, error
|
||||
|
||||
|
||||
@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
|
||||
|
||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||
set_config(cfg)
|
||||
set_loop(loop)
|
||||
for pkg, enabled in cfg["api_features"].items():
|
||||
if enabled:
|
||||
importlib.import_module(f"maubot.management.api.{pkg}")
|
||||
|
@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
|
||||
def get_token(request: web.Request) -> str:
|
||||
token = request.headers.get("Authorization", "")
|
||||
if not token or not token.startswith("Bearer "):
|
||||
token = request.query.get("access_token", None)
|
||||
token = request.query.get("access_token", "")
|
||||
else:
|
||||
token = token[len("Bearer ") :]
|
||||
return token
|
||||
|
@ -24,7 +24,6 @@ from ...config import Config
|
||||
|
||||
routes: web.RouteTableDef = web.RouteTableDef()
|
||||
_config: Config | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
def set_config(config: Config) -> None:
|
||||
@ -36,15 +35,6 @@ def get_config() -> Config:
|
||||
return _config
|
||||
|
||||
|
||||
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _loop
|
||||
_loop = loop
|
||||
|
||||
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
return _loop
|
||||
|
||||
|
||||
@routes.get("/version")
|
||||
async def version(_: web.Request) -> web.Response:
|
||||
return web.json_response({"version": __version__})
|
||||
|
@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
|
||||
from mautrix.types import FilterID, SyncToken, UserID
|
||||
|
||||
from ...client import Client
|
||||
from ...db import DBClient
|
||||
from .base import routes
|
||||
from .responses import resp
|
||||
|
||||
@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
|
||||
@routes.get("/client/{id}")
|
||||
async def get_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
return resp.found(client.to_dict())
|
||||
@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
||||
mxid="@not:a.mxid",
|
||||
base_url=homeserver,
|
||||
token=access_token,
|
||||
loop=Client.loop,
|
||||
client_session=Client.http_client,
|
||||
)
|
||||
try:
|
||||
@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
||||
except MatrixConnectionError:
|
||||
return resp.bad_client_connection_details
|
||||
if user_id is None:
|
||||
existing_client = Client.get(whoami.user_id, None)
|
||||
existing_client = await Client.get(whoami.user_id)
|
||||
if existing_client is not None:
|
||||
return resp.user_exists
|
||||
elif whoami.user_id != user_id:
|
||||
return resp.mxid_mismatch(whoami.user_id)
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
return resp.device_id_mismatch(whoami.device_id)
|
||||
db_instance = DBClient(
|
||||
id=whoami.user_id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
enabled=data.get("enabled", True),
|
||||
next_batch=SyncToken(""),
|
||||
filter_id=FilterID(""),
|
||||
sync=data.get("sync", True),
|
||||
autojoin=data.get("autojoin", True),
|
||||
online=data.get("online", True),
|
||||
displayname=data.get("displayname", "disable"),
|
||||
avatar_url=data.get("avatar_url", "disable"),
|
||||
device_id=device_id,
|
||||
client = await Client.get(
|
||||
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
|
||||
)
|
||||
client = Client(db_instance)
|
||||
client.db_instance.insert()
|
||||
client.enabled = data.get("enabled", True)
|
||||
client.sync = data.get("sync", True)
|
||||
client.autojoin = data.get("autojoin", True)
|
||||
client.online = data.get("online", True)
|
||||
client.displayname = data.get("displayname", "disable")
|
||||
client.avatar_url = data.get("avatar_url", "disable")
|
||||
await client.update()
|
||||
await client.start()
|
||||
return resp.created(client.to_dict())
|
||||
|
||||
@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
||||
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
||||
try:
|
||||
await client.update_access_details(
|
||||
data.get("access_token", None),
|
||||
data.get("homeserver", None),
|
||||
data.get("device_id", None),
|
||||
data.get("access_token"), data.get("homeserver"), data.get("device_id")
|
||||
)
|
||||
except MatrixInvalidToken:
|
||||
return resp.bad_client_access_token
|
||||
@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
|
||||
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
|
||||
elif str_err.startswith("Device ID mismatch"):
|
||||
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
|
||||
with client.db_instance.edit_mode():
|
||||
await client.update_avatar_url(data.get("avatar_url", None))
|
||||
await client.update_displayname(data.get("displayname", None))
|
||||
await client.update_started(data.get("started", None))
|
||||
client.enabled = data.get("enabled", client.enabled)
|
||||
client.autojoin = data.get("autojoin", client.autojoin)
|
||||
client.online = data.get("online", client.online)
|
||||
client.sync = data.get("sync", client.sync)
|
||||
return resp.updated(client.to_dict(), is_login=is_login)
|
||||
await client.update_avatar_url(data.get("avatar_url"), save=False)
|
||||
await client.update_displayname(data.get("displayname"), save=False)
|
||||
await client.update_started(data.get("started"))
|
||||
await client.update_enabled(data.get("enabled"), save=False)
|
||||
await client.update_autojoin(data.get("autojoin"), save=False)
|
||||
await client.update_online(data.get("online"), save=False)
|
||||
await client.update_sync(data.get("sync"), save=False)
|
||||
await client.update()
|
||||
return resp.updated(client.to_dict(), is_login=is_login)
|
||||
|
||||
|
||||
async def _create_or_update_client(
|
||||
user_id: UserID, data: dict, is_login: bool = False
|
||||
) -> web.Response:
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return await _create_client(user_id, data)
|
||||
else:
|
||||
@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
|
||||
|
||||
@routes.put("/client/{id}")
|
||||
async def update_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
user_id = request.match_info["id"]
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
|
||||
|
||||
@routes.delete("/client/{id}")
|
||||
async def delete_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
user_id = request.match_info["id"]
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
if len(client.references) > 0:
|
||||
return resp.client_in_use
|
||||
if client.started:
|
||||
await client.stop()
|
||||
client.delete()
|
||||
await client.delete()
|
||||
return resp.deleted
|
||||
|
||||
|
||||
@routes.post("/client/{id}/clearcache")
|
||||
async def clear_client_cache(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
user_id = request.match_info["id"]
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
client.clear_cache()
|
||||
await client.clear_cache()
|
||||
return resp.ok
|
||||
|
@ -13,7 +13,9 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, NamedTuple, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
import asyncio
|
||||
@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
|
||||
from mautrix.errors import MatrixRequestError
|
||||
from mautrix.types import LoginResponse, LoginType
|
||||
|
||||
from .base import get_config, get_loop, routes
|
||||
from .base import get_config, routes
|
||||
from .client import _create_client, _create_or_update_client
|
||||
from .responses import resp
|
||||
|
||||
|
||||
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
||||
def known_homeservers() -> dict[str, dict[str, str]]:
|
||||
return get_config()["homeservers"]
|
||||
|
||||
|
||||
@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
|
||||
|
||||
async def read_client_auth_request(
|
||||
request: web.Request,
|
||||
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
|
||||
) -> tuple[AuthRequestInfo | None, web.Response | None]:
|
||||
server_name = request.match_info.get("server", None)
|
||||
server = known_homeservers().get(server_name, None)
|
||||
if not server:
|
||||
@ -85,7 +87,7 @@ async def read_client_auth_request(
|
||||
return (
|
||||
AuthRequestInfo(
|
||||
server_name=server_name,
|
||||
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
||||
client=ClientAPI(base_url=base_url),
|
||||
secret=server.get("secret"),
|
||||
username=username,
|
||||
password=password,
|
||||
@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
|
||||
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
|
||||
{"redirectUrl": str(public_url)}
|
||||
)
|
||||
sso_waiters[waiter_id] = req, get_loop().create_future()
|
||||
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
|
||||
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
||||
|
||||
|
||||
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
|
||||
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
|
||||
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||
device_id = f"maubot_{device_id}"
|
||||
try:
|
||||
@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
|
||||
return web.json_response(res.serialize())
|
||||
|
||||
|
||||
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||
|
||||
|
||||
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
||||
|
@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
|
||||
@routes.view("/proxy/{id}/{path:_matrix/.+}")
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
|
||||
|
@ -18,7 +18,6 @@ from json import JSONDecodeError
|
||||
from aiohttp import web
|
||||
|
||||
from ...client import Client
|
||||
from ...db import DBPlugin
|
||||
from ...instance import PluginInstance
|
||||
from ...loader import PluginLoader
|
||||
from .base import routes
|
||||
@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
|
||||
|
||||
@routes.get("/instance/{id}")
|
||||
async def get_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
return resp.found(instance.to_dict())
|
||||
|
||||
|
||||
async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
||||
plugin_type = data.get("type", None)
|
||||
primary_user = data.get("primary_user", None)
|
||||
plugin_type = data.get("type")
|
||||
primary_user = data.get("primary_user")
|
||||
if not plugin_type:
|
||||
return resp.plugin_type_required
|
||||
elif not primary_user:
|
||||
return resp.primary_user_required
|
||||
elif not Client.get(primary_user):
|
||||
elif not await Client.get(primary_user):
|
||||
return resp.primary_user_not_found
|
||||
try:
|
||||
PluginLoader.find(plugin_type)
|
||||
except KeyError:
|
||||
return resp.plugin_type_not_found
|
||||
db_instance = DBPlugin(
|
||||
id=instance_id,
|
||||
type=plugin_type,
|
||||
enabled=data.get("enabled", True),
|
||||
primary_user=primary_user,
|
||||
config=data.get("config", ""),
|
||||
)
|
||||
instance = PluginInstance(db_instance)
|
||||
instance.load()
|
||||
instance.db_instance.insert()
|
||||
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
|
||||
instance.enabled = data.get("enabled", True)
|
||||
instance.config_str = data.get("config") or ""
|
||||
await instance.update()
|
||||
await instance.start()
|
||||
return resp.created(instance.to_dict())
|
||||
|
||||
|
||||
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
|
||||
if not await instance.update_primary_user(data.get("primary_user", None)):
|
||||
if not await instance.update_primary_user(data.get("primary_user")):
|
||||
return resp.primary_user_not_found
|
||||
with instance.db_instance.edit_mode():
|
||||
instance.update_id(data.get("id", None))
|
||||
instance.update_enabled(data.get("enabled", None))
|
||||
instance.update_config(data.get("config", None))
|
||||
await instance.update_started(data.get("started", None))
|
||||
await instance.update_type(data.get("type", None))
|
||||
return resp.updated(instance.to_dict())
|
||||
await instance.update_id(data.get("id"))
|
||||
await instance.update_enabled(data.get("enabled"))
|
||||
await instance.update_config(data.get("config"))
|
||||
await instance.update_started(data.get("started"))
|
||||
await instance.update_type(data.get("type"))
|
||||
return resp.updated(instance.to_dict())
|
||||
|
||||
|
||||
@routes.put("/instance/{id}")
|
||||
async def update_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
|
||||
|
||||
@routes.delete("/instance/{id}")
|
||||
async def delete_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
if instance.started:
|
||||
await instance.stop()
|
||||
instance.delete()
|
||||
await instance.delete()
|
||||
return resp.deleted
|
||||
|
@ -29,8 +29,8 @@ from .responses import resp
|
||||
|
||||
@routes.get("/instance/{id}/database")
|
||||
async def get_database(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
@ -65,8 +65,8 @@ def check_type(val):
|
||||
|
||||
@routes.get("/instance/{id}/database/{table}")
|
||||
async def get_table(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
|
||||
]
|
||||
except KeyError:
|
||||
order = []
|
||||
limit = int(request.query.get("limit", 100))
|
||||
limit = int(request.query.get("limit", "100"))
|
||||
return execute_query(instance, table.select().order_by(*order).limit(limit))
|
||||
|
||||
|
||||
@routes.post("/instance/{id}/database/query")
|
||||
async def query(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
|
@ -23,7 +23,7 @@ import logging
|
||||
from aiohttp import web, web_ws
|
||||
|
||||
from .auth import is_valid_token
|
||||
from .base import get_loop, routes
|
||||
from .base import routes
|
||||
|
||||
BUILTIN_ATTRS = {
|
||||
"args",
|
||||
@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
|
||||
authenticated = False
|
||||
|
||||
async def close_if_not_authenticated():
|
||||
await asyncio.sleep(5, loop=get_loop())
|
||||
await asyncio.sleep(5)
|
||||
if not authenticated:
|
||||
await ws.close(code=4000)
|
||||
log.debug(f"Connection from {request.remote} terminated due to no authentication")
|
||||
|
||||
asyncio.ensure_future(close_if_not_authenticated())
|
||||
asyncio.create_task(close_if_not_authenticated())
|
||||
|
||||
try:
|
||||
msg: web_ws.WSMessage
|
||||
|
@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
|
||||
|
||||
@routes.get("/plugin/{id}")
|
||||
async def get_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
return resp.found(plugin.to_dict())
|
||||
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
|
||||
|
||||
@routes.delete("/plugin/{id}")
|
||||
async def delete_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
elif len(plugin.references) > 0:
|
||||
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
|
||||
|
||||
@routes.post("/plugin/{id}/reload")
|
||||
async def reload_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
|
||||
|
@ -29,7 +29,7 @@ from .responses import resp
|
||||
|
||||
@routes.put("/plugin/{id}")
|
||||
async def put_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin_id = request.match_info["id"]
|
||||
content = await request.read()
|
||||
file = BytesIO(content)
|
||||
try:
|
||||
|
@ -15,7 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Awaitable
|
||||
from abc import ABC
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
@ -124,6 +124,7 @@ class Plugin(ABC):
|
||||
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
||||
return None
|
||||
|
||||
def on_external_config_update(self) -> None:
|
||||
def on_external_config_update(self) -> Awaitable[None] | None:
|
||||
if self.config:
|
||||
self.config.load_and_update()
|
||||
return None
|
||||
|
@ -1,13 +1,10 @@
|
||||
# Format: #/name defines a new extras_require group called name
|
||||
# Uncommented lines after the group definition insert things into that group.
|
||||
|
||||
#/postgres
|
||||
psycopg2-binary>=2,<3
|
||||
asyncpg>=0.20,<0.26
|
||||
#/sqlite
|
||||
aiosqlite>=0.16,<0.18
|
||||
|
||||
#/encryption
|
||||
asyncpg>=0.20,<0.26
|
||||
aiosqlite>=0.16,<0.18
|
||||
python-olm>=3,<4
|
||||
pycryptodome>=3,<4
|
||||
unpaddedbase64>=1,<3
|
||||
|
@ -1,7 +1,8 @@
|
||||
mautrix>=0.15.0,<0.16
|
||||
mautrix>=0.15.2,<0.16
|
||||
aiohttp>=3,<4
|
||||
yarl>=1,<2
|
||||
SQLAlchemy>=1,<1.4
|
||||
asyncpg>=0.20,<0.26
|
||||
alembic>=1,<2
|
||||
commonmark>=0.9,<1
|
||||
ruamel.yaml>=0.15.35,<0.18
|
||||
|
5
setup.py
5
setup.py
@ -1,5 +1,4 @@
|
||||
import setuptools
|
||||
import glob
|
||||
import os
|
||||
|
||||
with open("requirements.txt") as reqs:
|
||||
@ -57,9 +56,7 @@ setuptools.setup(
|
||||
mbc=maubot.cli:app
|
||||
""",
|
||||
data_files=[
|
||||
(".", ["maubot/example-config.yaml", "alembic.ini"]),
|
||||
("alembic", ["alembic/env.py"]),
|
||||
("alembic/versions", glob.glob("alembic/versions/*.py")),
|
||||
(".", ["maubot/example-config.yaml"]),
|
||||
],
|
||||
package_data={
|
||||
"maubot": [
|
||||
|
Loading…
Reference in New Issue
Block a user