parent
068e268c63
commit
21ed971d2f
@ -15,7 +15,6 @@ RUN apk add --no-cache \
|
|||||||
py3-attrs \
|
py3-attrs \
|
||||||
py3-bcrypt \
|
py3-bcrypt \
|
||||||
py3-cffi \
|
py3-cffi \
|
||||||
py3-psycopg2 \
|
|
||||||
py3-ruamel.yaml \
|
py3-ruamel.yaml \
|
||||||
py3-jinja2 \
|
py3-jinja2 \
|
||||||
py3-click \
|
py3-click \
|
||||||
@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
|
|||||||
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
||||||
WORKDIR /opt/maubot
|
WORKDIR /opt/maubot
|
||||||
RUN apk add --virtual .build-deps python3-dev build-base git \
|
RUN apk add --virtual .build-deps python3-dev build-base git \
|
||||||
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
|
|
||||||
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
||||||
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
||||||
&& apk del .build-deps
|
&& apk del .build-deps
|
||||||
|
@ -10,7 +10,6 @@ RUN apk add --no-cache \
|
|||||||
py3-attrs \
|
py3-attrs \
|
||||||
py3-bcrypt \
|
py3-bcrypt \
|
||||||
py3-cffi \
|
py3-cffi \
|
||||||
py3-psycopg2 \
|
|
||||||
py3-ruamel.yaml \
|
py3-ruamel.yaml \
|
||||||
py3-jinja2 \
|
py3-jinja2 \
|
||||||
py3-click \
|
py3-click \
|
||||||
@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
|
|||||||
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
|
||||||
WORKDIR /opt/maubot
|
WORKDIR /opt/maubot
|
||||||
RUN apk add --virtual .build-deps python3-dev build-base git \
|
RUN apk add --virtual .build-deps python3-dev build-base git \
|
||||||
&& sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
|
|
||||||
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
&& pip3 install -r requirements.txt -r optional-requirements.txt \
|
||||||
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
|
||||||
&& apk del .build-deps
|
&& apk del .build-deps
|
||||||
|
83
alembic.ini
83
alembic.ini
@ -1,83 +0,0 @@
|
|||||||
# A generic, single database configuration.
|
|
||||||
|
|
||||||
[alembic]
|
|
||||||
# path to migration scripts
|
|
||||||
script_location = alembic
|
|
||||||
|
|
||||||
# template used to generate migration files
|
|
||||||
# file_template = %%(rev)s_%%(slug)s
|
|
||||||
|
|
||||||
# timezone to use when rendering the date
|
|
||||||
# within the migration file as well as the filename.
|
|
||||||
# string value is passed to dateutil.tz.gettz()
|
|
||||||
# leave blank for localtime
|
|
||||||
# timezone =
|
|
||||||
|
|
||||||
# max length of characters to apply to the
|
|
||||||
# "slug" field
|
|
||||||
# truncate_slug_length = 40
|
|
||||||
|
|
||||||
# set to 'true' to run the environment during
|
|
||||||
# the 'revision' command, regardless of autogenerate
|
|
||||||
# revision_environment = false
|
|
||||||
|
|
||||||
# set to 'true' to allow .pyc and .pyo files without
|
|
||||||
# a source .py file to be detected as revisions in the
|
|
||||||
# versions/ directory
|
|
||||||
# sourceless = false
|
|
||||||
|
|
||||||
# version location specification; this defaults
|
|
||||||
# to alembic/versions. When using multiple version
|
|
||||||
# directories, initial revisions must be specified with --version-path
|
|
||||||
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
|
|
||||||
|
|
||||||
# the output encoding used when revision files
|
|
||||||
# are written from script.py.mako
|
|
||||||
# output_encoding = utf-8
|
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
|
||||||
# post_write_hooks defines scripts or Python functions that are run
|
|
||||||
# on newly generated revision scripts. See the documentation for further
|
|
||||||
# detail and examples
|
|
||||||
|
|
||||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
|
||||||
# hooks=black
|
|
||||||
# black.type=console_scripts
|
|
||||||
# black.entrypoint=black
|
|
||||||
# black.options=-l 79
|
|
||||||
|
|
||||||
# Logging configuration
|
|
||||||
[loggers]
|
|
||||||
keys = root,sqlalchemy,alembic
|
|
||||||
|
|
||||||
[handlers]
|
|
||||||
keys = console
|
|
||||||
|
|
||||||
[formatters]
|
|
||||||
keys = generic
|
|
||||||
|
|
||||||
[logger_root]
|
|
||||||
level = WARN
|
|
||||||
handlers = console
|
|
||||||
qualname =
|
|
||||||
|
|
||||||
[logger_sqlalchemy]
|
|
||||||
level = WARN
|
|
||||||
handlers =
|
|
||||||
qualname = sqlalchemy.engine
|
|
||||||
|
|
||||||
[logger_alembic]
|
|
||||||
level = INFO
|
|
||||||
handlers =
|
|
||||||
qualname = alembic
|
|
||||||
|
|
||||||
[handler_console]
|
|
||||||
class = StreamHandler
|
|
||||||
args = (sys.stderr,)
|
|
||||||
level = NOTSET
|
|
||||||
formatter = generic
|
|
||||||
|
|
||||||
[formatter_generic]
|
|
||||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
|
||||||
datefmt = %H:%M:%S
|
|
@ -1 +0,0 @@
|
|||||||
Generic single-database configuration.
|
|
@ -1,92 +0,0 @@
|
|||||||
from logging.config import fileConfig
|
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config, pool
|
|
||||||
|
|
||||||
from alembic import context
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from os.path import abspath, dirname
|
|
||||||
|
|
||||||
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
|
||||||
|
|
||||||
from mautrix.util.db import Base
|
|
||||||
from maubot.config import Config
|
|
||||||
from maubot import db
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
|
||||||
# access to the values within the .ini file in use.
|
|
||||||
config = context.config
|
|
||||||
|
|
||||||
maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
|
|
||||||
maubot_config = Config(maubot_config_path, None)
|
|
||||||
maubot_config.load()
|
|
||||||
config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
|
||||||
# This line sets up loggers basically.
|
|
||||||
fileConfig(config.config_file_name)
|
|
||||||
|
|
||||||
# add your model's MetaData object here
|
|
||||||
# for 'autogenerate' support
|
|
||||||
# from myapp import mymodel
|
|
||||||
# target_metadata = mymodel.Base.metadata
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
|
||||||
# can be acquired:
|
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
|
||||||
# ... etc.
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline():
|
|
||||||
"""Run migrations in 'offline' mode.
|
|
||||||
|
|
||||||
This configures the context with just a URL
|
|
||||||
and not an Engine, though an Engine is acceptable
|
|
||||||
here as well. By skipping the Engine creation
|
|
||||||
we don't even need a DBAPI to be available.
|
|
||||||
|
|
||||||
Calls to context.execute() here emit the given string to the
|
|
||||||
script output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
|
||||||
context.configure(
|
|
||||||
url=url,
|
|
||||||
target_metadata=target_metadata,
|
|
||||||
literal_binds=True,
|
|
||||||
dialect_opts={"paramstyle": "named"},
|
|
||||||
render_as_batch=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_online():
|
|
||||||
"""Run migrations in 'online' mode.
|
|
||||||
|
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
|
|
||||||
"""
|
|
||||||
connectable = engine_from_config(
|
|
||||||
config.get_section(config.config_ini_section),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
|
||||||
context.configure(
|
|
||||||
connection=connection, target_metadata=target_metadata,
|
|
||||||
render_as_batch=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
|
||||||
if context.is_offline_mode():
|
|
||||||
run_migrations_offline()
|
|
||||||
else:
|
|
||||||
run_migrations_online()
|
|
@ -1,24 +0,0 @@
|
|||||||
"""${message}
|
|
||||||
|
|
||||||
Revision ID: ${up_revision}
|
|
||||||
Revises: ${down_revision | comma,n}
|
|
||||||
Create Date: ${create_date}
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
${imports if imports else ""}
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = ${repr(up_revision)}
|
|
||||||
down_revision = ${repr(down_revision)}
|
|
||||||
branch_labels = ${repr(branch_labels)}
|
|
||||||
depends_on = ${repr(depends_on)}
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
${upgrades if upgrades else "pass"}
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
${downgrades if downgrades else "pass"}
|
|
@ -1,32 +0,0 @@
|
|||||||
"""Add device_id to clients
|
|
||||||
|
|
||||||
Revision ID: 4b93300852aa
|
|
||||||
Revises: fccd1f95544d
|
|
||||||
Create Date: 2020-07-11 15:49:38.831459
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = '4b93300852aa'
|
|
||||||
down_revision = 'fccd1f95544d'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table('client', schema=None) as batch_op:
|
|
||||||
batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True))
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table('client', schema=None) as batch_op:
|
|
||||||
batch_op.drop_column('device_id')
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,47 +0,0 @@
|
|||||||
"""Add Matrix state store
|
|
||||||
|
|
||||||
Revision ID: 90aa88820eab
|
|
||||||
Revises: 4b93300852aa
|
|
||||||
Create Date: 2020-07-12 01:50:06.215623
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from mautrix.client.state_store.sqlalchemy import SerializableType
|
|
||||||
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = '90aa88820eab'
|
|
||||||
down_revision = '4b93300852aa'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('mx_room_state',
|
|
||||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
|
|
||||||
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
|
|
||||||
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
|
|
||||||
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('room_id')
|
|
||||||
)
|
|
||||||
op.create_table('mx_user_profile',
|
|
||||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('user_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
|
|
||||||
sa.Column('displayname', sa.String(), nullable=True),
|
|
||||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('room_id', 'user_id')
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('mx_user_profile')
|
|
||||||
op.drop_table('mx_room_state')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,50 +0,0 @@
|
|||||||
"""Initial revision
|
|
||||||
|
|
||||||
Revision ID: d295f8dcfa64
|
|
||||||
Revises:
|
|
||||||
Create Date: 2019-09-27 00:21:02.527915
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = 'd295f8dcfa64'
|
|
||||||
down_revision = None
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('client',
|
|
||||||
sa.Column('id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('homeserver', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('access_token', sa.Text(), nullable=False),
|
|
||||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('next_batch', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('filter_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('sync', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('autojoin', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('displayname', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('avatar_url', sa.String(length=255), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('plugin',
|
|
||||||
sa.Column('id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('type', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('primary_user', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('config', sa.Text(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('plugin')
|
|
||||||
op.drop_table('client')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,30 +0,0 @@
|
|||||||
"""Add online field to clients
|
|
||||||
|
|
||||||
Revision ID: fccd1f95544d
|
|
||||||
Revises: d295f8dcfa64
|
|
||||||
Create Date: 2020-03-06 15:07:50.136644
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = 'fccd1f95544d'
|
|
||||||
down_revision = 'd295f8dcfa64'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("client") as batch_op:
|
|
||||||
batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
with op.batch_alter_table("client") as batch_op:
|
|
||||||
batch_op.drop_column('online')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,7 +1,7 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
function fixperms {
|
function fixperms {
|
||||||
chown -R $UID:$GID /var/log /data /opt/maubot
|
chown -R $UID:$GID /var/log /data
|
||||||
}
|
}
|
||||||
|
|
||||||
function fixdefault {
|
function fixdefault {
|
||||||
|
@ -13,24 +13,37 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import asyncio
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from mautrix.util.async_db import Database, DatabaseException
|
||||||
from mautrix.util.program import Program
|
from mautrix.util.program import Program
|
||||||
|
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
from .client import Client, init as init_client_class
|
from .client import Client
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import init as init_db
|
from .db import init as init_db, upgrade_table
|
||||||
from .instance import init as init_plugin_instance_class
|
from .instance import PluginInstance
|
||||||
from .lib.future_awaitable import FutureAwaitable
|
from .lib.future_awaitable import FutureAwaitable
|
||||||
|
from .lib.state_store import PgStateStore
|
||||||
from .loader.zip import init as init_zip_loader
|
from .loader.zip import init as init_zip_loader
|
||||||
from .management.api import init as init_mgmt_api
|
from .management.api import init as init_mgmt_api
|
||||||
from .server import MaubotServer
|
from .server import MaubotServer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mautrix.crypto.store import PgCryptoStore
|
||||||
|
except ImportError:
|
||||||
|
PgCryptoStore = None
|
||||||
|
|
||||||
|
|
||||||
class Maubot(Program):
|
class Maubot(Program):
|
||||||
config: Config
|
config: Config
|
||||||
server: MaubotServer
|
server: MaubotServer
|
||||||
|
db: Database
|
||||||
|
crypto_db: Database | None
|
||||||
|
state_store: PgStateStore
|
||||||
|
|
||||||
config_class = Config
|
config_class = Config
|
||||||
module = "maubot"
|
module = "maubot"
|
||||||
@ -45,6 +58,19 @@ class Maubot(Program):
|
|||||||
init(self.loop)
|
init(self.loop)
|
||||||
self.add_shutdown_actions(FutureAwaitable(stop_all))
|
self.add_shutdown_actions(FutureAwaitable(stop_all))
|
||||||
|
|
||||||
|
def prepare_arg_parser(self) -> None:
|
||||||
|
super().prepare_arg_parser()
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--ignore-unsupported-database",
|
||||||
|
action="store_true",
|
||||||
|
help="Run even if the database schema is too new",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--ignore-foreign-tables",
|
||||||
|
action="store_true",
|
||||||
|
help="Run even if the database contains tables from other programs (like Synapse)",
|
||||||
|
)
|
||||||
|
|
||||||
def prepare(self) -> None:
|
def prepare(self) -> None:
|
||||||
super().prepare()
|
super().prepare()
|
||||||
|
|
||||||
@ -52,21 +78,59 @@ class Maubot(Program):
|
|||||||
self.prepare_log_websocket()
|
self.prepare_log_websocket()
|
||||||
|
|
||||||
init_zip_loader(self.config)
|
init_zip_loader(self.config)
|
||||||
init_db(self.config)
|
self.db = Database.create(
|
||||||
clients = init_client_class(self.config, self.loop)
|
self.config["database"],
|
||||||
self.add_startup_actions(*(client.start() for client in clients))
|
upgrade_table=upgrade_table,
|
||||||
|
db_args=self.config["database_opts"],
|
||||||
|
owner_name=self.name,
|
||||||
|
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||||
|
)
|
||||||
|
init_db(self.db)
|
||||||
|
if self.config["crypto_database"] == "default":
|
||||||
|
self.crypto_db = self.db
|
||||||
|
else:
|
||||||
|
self.crypto_db = Database.create(
|
||||||
|
self.config["crypto_database"],
|
||||||
|
upgrade_table=PgCryptoStore.upgrade_table,
|
||||||
|
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||||
|
)
|
||||||
|
Client.init_cls(self)
|
||||||
|
PluginInstance.init_cls(self)
|
||||||
management_api = init_mgmt_api(self.config, self.loop)
|
management_api = init_mgmt_api(self.config, self.loop)
|
||||||
self.server = MaubotServer(management_api, self.config, self.loop)
|
self.server = MaubotServer(management_api, self.config, self.loop)
|
||||||
|
self.state_store = PgStateStore(self.db)
|
||||||
|
|
||||||
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
|
async def start_db(self) -> None:
|
||||||
for plugin in plugins:
|
self.log.debug("Starting database...")
|
||||||
plugin.load()
|
ignore_unsupported = self.args.ignore_unsupported_database
|
||||||
|
self.db.upgrade_table.allow_unsupported = ignore_unsupported
|
||||||
|
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
|
||||||
|
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
|
||||||
|
try:
|
||||||
|
await self.db.start()
|
||||||
|
await self.state_store.upgrade_table.upgrade(self.db)
|
||||||
|
if self.crypto_db and self.crypto_db is not self.db:
|
||||||
|
await self.crypto_db.start()
|
||||||
|
else:
|
||||||
|
await PgCryptoStore.upgrade_table.upgrade(self.db)
|
||||||
|
except DatabaseException as e:
|
||||||
|
self.log.critical("Failed to initialize database", exc_info=e)
|
||||||
|
if e.explanation:
|
||||||
|
self.log.info(e.explanation)
|
||||||
|
sys.exit(25)
|
||||||
|
|
||||||
|
async def system_exit(self) -> None:
|
||||||
|
if hasattr(self, "db"):
|
||||||
|
self.log.trace("Stopping database due to SystemExit")
|
||||||
|
await self.db.stop()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if Client.crypto_db:
|
await self.start_db()
|
||||||
self.log.debug("Starting client crypto database")
|
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
|
||||||
await Client.crypto_db.start()
|
await asyncio.gather(*[client.start() async for client in Client.all()])
|
||||||
await super().start()
|
await super().start()
|
||||||
|
async for plugin in PluginInstance.all():
|
||||||
|
await plugin.load()
|
||||||
await self.server.start()
|
await self.server.start()
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
@ -77,6 +141,7 @@ class Maubot(Program):
|
|||||||
await asyncio.wait_for(self.server.stop(), 5)
|
await asyncio.wait_for(self.server.stop(), 5)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.log.warning("Stopping server timed out")
|
self.log.warning("Stopping server timed out")
|
||||||
|
await self.db.stop()
|
||||||
|
|
||||||
|
|
||||||
Maubot().run()
|
Maubot().run()
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "0.2.1"
|
__version__ = "0.3.0+dev"
|
||||||
|
@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
|
|||||||
global history_count
|
global history_count
|
||||||
history_count = tail
|
history_count = tail
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
future = asyncio.ensure_future(view_logs(server, token), loop=loop)
|
future = asyncio.create_task(view_logs(server, token), loop=loop)
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(future)
|
loop.run_until_complete(future)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
385
maubot/client.py
385
maubot/client.py
@ -15,14 +15,14 @@
|
|||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
|
||||||
|
from collections import defaultdict
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from mautrix.client import InternalEventType
|
from mautrix.client import InternalEventType
|
||||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
|
||||||
from mautrix.errors import MatrixInvalidToken
|
from mautrix.errors import MatrixInvalidToken
|
||||||
from mautrix.types import (
|
from mautrix.types import (
|
||||||
ContentURI,
|
ContentURI,
|
||||||
@ -41,69 +41,110 @@ from mautrix.types import (
|
|||||||
SyncToken,
|
SyncToken,
|
||||||
UserID,
|
UserID,
|
||||||
)
|
)
|
||||||
|
from mautrix.util.async_getter_lock import async_getter_lock
|
||||||
|
from mautrix.util.logging import TraceLogger
|
||||||
|
|
||||||
from .db import DBClient
|
from .db import Client as DBClient
|
||||||
from .lib.store_proxy import SyncStoreProxy
|
|
||||||
from .matrix import MaubotMatrixClient
|
from .matrix import MaubotMatrixClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
|
from mautrix.crypto import OlmMachine, PgCryptoStore
|
||||||
from mautrix.util.async_db import Database as AsyncDatabase
|
|
||||||
|
|
||||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
|
||||||
pass
|
|
||||||
|
|
||||||
crypto_import_error = None
|
crypto_import_error = None
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
|
OlmMachine = PgCryptoStore = None
|
||||||
SQLStateStore = BaseSQLStateStore
|
|
||||||
crypto_import_error = e
|
crypto_import_error = e
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .config import Config
|
from .__main__ import Maubot
|
||||||
from .instance import PluginInstance
|
from .instance import PluginInstance
|
||||||
|
|
||||||
log = logging.getLogger("maubot.client")
|
|
||||||
|
|
||||||
|
class Client(DBClient):
|
||||||
class Client:
|
maubot: "Maubot" = None
|
||||||
log: logging.Logger = None
|
|
||||||
loop: asyncio.AbstractEventLoop = None
|
|
||||||
cache: dict[UserID, Client] = {}
|
cache: dict[UserID, Client] = {}
|
||||||
|
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||||
|
log: TraceLogger = logging.getLogger("maubot.client")
|
||||||
|
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
|
|
||||||
crypto_db: AsyncDatabase | None = None
|
|
||||||
|
|
||||||
references: set[PluginInstance]
|
references: set[PluginInstance]
|
||||||
db_instance: DBClient
|
|
||||||
client: MaubotMatrixClient
|
client: MaubotMatrixClient
|
||||||
crypto: OlmMachine | None
|
crypto: OlmMachine | None
|
||||||
crypto_store: PgCryptoStore | None
|
crypto_store: PgCryptoStore | None
|
||||||
started: bool
|
started: bool
|
||||||
|
sync_ok: bool
|
||||||
|
|
||||||
remote_displayname: str | None
|
remote_displayname: str | None
|
||||||
remote_avatar_url: ContentURI | None
|
remote_avatar_url: ContentURI | None
|
||||||
|
|
||||||
def __init__(self, db_instance: DBClient) -> None:
|
def __init__(
|
||||||
self.db_instance = db_instance
|
self,
|
||||||
|
id: UserID,
|
||||||
|
homeserver: str,
|
||||||
|
access_token: str,
|
||||||
|
device_id: DeviceID,
|
||||||
|
enabled: bool = False,
|
||||||
|
next_batch: SyncToken = "",
|
||||||
|
filter_id: FilterID = "",
|
||||||
|
sync: bool = True,
|
||||||
|
autojoin: bool = True,
|
||||||
|
online: bool = True,
|
||||||
|
displayname: str = "disable",
|
||||||
|
avatar_url: str = "disable",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
homeserver=homeserver,
|
||||||
|
access_token=access_token,
|
||||||
|
device_id=device_id,
|
||||||
|
enabled=bool(enabled),
|
||||||
|
next_batch=next_batch,
|
||||||
|
filter_id=filter_id,
|
||||||
|
sync=bool(sync),
|
||||||
|
autojoin=bool(autojoin),
|
||||||
|
online=bool(online),
|
||||||
|
displayname=displayname,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
)
|
||||||
|
self._postinited = False
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_cls(cls, maubot: "Maubot") -> None:
|
||||||
|
cls.maubot = maubot
|
||||||
|
|
||||||
|
def _make_client(
|
||||||
|
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
|
||||||
|
) -> MaubotMatrixClient:
|
||||||
|
return MaubotMatrixClient(
|
||||||
|
mxid=self.id,
|
||||||
|
base_url=homeserver or self.homeserver,
|
||||||
|
token=token or self.access_token,
|
||||||
|
client_session=self.http_client,
|
||||||
|
log=self.log,
|
||||||
|
crypto_log=self.log.getChild("crypto"),
|
||||||
|
loop=self.maubot.loop,
|
||||||
|
device_id=device_id or self.device_id,
|
||||||
|
sync_store=self,
|
||||||
|
state_store=self.maubot.state_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
def postinit(self) -> None:
|
||||||
|
if self._postinited:
|
||||||
|
raise RuntimeError("postinit() called twice")
|
||||||
|
self._postinited = True
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
self.log = log.getChild(self.id)
|
self.log = self.log.getChild(self.id)
|
||||||
|
self.http_client = ClientSession(loop=self.maubot.loop)
|
||||||
self.references = set()
|
self.references = set()
|
||||||
self.started = False
|
self.started = False
|
||||||
self.sync_ok = True
|
self.sync_ok = True
|
||||||
self.remote_displayname = None
|
self.remote_displayname = None
|
||||||
self.remote_avatar_url = None
|
self.remote_avatar_url = None
|
||||||
self.client = MaubotMatrixClient(
|
self.client = self._make_client()
|
||||||
mxid=self.id,
|
|
||||||
base_url=self.homeserver,
|
|
||||||
token=self.access_token,
|
|
||||||
client_session=self.http_client,
|
|
||||||
log=self.log,
|
|
||||||
loop=self.loop,
|
|
||||||
device_id=self.device_id,
|
|
||||||
sync_store=SyncStoreProxy(self.db_instance),
|
|
||||||
state_store=self.global_state_store,
|
|
||||||
)
|
|
||||||
if self.enable_crypto:
|
if self.enable_crypto:
|
||||||
self._prepare_crypto()
|
self._prepare_crypto()
|
||||||
else:
|
else:
|
||||||
@ -118,6 +159,12 @@ class Client:
|
|||||||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||||
|
|
||||||
|
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||||
|
async def handler(data: dict[str, Any]) -> None:
|
||||||
|
self.sync_ok = ok
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_crypto(self) -> bool:
|
def enable_crypto(self) -> bool:
|
||||||
if not self.device_id:
|
if not self.device_id:
|
||||||
@ -131,16 +178,21 @@ class Client:
|
|||||||
# Clear the stack trace after it's logged once to avoid spamming logs
|
# Clear the stack trace after it's logged once to avoid spamming logs
|
||||||
crypto_import_error = None
|
crypto_import_error = None
|
||||||
return False
|
return False
|
||||||
elif not self.crypto_db:
|
elif not self.maubot.crypto_db:
|
||||||
self.log.warning("Client has device ID, but crypto database is not prepared")
|
self.log.warning("Client has device ID, but crypto database is not prepared")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _prepare_crypto(self) -> None:
|
def _prepare_crypto(self) -> None:
|
||||||
self.crypto_store = PgCryptoStore(
|
self.crypto_store = PgCryptoStore(
|
||||||
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
|
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
|
||||||
|
)
|
||||||
|
self.crypto = OlmMachine(
|
||||||
|
self.client,
|
||||||
|
self.crypto_store,
|
||||||
|
self.maubot.state_store,
|
||||||
|
log=self.client.crypto_log,
|
||||||
)
|
)
|
||||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
|
||||||
self.client.crypto = self.crypto
|
self.client.crypto = self.crypto
|
||||||
|
|
||||||
def _remove_crypto_event_handlers(self) -> None:
|
def _remove_crypto_event_handlers(self) -> None:
|
||||||
@ -156,12 +208,6 @@ class Client:
|
|||||||
for event_type, func in handlers:
|
for event_type, func in handlers:
|
||||||
self.client.remove_event_handler(event_type, func)
|
self.client.remove_event_handler(event_type, func)
|
||||||
|
|
||||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
|
||||||
async def handler(data: dict[str, Any]) -> None:
|
|
||||||
self.sync_ok = ok
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
async def start(self, try_n: int | None = 0) -> None:
|
async def start(self, try_n: int | None = 0) -> None:
|
||||||
try:
|
try:
|
||||||
if try_n > 0:
|
if try_n > 0:
|
||||||
@ -196,31 +242,34 @@ class Client:
|
|||||||
whoami = await self.client.whoami()
|
whoami = await self.client.whoami()
|
||||||
except MatrixInvalidToken as e:
|
except MatrixInvalidToken as e:
|
||||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||||
self.db_instance.enabled = False
|
self.enabled = False
|
||||||
|
await self.update()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if try_n >= 8:
|
if try_n >= 8:
|
||||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||||
self.db_instance.enabled = False
|
self.enabled = False
|
||||||
|
await self.update()
|
||||||
else:
|
else:
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
|
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
|
||||||
)
|
)
|
||||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
_ = asyncio.create_task(self.start(try_n + 1))
|
||||||
return
|
return
|
||||||
if whoami.user_id != self.id:
|
if whoami.user_id != self.id:
|
||||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||||
self.db_instance.enabled = False
|
self.enabled = False
|
||||||
|
await self.update()
|
||||||
return
|
return
|
||||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||||
self.log.error(
|
self.log.error(
|
||||||
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
||||||
)
|
)
|
||||||
self.db_instance.enabled = False
|
self.enabled = False
|
||||||
|
await self.update()
|
||||||
return
|
return
|
||||||
if not self.filter_id:
|
if not self.filter_id:
|
||||||
self.db_instance.edit(
|
self.filter_id = await self.client.create_filter(
|
||||||
filter_id=await self.client.create_filter(
|
|
||||||
Filter(
|
Filter(
|
||||||
room=RoomFilter(
|
room=RoomFilter(
|
||||||
timeline=RoomEventFilter(
|
timeline=RoomEventFilter(
|
||||||
@ -236,7 +285,7 @@ class Client:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
await self.update()
|
||||||
if self.displayname != "disable":
|
if self.displayname != "disable":
|
||||||
await self.client.set_displayname(self.displayname)
|
await self.client.set_displayname(self.displayname)
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
@ -270,18 +319,13 @@ class Client:
|
|||||||
if self.crypto:
|
if self.crypto:
|
||||||
await self.crypto_store.close()
|
await self.crypto_store.close()
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
async def clear_cache(self) -> None:
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
self.db_instance.edit(filter_id="", next_batch="")
|
self.filter_id = FilterID("")
|
||||||
|
self.next_batch = SyncToken("")
|
||||||
|
await self.update()
|
||||||
self.start_sync()
|
self.start_sync()
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
try:
|
|
||||||
del self.cache[self.id]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
self.db_instance.delete()
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@ -304,20 +348,6 @@ class Client:
|
|||||||
"instances": [instance.to_dict() for instance in self.references],
|
"instances": [instance.to_dict() for instance in self.references],
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
|
|
||||||
try:
|
|
||||||
return cls.cache[user_id]
|
|
||||||
except KeyError:
|
|
||||||
db_instance = db_instance or DBClient.get(user_id)
|
|
||||||
if not db_instance:
|
|
||||||
return None
|
|
||||||
return Client(db_instance)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def all(cls) -> Iterable[Client]:
|
|
||||||
return (cls.get(user.id, user) for user in DBClient.all())
|
|
||||||
|
|
||||||
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
||||||
if not evt.content.replacement_room:
|
if not evt.content.replacement_room:
|
||||||
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
|
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
|
||||||
@ -329,7 +359,7 @@ class Client:
|
|||||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||||
await self.client.join_room(evt.room_id)
|
await self.client.join_room(evt.room_id)
|
||||||
|
|
||||||
async def update_started(self, started: bool) -> None:
|
async def update_started(self, started: bool | None) -> None:
|
||||||
if started is None or started == self.started:
|
if started is None or started == self.started:
|
||||||
return
|
return
|
||||||
if started:
|
if started:
|
||||||
@ -337,23 +367,65 @@ class Client:
|
|||||||
else:
|
else:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
|
|
||||||
async def update_displayname(self, displayname: str) -> None:
|
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
|
||||||
|
if enabled is None or enabled == self.enabled:
|
||||||
|
return
|
||||||
|
self.enabled = enabled
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
|
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
|
||||||
if displayname is None or displayname == self.displayname:
|
if displayname is None or displayname == self.displayname:
|
||||||
return
|
return
|
||||||
self.db_instance.displayname = displayname
|
self.displayname = displayname
|
||||||
if self.displayname != "disable":
|
if self.displayname != "disable":
|
||||||
await self.client.set_displayname(self.displayname)
|
await self.client.set_displayname(self.displayname)
|
||||||
else:
|
else:
|
||||||
await self._update_remote_profile()
|
await self._update_remote_profile()
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
|
||||||
if avatar_url is None or avatar_url == self.avatar_url:
|
if avatar_url is None or avatar_url == self.avatar_url:
|
||||||
return
|
return
|
||||||
self.db_instance.avatar_url = avatar_url
|
self.avatar_url = avatar_url
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
await self.client.set_avatar_url(self.avatar_url)
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
else:
|
else:
|
||||||
await self._update_remote_profile()
|
await self._update_remote_profile()
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
|
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
|
||||||
|
if sync is None or self.sync == sync:
|
||||||
|
return
|
||||||
|
self.sync = sync
|
||||||
|
if self.started:
|
||||||
|
if sync:
|
||||||
|
self.start_sync()
|
||||||
|
else:
|
||||||
|
self.stop_sync()
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
|
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
|
||||||
|
if autojoin is None or autojoin == self.autojoin:
|
||||||
|
return
|
||||||
|
if autojoin:
|
||||||
|
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||||
|
else:
|
||||||
|
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||||
|
self.autojoin = autojoin
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
|
async def update_online(self, online: bool | None, save: bool = True) -> None:
|
||||||
|
if online is None or online == self.online:
|
||||||
|
return
|
||||||
|
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
|
||||||
|
self.online = online
|
||||||
|
if save:
|
||||||
|
await self.update()
|
||||||
|
|
||||||
async def update_access_details(
|
async def update_access_details(
|
||||||
self,
|
self,
|
||||||
@ -373,22 +445,13 @@ class Client:
|
|||||||
and device_id == self.device_id
|
and device_id == self.device_id
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
new_client = MaubotMatrixClient(
|
new_client = self._make_client(homeserver, access_token, device_id)
|
||||||
mxid=self.id,
|
|
||||||
base_url=homeserver or self.homeserver,
|
|
||||||
token=access_token or self.access_token,
|
|
||||||
loop=self.loop,
|
|
||||||
device_id=device_id,
|
|
||||||
client_session=self.http_client,
|
|
||||||
log=self.log,
|
|
||||||
state_store=self.global_state_store,
|
|
||||||
)
|
|
||||||
whoami = await new_client.whoami()
|
whoami = await new_client.whoami()
|
||||||
if whoami.user_id != self.id:
|
if whoami.user_id != self.id:
|
||||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||||
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
new_client.sync_store = self
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
|
|
||||||
# TODO this event handler transfer is pretty hacky
|
# TODO this event handler transfer is pretty hacky
|
||||||
@ -398,9 +461,9 @@ class Client:
|
|||||||
new_client.global_event_handlers = self.client.global_event_handlers
|
new_client.global_event_handlers = self.client.global_event_handlers
|
||||||
|
|
||||||
self.client = new_client
|
self.client = new_client
|
||||||
self.db_instance.homeserver = homeserver
|
self.homeserver = homeserver
|
||||||
self.db_instance.access_token = access_token
|
self.access_token = access_token
|
||||||
self.db_instance.device_id = device_id
|
self.device_id = device_id
|
||||||
if self.enable_crypto:
|
if self.enable_crypto:
|
||||||
self._prepare_crypto()
|
self._prepare_crypto()
|
||||||
await self._start_crypto()
|
await self._start_crypto()
|
||||||
@ -413,97 +476,53 @@ class Client:
|
|||||||
profile = await self.client.get_profile(self.id)
|
profile = await self.client.get_profile(self.id)
|
||||||
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
|
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
|
||||||
|
|
||||||
# region Properties
|
async def delete(self) -> None:
|
||||||
|
try:
|
||||||
|
del self.cache[self.id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
await super().delete()
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def id(self) -> UserID:
|
@async_getter_lock
|
||||||
return self.db_instance.id
|
async def get(
|
||||||
|
cls,
|
||||||
|
user_id: UserID,
|
||||||
|
*,
|
||||||
|
homeserver: str | None = None,
|
||||||
|
access_token: str | None = None,
|
||||||
|
device_id: DeviceID | None = None,
|
||||||
|
) -> Client | None:
|
||||||
|
try:
|
||||||
|
return cls.cache[user_id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
user = cast(cls, await super().get(user_id))
|
||||||
def homeserver(self) -> str:
|
if user is not None:
|
||||||
return self.db_instance.homeserver
|
user.postinit()
|
||||||
|
return user
|
||||||
|
|
||||||
@property
|
if homeserver and access_token:
|
||||||
def access_token(self) -> str:
|
user = cls(
|
||||||
return self.db_instance.access_token
|
user_id,
|
||||||
|
homeserver=homeserver,
|
||||||
|
access_token=access_token,
|
||||||
|
device_id=device_id or "",
|
||||||
|
)
|
||||||
|
await user.insert()
|
||||||
|
user.postinit()
|
||||||
|
return user
|
||||||
|
|
||||||
@property
|
return None
|
||||||
def device_id(self) -> DeviceID:
|
|
||||||
return self.db_instance.device_id
|
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def enabled(self) -> bool:
|
async def all(cls) -> AsyncGenerator[Client, None]:
|
||||||
return self.db_instance.enabled
|
users = await super().all()
|
||||||
|
user: cls
|
||||||
@enabled.setter
|
for user in users:
|
||||||
def enabled(self, value: bool) -> None:
|
try:
|
||||||
self.db_instance.enabled = value
|
yield cls.cache[user.id]
|
||||||
|
except KeyError:
|
||||||
@property
|
user.postinit()
|
||||||
def next_batch(self) -> SyncToken:
|
yield user
|
||||||
return self.db_instance.next_batch
|
|
||||||
|
|
||||||
@property
|
|
||||||
def filter_id(self) -> FilterID:
|
|
||||||
return self.db_instance.filter_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sync(self) -> bool:
|
|
||||||
return self.db_instance.sync
|
|
||||||
|
|
||||||
@sync.setter
|
|
||||||
def sync(self, value: bool) -> None:
|
|
||||||
if value == self.db_instance.sync:
|
|
||||||
return
|
|
||||||
self.db_instance.sync = value
|
|
||||||
if self.started:
|
|
||||||
if value:
|
|
||||||
self.start_sync()
|
|
||||||
else:
|
|
||||||
self.stop_sync()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def autojoin(self) -> bool:
|
|
||||||
return self.db_instance.autojoin
|
|
||||||
|
|
||||||
@autojoin.setter
|
|
||||||
def autojoin(self, value: bool) -> None:
|
|
||||||
if value == self.db_instance.autojoin:
|
|
||||||
return
|
|
||||||
if value:
|
|
||||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
|
||||||
else:
|
|
||||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
|
||||||
self.db_instance.autojoin = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def online(self) -> bool:
|
|
||||||
return self.db_instance.online
|
|
||||||
|
|
||||||
@online.setter
|
|
||||||
def online(self, value: bool) -> None:
|
|
||||||
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
|
|
||||||
self.db_instance.online = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def displayname(self) -> str:
|
|
||||||
return self.db_instance.displayname
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avatar_url(self) -> ContentURI:
|
|
||||||
return self.db_instance.avatar_url
|
|
||||||
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
|
|
||||||
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
|
||||||
Client.http_client = ClientSession(loop=loop)
|
|
||||||
Client.loop = loop
|
|
||||||
|
|
||||||
if OlmMachine:
|
|
||||||
db_url = config["crypto_database"]
|
|
||||||
if db_url == "default":
|
|
||||||
db_url = config["database"]
|
|
||||||
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
|
|
||||||
|
|
||||||
return Client.all()
|
|
||||||
|
108
maubot/db.py
108
maubot/db.py
@ -1,108 +0,0 @@
|
|||||||
# maubot - A plugin-based Matrix bot system.
|
|
||||||
# Copyright (C) 2022 Tulir Asokan
|
|
||||||
#
|
|
||||||
# This program is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
|
||||||
# the Free Software Foundation, either version 3 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
#
|
|
||||||
# This program is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU Affero General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
from typing import Iterable, Optional
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
|
|
||||||
from sqlalchemy.engine.base import Engine
|
|
||||||
import sqlalchemy as sql
|
|
||||||
|
|
||||||
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
|
||||||
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
|
||||||
from mautrix.util.db import Base
|
|
||||||
|
|
||||||
from .config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class DBPlugin(Base):
|
|
||||||
__tablename__ = "plugin"
|
|
||||||
|
|
||||||
id: str = Column(String(255), primary_key=True)
|
|
||||||
type: str = Column(String(255), nullable=False)
|
|
||||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
|
||||||
primary_user: UserID = Column(
|
|
||||||
String(255),
|
|
||||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
config: str = Column(Text, nullable=False, default="")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def all(cls) -> Iterable["DBPlugin"]:
|
|
||||||
return cls._select_all()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, id: str) -> Optional["DBPlugin"]:
|
|
||||||
return cls._select_one_or_none(cls.c.id == id)
|
|
||||||
|
|
||||||
|
|
||||||
class DBClient(Base):
|
|
||||||
__tablename__ = "client"
|
|
||||||
|
|
||||||
id: UserID = Column(String(255), primary_key=True)
|
|
||||||
homeserver: str = Column(String(255), nullable=False)
|
|
||||||
access_token: str = Column(Text, nullable=False)
|
|
||||||
device_id: DeviceID = Column(String(255), nullable=True)
|
|
||||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
|
||||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
|
||||||
|
|
||||||
sync: bool = Column(Boolean, nullable=False, default=True)
|
|
||||||
autojoin: bool = Column(Boolean, nullable=False, default=True)
|
|
||||||
online: bool = Column(Boolean, nullable=False, default=True)
|
|
||||||
|
|
||||||
displayname: str = Column(String(255), nullable=False, default="")
|
|
||||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def all(cls) -> Iterable["DBClient"]:
|
|
||||||
return cls._select_all()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, id: str) -> Optional["DBClient"]:
|
|
||||||
return cls._select_one_or_none(cls.c.id == id)
|
|
||||||
|
|
||||||
|
|
||||||
def init(config: Config) -> Engine:
|
|
||||||
db = sql.create_engine(config["database"])
|
|
||||||
Base.metadata.bind = db
|
|
||||||
|
|
||||||
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
|
||||||
table.bind(db)
|
|
||||||
|
|
||||||
if not db.has_table("alembic_version"):
|
|
||||||
log = logging.getLogger("maubot.db")
|
|
||||||
|
|
||||||
if db.has_table("client") and db.has_table("plugin"):
|
|
||||||
log.warning(
|
|
||||||
"alembic_version table not found, but client and plugin tables found. "
|
|
||||||
"Assuming pre-Alembic database and inserting version."
|
|
||||||
)
|
|
||||||
db.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS alembic_version ("
|
|
||||||
" version_num VARCHAR(32) PRIMARY KEY"
|
|
||||||
");"
|
|
||||||
)
|
|
||||||
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
|
|
||||||
else:
|
|
||||||
log.critical(
|
|
||||||
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
|
|
||||||
)
|
|
||||||
sys.exit(10)
|
|
||||||
|
|
||||||
return db
|
|
13
maubot/db/__init__.py
Normal file
13
maubot/db/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
|
from .client import Client
|
||||||
|
from .instance import Instance
|
||||||
|
from .upgrade import upgrade_table
|
||||||
|
|
||||||
|
|
||||||
|
def init(db: Database) -> None:
|
||||||
|
for table in (Client, Instance):
|
||||||
|
table.db = db
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["upgrade_table", "init", "Client", "Instance"]
|
114
maubot/db/client.py
Normal file
114
maubot/db/client.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
|
from asyncpg import Record
|
||||||
|
from attr import dataclass
|
||||||
|
|
||||||
|
from mautrix.client import SyncStore
|
||||||
|
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
||||||
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Client(SyncStore):
|
||||||
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
|
id: UserID
|
||||||
|
homeserver: str
|
||||||
|
access_token: str
|
||||||
|
device_id: DeviceID
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
next_batch: SyncToken
|
||||||
|
filter_id: FilterID
|
||||||
|
|
||||||
|
sync: bool
|
||||||
|
autojoin: bool
|
||||||
|
online: bool
|
||||||
|
|
||||||
|
displayname: str
|
||||||
|
avatar_url: ContentURI
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_row(cls, row: Record | None) -> Client | None:
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return cls(**row)
|
||||||
|
|
||||||
|
_columns = (
|
||||||
|
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
|
||||||
|
"sync, autojoin, online, displayname, avatar_url"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _values(self):
|
||||||
|
return (
|
||||||
|
self.id,
|
||||||
|
self.homeserver,
|
||||||
|
self.access_token,
|
||||||
|
self.device_id,
|
||||||
|
self.enabled,
|
||||||
|
self.next_batch,
|
||||||
|
self.filter_id,
|
||||||
|
self.sync,
|
||||||
|
self.autojoin,
|
||||||
|
self.online,
|
||||||
|
self.displayname,
|
||||||
|
self.avatar_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def all(cls) -> list[Client]:
|
||||||
|
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
|
||||||
|
return [cls._from_row(row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls, id: str) -> Client | None:
|
||||||
|
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
|
||||||
|
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||||
|
|
||||||
|
async def insert(self) -> None:
|
||||||
|
q = """
|
||||||
|
INSERT INTO client (
|
||||||
|
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
|
||||||
|
sync, autojoin, online, displayname, avatar_url
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
|
"""
|
||||||
|
await self.db.execute(q, *self._values)
|
||||||
|
|
||||||
|
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||||
|
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
|
||||||
|
self.next_batch = next_batch
|
||||||
|
|
||||||
|
async def get_next_batch(self) -> SyncToken:
|
||||||
|
return self.next_batch
|
||||||
|
|
||||||
|
async def update(self) -> None:
|
||||||
|
q = """
|
||||||
|
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
|
||||||
|
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
|
||||||
|
displayname=$11, avatar_url=$12
|
||||||
|
WHERE id=$1
|
||||||
|
"""
|
||||||
|
await self.db.execute(q, *self._values)
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)
|
75
maubot/db/instance.py
Normal file
75
maubot/db/instance.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
|
from asyncpg import Record
|
||||||
|
from attr import dataclass
|
||||||
|
|
||||||
|
from mautrix.types import UserID
|
||||||
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Instance:
|
||||||
|
db: ClassVar[Database] = fake_db
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
enabled: bool
|
||||||
|
primary_user: UserID
|
||||||
|
config_str: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_row(cls, row: Record | None) -> Instance | None:
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return cls(**row)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def all(cls) -> list[Instance]:
|
||||||
|
rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
|
||||||
|
return [cls._from_row(row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls, id: str) -> Instance | None:
|
||||||
|
q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
|
||||||
|
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||||
|
|
||||||
|
async def update_id(self, new_id: str) -> None:
|
||||||
|
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
|
||||||
|
self.id = new_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _values(self):
|
||||||
|
return self.id, self.type, self.enabled, self.primary_user, self.config_str
|
||||||
|
|
||||||
|
async def insert(self) -> None:
|
||||||
|
q = (
|
||||||
|
"INSERT INTO instance (id, type, enabled, primary_user, config) "
|
||||||
|
"VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
)
|
||||||
|
await self.db.execute(q, *self._values)
|
||||||
|
|
||||||
|
async def update(self) -> None:
|
||||||
|
q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
|
||||||
|
await self.db.execute(q, *self._values)
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)
|
5
maubot/db/upgrade/__init__.py
Normal file
5
maubot/db/upgrade/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from mautrix.util.async_db import UpgradeTable
|
||||||
|
|
||||||
|
upgrade_table = UpgradeTable()
|
||||||
|
|
||||||
|
from . import v01_initial_revision
|
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from mautrix.util.async_db import Connection, Scheme
|
||||||
|
|
||||||
|
from . import upgrade_table
|
||||||
|
|
||||||
|
legacy_version_query = "SELECT version_num FROM alembic_version"
|
||||||
|
last_legacy_version = "90aa88820eab"
|
||||||
|
|
||||||
|
|
||||||
|
@upgrade_table.register(description="Initial asyncpg revision")
|
||||||
|
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
|
||||||
|
if await conn.table_exists("alembic_version"):
|
||||||
|
await migrate_legacy_to_v1(conn, scheme)
|
||||||
|
else:
|
||||||
|
return await create_v1_tables(conn)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_v1_tables(conn: Connection) -> None:
|
||||||
|
await conn.execute(
|
||||||
|
"""CREATE TABLE client (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
homeserver TEXT NOT NULL,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
enabled BOOLEAN NOT NULL,
|
||||||
|
|
||||||
|
next_batch TEXT NOT NULL,
|
||||||
|
filter_id TEXT NOT NULL,
|
||||||
|
|
||||||
|
sync BOOLEAN NOT NULL,
|
||||||
|
autojoin BOOLEAN NOT NULL,
|
||||||
|
online BOOLEAN NOT NULL,
|
||||||
|
|
||||||
|
displayname TEXT NOT NULL,
|
||||||
|
avatar_url TEXT NOT NULL
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
"""CREATE TABLE instance (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
enabled BOOLEAN NOT NULL,
|
||||||
|
primary_user TEXT NOT NULL,
|
||||||
|
config TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
|
||||||
|
legacy_version = await conn.fetchval(legacy_version_query)
|
||||||
|
if legacy_version != last_legacy_version:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Legacy database is not on last version. "
|
||||||
|
"Please upgrade the old database with alembic or drop it completely first."
|
||||||
|
)
|
||||||
|
await conn.execute("ALTER TABLE plugin RENAME TO instance")
|
||||||
|
await update_state_store(conn, scheme)
|
||||||
|
if scheme != Scheme.SQLITE:
|
||||||
|
await varchar_to_text(conn)
|
||||||
|
await conn.execute("DROP TABLE alembic_version")
|
||||||
|
|
||||||
|
|
||||||
|
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
|
||||||
|
# The Matrix state store already has more or less the correct schema, so set the version
|
||||||
|
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
|
||||||
|
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
|
||||||
|
if scheme != Scheme.SQLITE:
|
||||||
|
# Remove old uppercase membership type and recreate it as lowercase
|
||||||
|
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
|
||||||
|
await conn.execute("DROP TYPE IF EXISTS membership")
|
||||||
|
await conn.execute(
|
||||||
|
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
|
||||||
|
"USING LOWER(membership)::membership"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Recreate table to remove CHECK constraint and lowercase everything
|
||||||
|
await conn.execute(
|
||||||
|
"""CREATE TABLE new_mx_user_profile (
|
||||||
|
room_id TEXT,
|
||||||
|
user_id TEXT,
|
||||||
|
membership TEXT NOT NULL
|
||||||
|
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
|
||||||
|
displayname TEXT,
|
||||||
|
avatar_url TEXT,
|
||||||
|
PRIMARY KEY (room_id, user_id)
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
|
||||||
|
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
|
||||||
|
FROM mx_user_profile
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await conn.execute("DROP TABLE mx_user_profile")
|
||||||
|
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
|
||||||
|
|
||||||
|
|
||||||
|
async def varchar_to_text(conn: Connection) -> None:
|
||||||
|
columns_to_adjust = {
|
||||||
|
"client": (
|
||||||
|
"id",
|
||||||
|
"homeserver",
|
||||||
|
"device_id",
|
||||||
|
"next_batch",
|
||||||
|
"filter_id",
|
||||||
|
"displayname",
|
||||||
|
"avatar_url",
|
||||||
|
),
|
||||||
|
"instance": ("id", "type", "primary_user"),
|
||||||
|
"mx_room_state": ("room_id",),
|
||||||
|
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
|
||||||
|
}
|
||||||
|
for table, columns in columns_to_adjust.items():
|
||||||
|
for column in columns:
|
||||||
|
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')
|
@ -6,9 +6,7 @@
|
|||||||
database: sqlite:///maubot.db
|
database: sqlite:///maubot.db
|
||||||
|
|
||||||
# Separate database URL for the crypto database. "default" means use the same database as above.
|
# Separate database URL for the crypto database. "default" means use the same database as above.
|
||||||
# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above.
|
crypto_database: default
|
||||||
# When using postgres, using the same database for both is safe.
|
|
||||||
crypto_database: sqlite:///crypto.db
|
|
||||||
|
|
||||||
plugin_directories:
|
plugin_directories:
|
||||||
# The directory where uploaded new plugins should be stored.
|
# The directory where uploaded new plugins should be stored.
|
||||||
|
@ -15,8 +15,10 @@
|
|||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Iterable
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||||
from asyncio import AbstractEventLoop
|
from collections import defaultdict
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
|
|||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID
|
||||||
|
from mautrix.util.async_getter_lock import async_getter_lock
|
||||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||||
|
|
||||||
from .client import Client
|
from .client import Client
|
||||||
from .config import Config
|
from .db import Instance as DBInstance
|
||||||
from .db import DBPlugin
|
|
||||||
from .loader import PluginLoader, ZippedPluginLoader
|
from .loader import PluginLoader, ZippedPluginLoader
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .server import MaubotServer, PluginWebApp
|
from .__main__ import Maubot
|
||||||
|
from .server import PluginWebApp
|
||||||
|
|
||||||
log = logging.getLogger("maubot.instance")
|
log = logging.getLogger("maubot.instance")
|
||||||
|
|
||||||
@ -44,29 +47,42 @@ yaml.indent(4)
|
|||||||
yaml.width = 200
|
yaml.width = 200
|
||||||
|
|
||||||
|
|
||||||
class PluginInstance:
|
class PluginInstance(DBInstance):
|
||||||
webserver: MaubotServer = None
|
maubot: "Maubot" = None
|
||||||
mb_config: Config = None
|
|
||||||
loop: AbstractEventLoop = None
|
|
||||||
cache: dict[str, PluginInstance] = {}
|
cache: dict[str, PluginInstance] = {}
|
||||||
plugin_directories: list[str] = []
|
plugin_directories: list[str] = []
|
||||||
|
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||||
|
|
||||||
log: logging.Logger
|
log: logging.Logger
|
||||||
loader: PluginLoader
|
loader: PluginLoader | None
|
||||||
client: Client
|
client: Client | None
|
||||||
plugin: Plugin
|
plugin: Plugin | None
|
||||||
config: BaseProxyConfig
|
config: BaseProxyConfig | None
|
||||||
base_cfg: RecursiveDict[CommentedMap] | None
|
base_cfg: RecursiveDict[CommentedMap] | None
|
||||||
base_cfg_str: str | None
|
base_cfg_str: str | None
|
||||||
inst_db: sql.engine.Engine
|
inst_db: sql.engine.Engine | None
|
||||||
inst_db_tables: dict[str, sql.Table]
|
inst_db_tables: dict[str, sql.Table] | None
|
||||||
inst_webapp: PluginWebApp | None
|
inst_webapp: PluginWebApp | None
|
||||||
inst_webapp_url: str | None
|
inst_webapp_url: str | None
|
||||||
started: bool
|
started: bool
|
||||||
|
|
||||||
def __init__(self, db_instance: DBPlugin):
|
def __init__(
|
||||||
self.db_instance = db_instance
|
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_cls(cls, maubot: "Maubot") -> None:
|
||||||
|
cls.maubot = maubot
|
||||||
|
|
||||||
|
def postinit(self) -> None:
|
||||||
self.log = log.getChild(self.id)
|
self.log = log.getChild(self.id)
|
||||||
|
self.cache[self.id] = self
|
||||||
self.config = None
|
self.config = None
|
||||||
self.started = False
|
self.started = False
|
||||||
self.loader = None
|
self.loader = None
|
||||||
@ -78,7 +94,6 @@ class PluginInstance:
|
|||||||
self.inst_webapp_url = None
|
self.inst_webapp_url = None
|
||||||
self.base_cfg = None
|
self.base_cfg = None
|
||||||
self.base_cfg_str = None
|
self.base_cfg_str = None
|
||||||
self.cache[self.id] = self
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -87,10 +102,10 @@ class PluginInstance:
|
|||||||
"enabled": self.enabled,
|
"enabled": self.enabled,
|
||||||
"started": self.started,
|
"started": self.started,
|
||||||
"primary_user": self.primary_user,
|
"primary_user": self.primary_user,
|
||||||
"config": self.db_instance.config,
|
"config": self.config_str,
|
||||||
"base_config": self.base_cfg_str,
|
"base_config": self.base_cfg_str,
|
||||||
"database": (
|
"database": (
|
||||||
self.inst_db is not None and self.mb_config["api_features.instance_database"]
|
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,19 +116,19 @@ class PluginInstance:
|
|||||||
self.inst_db_tables = metadata.tables
|
self.inst_db_tables = metadata.tables
|
||||||
return self.inst_db_tables
|
return self.inst_db_tables
|
||||||
|
|
||||||
def load(self) -> bool:
|
async def load(self) -> bool:
|
||||||
if not self.loader:
|
if not self.loader:
|
||||||
try:
|
try:
|
||||||
self.loader = PluginLoader.find(self.type)
|
self.loader = PluginLoader.find(self.type)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.log.error(f"Failed to find loader for type {self.type}")
|
self.log.error(f"Failed to find loader for type {self.type}")
|
||||||
self.db_instance.enabled = False
|
await self.update_enabled(False)
|
||||||
return False
|
return False
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self.client = Client.get(self.primary_user)
|
self.client = await Client.get(self.primary_user)
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||||
self.db_instance.enabled = False
|
await self.update_enabled(False)
|
||||||
return False
|
return False
|
||||||
if self.loader.meta.database:
|
if self.loader.meta.database:
|
||||||
self.enable_database()
|
self.enable_database()
|
||||||
@ -125,18 +140,18 @@ class PluginInstance:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def enable_webapp(self) -> None:
|
def enable_webapp(self) -> None:
|
||||||
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
|
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
|
||||||
|
|
||||||
def disable_webapp(self) -> None:
|
def disable_webapp(self) -> None:
|
||||||
self.webserver.remove_instance_webapp(self.id)
|
self.maubot.server.remove_instance_webapp(self.id)
|
||||||
self.inst_webapp = None
|
self.inst_webapp = None
|
||||||
self.inst_webapp_url = None
|
self.inst_webapp_url = None
|
||||||
|
|
||||||
def enable_database(self) -> None:
|
def enable_database(self) -> None:
|
||||||
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
|
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
||||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
||||||
|
|
||||||
def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
if self.loader is not None:
|
if self.loader is not None:
|
||||||
self.loader.references.remove(self)
|
self.loader.references.remove(self)
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
@ -145,23 +160,23 @@ class PluginInstance:
|
|||||||
del self.cache[self.id]
|
del self.cache[self.id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.db_instance.delete()
|
await super().delete()
|
||||||
if self.inst_db:
|
if self.inst_db:
|
||||||
self.inst_db.dispose()
|
self.inst_db.dispose()
|
||||||
ZippedPluginLoader.trash(
|
ZippedPluginLoader.trash(
|
||||||
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
|
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
||||||
reason="deleted",
|
reason="deleted",
|
||||||
)
|
)
|
||||||
if self.inst_webapp:
|
if self.inst_webapp:
|
||||||
self.disable_webapp()
|
self.disable_webapp()
|
||||||
|
|
||||||
def load_config(self) -> CommentedMap:
|
def load_config(self) -> CommentedMap:
|
||||||
return yaml.load(self.db_instance.config)
|
return yaml.load(self.config_str)
|
||||||
|
|
||||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
yaml.dump(data, buf)
|
yaml.dump(data, buf)
|
||||||
self.db_instance.config = buf.getvalue()
|
self.config_str = buf.getvalue()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self.started:
|
if self.started:
|
||||||
@ -172,7 +187,7 @@ class PluginInstance:
|
|||||||
return
|
return
|
||||||
if not self.client or not self.loader:
|
if not self.client or not self.loader:
|
||||||
self.log.warning("Missing plugin instance dependencies, attempting to load...")
|
self.log.warning("Missing plugin instance dependencies, attempting to load...")
|
||||||
if not self.load():
|
if not await self.load():
|
||||||
return
|
return
|
||||||
cls = await self.loader.load()
|
cls = await self.loader.load()
|
||||||
if self.loader.meta.webapp and self.inst_webapp is None:
|
if self.loader.meta.webapp and self.inst_webapp is None:
|
||||||
@ -205,7 +220,7 @@ class PluginInstance:
|
|||||||
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
|
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
|
||||||
self.plugin = cls(
|
self.plugin = cls(
|
||||||
client=self.client.client,
|
client=self.client.client,
|
||||||
loop=self.loop,
|
loop=self.maubot.loop,
|
||||||
http=self.client.http_client,
|
http=self.client.http_client,
|
||||||
instance_id=self.id,
|
instance_id=self.id,
|
||||||
log=self.log,
|
log=self.log,
|
||||||
@ -219,7 +234,7 @@ class PluginInstance:
|
|||||||
await self.plugin.internal_start()
|
await self.plugin.internal_start()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to start instance")
|
self.log.exception("Failed to start instance")
|
||||||
self.db_instance.enabled = False
|
await self.update_enabled(False)
|
||||||
return
|
return
|
||||||
self.started = True
|
self.started = True
|
||||||
self.inst_db_tables = None
|
self.inst_db_tables = None
|
||||||
@ -241,60 +256,51 @@ class PluginInstance:
|
|||||||
self.plugin = None
|
self.plugin = None
|
||||||
self.inst_db_tables = None
|
self.inst_db_tables = None
|
||||||
|
|
||||||
@classmethod
|
async def update_id(self, new_id: str | None) -> None:
|
||||||
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
|
if new_id is not None and new_id.lower() != self.id:
|
||||||
try:
|
await super().update_id(new_id.lower())
|
||||||
return cls.cache[instance_id]
|
|
||||||
except KeyError:
|
|
||||||
db_instance = db_instance or DBPlugin.get(instance_id)
|
|
||||||
if not db_instance:
|
|
||||||
return None
|
|
||||||
return PluginInstance(db_instance)
|
|
||||||
|
|
||||||
@classmethod
|
async def update_config(self, config: str | None) -> None:
|
||||||
def all(cls) -> Iterable[PluginInstance]:
|
if config is None or self.config_str == config:
|
||||||
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
|
|
||||||
|
|
||||||
def update_id(self, new_id: str) -> None:
|
|
||||||
if new_id is not None and new_id != self.id:
|
|
||||||
self.db_instance.id = new_id.lower()
|
|
||||||
|
|
||||||
def update_config(self, config: str) -> None:
|
|
||||||
if not config or self.db_instance.config == config:
|
|
||||||
return
|
return
|
||||||
self.db_instance.config = config
|
self.config_str = config
|
||||||
if self.started and self.plugin is not None:
|
if self.started and self.plugin is not None:
|
||||||
self.plugin.on_external_config_update()
|
res = self.plugin.on_external_config_update()
|
||||||
|
if inspect.isawaitable(res):
|
||||||
|
await res
|
||||||
|
await self.update()
|
||||||
|
|
||||||
async def update_primary_user(self, primary_user: UserID) -> bool:
|
async def update_primary_user(self, primary_user: UserID | None) -> bool:
|
||||||
if not primary_user or primary_user == self.primary_user:
|
if primary_user is None or primary_user == self.primary_user:
|
||||||
return True
|
return True
|
||||||
client = Client.get(primary_user)
|
client = await Client.get(primary_user)
|
||||||
if not client:
|
if not client:
|
||||||
return False
|
return False
|
||||||
await self.stop()
|
await self.stop()
|
||||||
self.db_instance.primary_user = client.id
|
self.primary_user = client.id
|
||||||
if self.client:
|
if self.client:
|
||||||
self.client.references.remove(self)
|
self.client.references.remove(self)
|
||||||
self.client = client
|
self.client = client
|
||||||
self.client.references.add(self)
|
self.client.references.add(self)
|
||||||
|
await self.update()
|
||||||
await self.start()
|
await self.start()
|
||||||
self.log.debug(f"Primary user switched to {self.client.id}")
|
self.log.debug(f"Primary user switched to {self.client.id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def update_type(self, type: str) -> bool:
|
async def update_type(self, type: str | None) -> bool:
|
||||||
if not type or type == self.type:
|
if type is None or type == self.type:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
loader = PluginLoader.find(type)
|
loader = PluginLoader.find(type)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
await self.stop()
|
await self.stop()
|
||||||
self.db_instance.type = loader.meta.id
|
self.type = loader.meta.id
|
||||||
if self.loader:
|
if self.loader:
|
||||||
self.loader.references.remove(self)
|
self.loader.references.remove(self)
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.loader.references.add(self)
|
self.loader.references.add(self)
|
||||||
|
await self.update()
|
||||||
await self.start()
|
await self.start()
|
||||||
self.log.debug(f"Type switched to {self.loader.meta.id}")
|
self.log.debug(f"Type switched to {self.loader.meta.id}")
|
||||||
return True
|
return True
|
||||||
@ -303,39 +309,41 @@ class PluginInstance:
|
|||||||
if started is not None and started != self.started:
|
if started is not None and started != self.started:
|
||||||
await (self.start() if started else self.stop())
|
await (self.start() if started else self.stop())
|
||||||
|
|
||||||
def update_enabled(self, enabled: bool) -> None:
|
async def update_enabled(self, enabled: bool) -> None:
|
||||||
if enabled is not None and enabled != self.enabled:
|
if enabled is not None and enabled != self.enabled:
|
||||||
self.db_instance.enabled = enabled
|
self.enabled = enabled
|
||||||
|
await self.update()
|
||||||
|
|
||||||
# region Properties
|
@classmethod
|
||||||
|
@async_getter_lock
|
||||||
|
async def get(
|
||||||
|
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
|
||||||
|
) -> PluginInstance | None:
|
||||||
|
try:
|
||||||
|
return cls.cache[instance_id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
instance = cast(cls, await super().get(instance_id))
|
||||||
def id(self) -> str:
|
if instance is not None:
|
||||||
return self.db_instance.id
|
instance.postinit()
|
||||||
|
return instance
|
||||||
|
|
||||||
@id.setter
|
if type and primary_user:
|
||||||
def id(self, value: str) -> None:
|
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
|
||||||
self.db_instance.id = value
|
await instance.insert()
|
||||||
|
instance.postinit()
|
||||||
|
return instance
|
||||||
|
|
||||||
@property
|
return None
|
||||||
def type(self) -> str:
|
|
||||||
return self.db_instance.type
|
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def enabled(self) -> bool:
|
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
|
||||||
return self.db_instance.enabled
|
instances = await super().all()
|
||||||
|
instance: PluginInstance
|
||||||
@property
|
for instance in instances:
|
||||||
def primary_user(self) -> UserID:
|
try:
|
||||||
return self.db_instance.primary_user
|
yield cls.cache[instance.id]
|
||||||
|
except KeyError:
|
||||||
# endregion
|
instance.postinit()
|
||||||
|
yield instance
|
||||||
|
|
||||||
def init(
|
|
||||||
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
|
|
||||||
) -> Iterable[PluginInstance]:
|
|
||||||
PluginInstance.mb_config = config
|
|
||||||
PluginInstance.loop = loop
|
|
||||||
PluginInstance.webserver = webserver
|
|
||||||
return PluginInstance.all()
|
|
||||||
|
@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
|
|||||||
class ColorFormatter(BaseColorFormatter):
|
class ColorFormatter(BaseColorFormatter):
|
||||||
def _color_name(self, module: str) -> str:
|
def _color_name(self, module: str) -> str:
|
||||||
client = "maubot.client"
|
client = "maubot.client"
|
||||||
if module.startswith(client):
|
if module.startswith(client + "."):
|
||||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
|
suffix = ""
|
||||||
|
if module.endswith(".crypto"):
|
||||||
|
suffix = f".{MAU_COLOR}crypto{RESET}"
|
||||||
|
module = module[: -len(".crypto")]
|
||||||
|
module = module[len(client) + 1 :]
|
||||||
|
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
|
||||||
instance = "maubot.instance"
|
instance = "maubot.instance"
|
||||||
if module.startswith(instance):
|
if module.startswith(instance + "."):
|
||||||
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
|
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
|
||||||
loader = "maubot.loader"
|
loader = "maubot.loader"
|
||||||
if module.startswith(loader):
|
if module.startswith(loader + "."):
|
||||||
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
|
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
|
||||||
if module.startswith("maubot"):
|
if module.startswith("maubot."):
|
||||||
return f"{MAU_COLOR}{module}{RESET}"
|
return f"{MAU_COLOR}{module}{RESET}"
|
||||||
return super()._color_name(module)
|
return super()._color_name(module)
|
||||||
|
@ -13,16 +13,15 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from mautrix.client import SyncStore
|
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
|
||||||
from mautrix.types import SyncToken
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mautrix.crypto import StateStore as CryptoStateStore
|
||||||
|
|
||||||
class SyncStoreProxy(SyncStore):
|
class PgStateStore(BasePgStateStore, CryptoStateStore):
|
||||||
def __init__(self, db_instance) -> None:
|
pass
|
||||||
self.db_instance = db_instance
|
|
||||||
|
|
||||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
except ImportError as e:
|
||||||
self.db_instance.edit(next_batch=next_batch)
|
PgStateStore = BasePgStateStore
|
||||||
|
|
||||||
async def get_next_batch(self) -> SyncToken:
|
__all__ = ["PgStateStore"]
|
||||||
return self.db_instance.next_batch
|
|
@ -13,17 +13,14 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from attr import dataclass
|
|
||||||
from packaging.version import InvalidVersion, Version
|
|
||||||
|
|
||||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
|
||||||
|
|
||||||
from ..__meta__ import __version__
|
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
from .meta import PluginMeta
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..instance import PluginInstance
|
from ..instance import PluginInstance
|
||||||
@ -35,36 +32,6 @@ class IDConflictError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@serializer(Version)
|
|
||||||
def serialize_version(version: Version) -> str:
|
|
||||||
return str(version)
|
|
||||||
|
|
||||||
|
|
||||||
@deserializer(Version)
|
|
||||||
def deserialize_version(version: str) -> Version:
|
|
||||||
try:
|
|
||||||
return Version(version)
|
|
||||||
except InvalidVersion as e:
|
|
||||||
raise SerializerError("Invalid version") from e
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PluginMeta(SerializableAttrs):
|
|
||||||
id: str
|
|
||||||
version: Version
|
|
||||||
modules: List[str]
|
|
||||||
main_class: str
|
|
||||||
|
|
||||||
maubot: Version = Version(__version__)
|
|
||||||
database: bool = False
|
|
||||||
config: bool = False
|
|
||||||
webapp: bool = False
|
|
||||||
license: str = ""
|
|
||||||
extra_files: List[str] = []
|
|
||||||
dependencies: List[str] = []
|
|
||||||
soft_dependencies: List[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class BasePluginLoader(ABC):
|
class BasePluginLoader(ABC):
|
||||||
meta: PluginMeta
|
meta: PluginMeta
|
||||||
|
|
||||||
@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
|
|||||||
async def read_file(self, path: str) -> bytes:
|
async def read_file(self, path: str) -> bytes:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sync_list_files(self, directory: str) -> List[str]:
|
def sync_list_files(self, directory: str) -> list[str]:
|
||||||
raise NotImplementedError("This loader doesn't support synchronous operations")
|
raise NotImplementedError("This loader doesn't support synchronous operations")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def list_files(self, directory: str) -> List[str]:
|
async def list_files(self, directory: str) -> list[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginLoader(BasePluginLoader, ABC):
|
class PluginLoader(BasePluginLoader, ABC):
|
||||||
id_cache: Dict[str, "PluginLoader"] = {}
|
id_cache: dict[str, PluginLoader] = {}
|
||||||
|
|
||||||
meta: PluginMeta
|
meta: PluginMeta
|
||||||
references: Set["PluginInstance"]
|
references: set[PluginInstance]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.references = set()
|
self.references = set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, plugin_id: str) -> "PluginLoader":
|
def find(cls, plugin_id: str) -> PluginLoader:
|
||||||
return cls.id_cache[plugin_id]
|
return cls.id_cache[plugin_id]
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def load(self) -> Type[PluginClass]:
|
async def load(self) -> type[PluginClass]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def reload(self) -> Type[PluginClass]:
|
async def reload(self) -> type[PluginClass]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
53
maubot/loader/meta.py
Normal file
53
maubot/loader/meta.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from attr import dataclass
|
||||||
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
|
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||||
|
|
||||||
|
from ..__meta__ import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@serializer(Version)
|
||||||
|
def serialize_version(version: Version) -> str:
|
||||||
|
return str(version)
|
||||||
|
|
||||||
|
|
||||||
|
@deserializer(Version)
|
||||||
|
def deserialize_version(version: str) -> Version:
|
||||||
|
try:
|
||||||
|
return Version(version)
|
||||||
|
except InvalidVersion as e:
|
||||||
|
raise SerializerError("Invalid version") from e
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PluginMeta(SerializableAttrs):
|
||||||
|
id: str
|
||||||
|
version: Version
|
||||||
|
modules: List[str]
|
||||||
|
main_class: str
|
||||||
|
|
||||||
|
maubot: Version = Version(__version__)
|
||||||
|
database: bool = False
|
||||||
|
config: bool = False
|
||||||
|
webapp: bool = False
|
||||||
|
license: str = ""
|
||||||
|
extra_files: List[str] = []
|
||||||
|
dependencies: List[str] = []
|
||||||
|
soft_dependencies: List[str] = []
|
@ -29,7 +29,8 @@ from mautrix.types import SerializerError
|
|||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..lib.zipimport import ZipImportError, zipimporter
|
from ..lib.zipimport import ZipImportError, zipimporter
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
|
from .abc import IDConflictError, PluginClass, PluginLoader
|
||||||
|
from .meta import PluginMeta
|
||||||
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from aiohttp import web
|
|||||||
|
|
||||||
from ...config import Config
|
from ...config import Config
|
||||||
from .auth import check_token
|
from .auth import check_token
|
||||||
from .base import get_config, routes, set_config, set_loop
|
from .base import get_config, routes, set_config
|
||||||
from .middleware import auth, error
|
from .middleware import auth, error
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||||
set_config(cfg)
|
set_config(cfg)
|
||||||
set_loop(loop)
|
|
||||||
for pkg, enabled in cfg["api_features"].items():
|
for pkg, enabled in cfg["api_features"].items():
|
||||||
if enabled:
|
if enabled:
|
||||||
importlib.import_module(f"maubot.management.api.{pkg}")
|
importlib.import_module(f"maubot.management.api.{pkg}")
|
||||||
|
@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
|
|||||||
def get_token(request: web.Request) -> str:
|
def get_token(request: web.Request) -> str:
|
||||||
token = request.headers.get("Authorization", "")
|
token = request.headers.get("Authorization", "")
|
||||||
if not token or not token.startswith("Bearer "):
|
if not token or not token.startswith("Bearer "):
|
||||||
token = request.query.get("access_token", None)
|
token = request.query.get("access_token", "")
|
||||||
else:
|
else:
|
||||||
token = token[len("Bearer ") :]
|
token = token[len("Bearer ") :]
|
||||||
return token
|
return token
|
||||||
|
@ -24,7 +24,6 @@ from ...config import Config
|
|||||||
|
|
||||||
routes: web.RouteTableDef = web.RouteTableDef()
|
routes: web.RouteTableDef = web.RouteTableDef()
|
||||||
_config: Config | None = None
|
_config: Config | None = None
|
||||||
_loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Config) -> None:
|
def set_config(config: Config) -> None:
|
||||||
@ -36,15 +35,6 @@ def get_config() -> Config:
|
|||||||
return _config
|
return _config
|
||||||
|
|
||||||
|
|
||||||
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
|
|
||||||
global _loop
|
|
||||||
_loop = loop
|
|
||||||
|
|
||||||
|
|
||||||
def get_loop() -> asyncio.AbstractEventLoop:
|
|
||||||
return _loop
|
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/version")
|
@routes.get("/version")
|
||||||
async def version(_: web.Request) -> web.Response:
|
async def version(_: web.Request) -> web.Response:
|
||||||
return web.json_response({"version": __version__})
|
return web.json_response({"version": __version__})
|
||||||
|
@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
|
|||||||
from mautrix.types import FilterID, SyncToken, UserID
|
from mautrix.types import FilterID, SyncToken, UserID
|
||||||
|
|
||||||
from ...client import Client
|
from ...client import Client
|
||||||
from ...db import DBClient
|
|
||||||
from .base import routes
|
from .base import routes
|
||||||
from .responses import resp
|
from .responses import resp
|
||||||
|
|
||||||
@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
|
|||||||
@routes.get("/client/{id}")
|
@routes.get("/client/{id}")
|
||||||
async def get_client(request: web.Request) -> web.Response:
|
async def get_client(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info.get("id", None)
|
||||||
client = Client.get(user_id, None)
|
client = await Client.get(user_id)
|
||||||
if not client:
|
if not client:
|
||||||
return resp.client_not_found
|
return resp.client_not_found
|
||||||
return resp.found(client.to_dict())
|
return resp.found(client.to_dict())
|
||||||
@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||||||
mxid="@not:a.mxid",
|
mxid="@not:a.mxid",
|
||||||
base_url=homeserver,
|
base_url=homeserver,
|
||||||
token=access_token,
|
token=access_token,
|
||||||
loop=Client.loop,
|
|
||||||
client_session=Client.http_client,
|
client_session=Client.http_client,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||||||
except MatrixConnectionError:
|
except MatrixConnectionError:
|
||||||
return resp.bad_client_connection_details
|
return resp.bad_client_connection_details
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
existing_client = Client.get(whoami.user_id, None)
|
existing_client = await Client.get(whoami.user_id)
|
||||||
if existing_client is not None:
|
if existing_client is not None:
|
||||||
return resp.user_exists
|
return resp.user_exists
|
||||||
elif whoami.user_id != user_id:
|
elif whoami.user_id != user_id:
|
||||||
return resp.mxid_mismatch(whoami.user_id)
|
return resp.mxid_mismatch(whoami.user_id)
|
||||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||||
return resp.device_id_mismatch(whoami.device_id)
|
return resp.device_id_mismatch(whoami.device_id)
|
||||||
db_instance = DBClient(
|
client = await Client.get(
|
||||||
id=whoami.user_id,
|
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
|
||||||
homeserver=homeserver,
|
|
||||||
access_token=access_token,
|
|
||||||
enabled=data.get("enabled", True),
|
|
||||||
next_batch=SyncToken(""),
|
|
||||||
filter_id=FilterID(""),
|
|
||||||
sync=data.get("sync", True),
|
|
||||||
autojoin=data.get("autojoin", True),
|
|
||||||
online=data.get("online", True),
|
|
||||||
displayname=data.get("displayname", "disable"),
|
|
||||||
avatar_url=data.get("avatar_url", "disable"),
|
|
||||||
device_id=device_id,
|
|
||||||
)
|
)
|
||||||
client = Client(db_instance)
|
client.enabled = data.get("enabled", True)
|
||||||
client.db_instance.insert()
|
client.sync = data.get("sync", True)
|
||||||
|
client.autojoin = data.get("autojoin", True)
|
||||||
|
client.online = data.get("online", True)
|
||||||
|
client.displayname = data.get("displayname", "disable")
|
||||||
|
client.avatar_url = data.get("avatar_url", "disable")
|
||||||
|
await client.update()
|
||||||
await client.start()
|
await client.start()
|
||||||
return resp.created(client.to_dict())
|
return resp.created(client.to_dict())
|
||||||
|
|
||||||
@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||||||
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
||||||
try:
|
try:
|
||||||
await client.update_access_details(
|
await client.update_access_details(
|
||||||
data.get("access_token", None),
|
data.get("access_token"), data.get("homeserver"), data.get("device_id")
|
||||||
data.get("homeserver", None),
|
|
||||||
data.get("device_id", None),
|
|
||||||
)
|
)
|
||||||
except MatrixInvalidToken:
|
except MatrixInvalidToken:
|
||||||
return resp.bad_client_access_token
|
return resp.bad_client_access_token
|
||||||
@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
|
|||||||
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
|
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
|
||||||
elif str_err.startswith("Device ID mismatch"):
|
elif str_err.startswith("Device ID mismatch"):
|
||||||
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
|
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
|
||||||
with client.db_instance.edit_mode():
|
await client.update_avatar_url(data.get("avatar_url"), save=False)
|
||||||
await client.update_avatar_url(data.get("avatar_url", None))
|
await client.update_displayname(data.get("displayname"), save=False)
|
||||||
await client.update_displayname(data.get("displayname", None))
|
await client.update_started(data.get("started"))
|
||||||
await client.update_started(data.get("started", None))
|
await client.update_enabled(data.get("enabled"), save=False)
|
||||||
client.enabled = data.get("enabled", client.enabled)
|
await client.update_autojoin(data.get("autojoin"), save=False)
|
||||||
client.autojoin = data.get("autojoin", client.autojoin)
|
await client.update_online(data.get("online"), save=False)
|
||||||
client.online = data.get("online", client.online)
|
await client.update_sync(data.get("sync"), save=False)
|
||||||
client.sync = data.get("sync", client.sync)
|
await client.update()
|
||||||
return resp.updated(client.to_dict(), is_login=is_login)
|
return resp.updated(client.to_dict(), is_login=is_login)
|
||||||
|
|
||||||
|
|
||||||
async def _create_or_update_client(
|
async def _create_or_update_client(
|
||||||
user_id: UserID, data: dict, is_login: bool = False
|
user_id: UserID, data: dict, is_login: bool = False
|
||||||
) -> web.Response:
|
) -> web.Response:
|
||||||
client = Client.get(user_id, None)
|
client = await Client.get(user_id)
|
||||||
if not client:
|
if not client:
|
||||||
return await _create_client(user_id, data)
|
return await _create_client(user_id, data)
|
||||||
else:
|
else:
|
||||||
@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.put("/client/{id}")
|
@routes.put("/client/{id}")
|
||||||
async def update_client(request: web.Request) -> web.Response:
|
async def update_client(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info["id"]
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.delete("/client/{id}")
|
@routes.delete("/client/{id}")
|
||||||
async def delete_client(request: web.Request) -> web.Response:
|
async def delete_client(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info["id"]
|
||||||
client = Client.get(user_id, None)
|
client = await Client.get(user_id)
|
||||||
if not client:
|
if not client:
|
||||||
return resp.client_not_found
|
return resp.client_not_found
|
||||||
if len(client.references) > 0:
|
if len(client.references) > 0:
|
||||||
return resp.client_in_use
|
return resp.client_in_use
|
||||||
if client.started:
|
if client.started:
|
||||||
await client.stop()
|
await client.stop()
|
||||||
client.delete()
|
await client.delete()
|
||||||
return resp.deleted
|
return resp.deleted
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/client/{id}/clearcache")
|
@routes.post("/client/{id}/clearcache")
|
||||||
async def clear_client_cache(request: web.Request) -> web.Response:
|
async def clear_client_cache(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info["id"]
|
||||||
client = Client.get(user_id, None)
|
client = await Client.get(user_id)
|
||||||
if not client:
|
if not client:
|
||||||
return resp.client_not_found
|
return resp.client_not_found
|
||||||
client.clear_cache()
|
await client.clear_cache()
|
||||||
return resp.ok
|
return resp.ok
|
||||||
|
@ -13,7 +13,9 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, NamedTuple, Optional, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
|
|||||||
from mautrix.errors import MatrixRequestError
|
from mautrix.errors import MatrixRequestError
|
||||||
from mautrix.types import LoginResponse, LoginType
|
from mautrix.types import LoginResponse, LoginType
|
||||||
|
|
||||||
from .base import get_config, get_loop, routes
|
from .base import get_config, routes
|
||||||
from .client import _create_client, _create_or_update_client
|
from .client import _create_client, _create_or_update_client
|
||||||
from .responses import resp
|
from .responses import resp
|
||||||
|
|
||||||
|
|
||||||
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
def known_homeservers() -> dict[str, dict[str, str]]:
|
||||||
return get_config()["homeservers"]
|
return get_config()["homeservers"]
|
||||||
|
|
||||||
|
|
||||||
@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
|
|||||||
|
|
||||||
async def read_client_auth_request(
|
async def read_client_auth_request(
|
||||||
request: web.Request,
|
request: web.Request,
|
||||||
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
|
) -> tuple[AuthRequestInfo | None, web.Response | None]:
|
||||||
server_name = request.match_info.get("server", None)
|
server_name = request.match_info.get("server", None)
|
||||||
server = known_homeservers().get(server_name, None)
|
server = known_homeservers().get(server_name, None)
|
||||||
if not server:
|
if not server:
|
||||||
@ -85,7 +87,7 @@ async def read_client_auth_request(
|
|||||||
return (
|
return (
|
||||||
AuthRequestInfo(
|
AuthRequestInfo(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
client=ClientAPI(base_url=base_url),
|
||||||
secret=server.get("secret"),
|
secret=server.get("secret"),
|
||||||
username=username,
|
username=username,
|
||||||
password=password,
|
password=password,
|
||||||
@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
|
|||||||
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
|
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
|
||||||
{"redirectUrl": str(public_url)}
|
{"redirectUrl": str(public_url)}
|
||||||
)
|
)
|
||||||
sso_waiters[waiter_id] = req, get_loop().create_future()
|
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
|
||||||
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
||||||
|
|
||||||
|
|
||||||
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
|
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
|
||||||
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||||
device_id = f"maubot_{device_id}"
|
device_id = f"maubot_{device_id}"
|
||||||
try:
|
try:
|
||||||
@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
|
|||||||
return web.json_response(res.serialize())
|
return web.json_response(res.serialize())
|
||||||
|
|
||||||
|
|
||||||
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
|
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
||||||
|
@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
|
|||||||
@routes.view("/proxy/{id}/{path:_matrix/.+}")
|
@routes.view("/proxy/{id}/{path:_matrix/.+}")
|
||||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info.get("id", None)
|
||||||
client = Client.get(user_id, None)
|
client = await Client.get(user_id)
|
||||||
if not client:
|
if not client:
|
||||||
return resp.client_not_found
|
return resp.client_not_found
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from json import JSONDecodeError
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from ...client import Client
|
from ...client import Client
|
||||||
from ...db import DBPlugin
|
|
||||||
from ...instance import PluginInstance
|
from ...instance import PluginInstance
|
||||||
from ...loader import PluginLoader
|
from ...loader import PluginLoader
|
||||||
from .base import routes
|
from .base import routes
|
||||||
@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.get("/instance/{id}")
|
@routes.get("/instance/{id}")
|
||||||
async def get_instance(request: web.Request) -> web.Response:
|
async def get_instance(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "").lower()
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id, None)
|
instance = await PluginInstance.get(instance_id)
|
||||||
if not instance:
|
if not instance:
|
||||||
return resp.instance_not_found
|
return resp.instance_not_found
|
||||||
return resp.found(instance.to_dict())
|
return resp.found(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
||||||
plugin_type = data.get("type", None)
|
plugin_type = data.get("type")
|
||||||
primary_user = data.get("primary_user", None)
|
primary_user = data.get("primary_user")
|
||||||
if not plugin_type:
|
if not plugin_type:
|
||||||
return resp.plugin_type_required
|
return resp.plugin_type_required
|
||||||
elif not primary_user:
|
elif not primary_user:
|
||||||
return resp.primary_user_required
|
return resp.primary_user_required
|
||||||
elif not Client.get(primary_user):
|
elif not await Client.get(primary_user):
|
||||||
return resp.primary_user_not_found
|
return resp.primary_user_not_found
|
||||||
try:
|
try:
|
||||||
PluginLoader.find(plugin_type)
|
PluginLoader.find(plugin_type)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return resp.plugin_type_not_found
|
return resp.plugin_type_not_found
|
||||||
db_instance = DBPlugin(
|
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
|
||||||
id=instance_id,
|
instance.enabled = data.get("enabled", True)
|
||||||
type=plugin_type,
|
instance.config_str = data.get("config") or ""
|
||||||
enabled=data.get("enabled", True),
|
await instance.update()
|
||||||
primary_user=primary_user,
|
|
||||||
config=data.get("config", ""),
|
|
||||||
)
|
|
||||||
instance = PluginInstance(db_instance)
|
|
||||||
instance.load()
|
|
||||||
instance.db_instance.insert()
|
|
||||||
await instance.start()
|
await instance.start()
|
||||||
return resp.created(instance.to_dict())
|
return resp.created(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
|
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
|
||||||
if not await instance.update_primary_user(data.get("primary_user", None)):
|
if not await instance.update_primary_user(data.get("primary_user")):
|
||||||
return resp.primary_user_not_found
|
return resp.primary_user_not_found
|
||||||
with instance.db_instance.edit_mode():
|
await instance.update_id(data.get("id"))
|
||||||
instance.update_id(data.get("id", None))
|
await instance.update_enabled(data.get("enabled"))
|
||||||
instance.update_enabled(data.get("enabled", None))
|
await instance.update_config(data.get("config"))
|
||||||
instance.update_config(data.get("config", None))
|
await instance.update_started(data.get("started"))
|
||||||
await instance.update_started(data.get("started", None))
|
await instance.update_type(data.get("type"))
|
||||||
await instance.update_type(data.get("type", None))
|
|
||||||
return resp.updated(instance.to_dict())
|
return resp.updated(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@routes.put("/instance/{id}")
|
@routes.put("/instance/{id}")
|
||||||
async def update_instance(request: web.Request) -> web.Response:
|
async def update_instance(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "").lower()
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id, None)
|
instance = await PluginInstance.get(instance_id)
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.delete("/instance/{id}")
|
@routes.delete("/instance/{id}")
|
||||||
async def delete_instance(request: web.Request) -> web.Response:
|
async def delete_instance(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "").lower()
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id)
|
instance = await PluginInstance.get(instance_id)
|
||||||
if not instance:
|
if not instance:
|
||||||
return resp.instance_not_found
|
return resp.instance_not_found
|
||||||
if instance.started:
|
if instance.started:
|
||||||
await instance.stop()
|
await instance.stop()
|
||||||
instance.delete()
|
await instance.delete()
|
||||||
return resp.deleted
|
return resp.deleted
|
||||||
|
@ -29,8 +29,8 @@ from .responses import resp
|
|||||||
|
|
||||||
@routes.get("/instance/{id}/database")
|
@routes.get("/instance/{id}/database")
|
||||||
async def get_database(request: web.Request) -> web.Response:
|
async def get_database(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "")
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id, None)
|
instance = await PluginInstance.get(instance_id)
|
||||||
if not instance:
|
if not instance:
|
||||||
return resp.instance_not_found
|
return resp.instance_not_found
|
||||||
elif not instance.inst_db:
|
elif not instance.inst_db:
|
||||||
@ -65,8 +65,8 @@ def check_type(val):
|
|||||||
|
|
||||||
@routes.get("/instance/{id}/database/{table}")
|
@routes.get("/instance/{id}/database/{table}")
|
||||||
async def get_table(request: web.Request) -> web.Response:
|
async def get_table(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "")
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id, None)
|
instance = await PluginInstance.get(instance_id)
|
||||||
if not instance:
|
if not instance:
|
||||||
return resp.instance_not_found
|
return resp.instance_not_found
|
||||||
elif not instance.inst_db:
|
elif not instance.inst_db:
|
||||||
@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
|
|||||||
]
|
]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
order = []
|
order = []
|
||||||
limit = int(request.query.get("limit", 100))
|
limit = int(request.query.get("limit", "100"))
|
||||||
return execute_query(instance, table.select().order_by(*order).limit(limit))
|
return execute_query(instance, table.select().order_by(*order).limit(limit))
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/instance/{id}/database/query")
|
@routes.post("/instance/{id}/database/query")
|
||||||
async def query(request: web.Request) -> web.Response:
|
async def query(request: web.Request) -> web.Response:
|
||||||
instance_id = request.match_info.get("id", "")
|
instance_id = request.match_info["id"].lower()
|
||||||
instance = PluginInstance.get(instance_id, None)
|
instance = await PluginInstance.get(instance_id)
|
||||||
if not instance:
|
if not instance:
|
||||||
return resp.instance_not_found
|
return resp.instance_not_found
|
||||||
elif not instance.inst_db:
|
elif not instance.inst_db:
|
||||||
|
@ -23,7 +23,7 @@ import logging
|
|||||||
from aiohttp import web, web_ws
|
from aiohttp import web, web_ws
|
||||||
|
|
||||||
from .auth import is_valid_token
|
from .auth import is_valid_token
|
||||||
from .base import get_loop, routes
|
from .base import routes
|
||||||
|
|
||||||
BUILTIN_ATTRS = {
|
BUILTIN_ATTRS = {
|
||||||
"args",
|
"args",
|
||||||
@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
|
|||||||
authenticated = False
|
authenticated = False
|
||||||
|
|
||||||
async def close_if_not_authenticated():
|
async def close_if_not_authenticated():
|
||||||
await asyncio.sleep(5, loop=get_loop())
|
await asyncio.sleep(5)
|
||||||
if not authenticated:
|
if not authenticated:
|
||||||
await ws.close(code=4000)
|
await ws.close(code=4000)
|
||||||
log.debug(f"Connection from {request.remote} terminated due to no authentication")
|
log.debug(f"Connection from {request.remote} terminated due to no authentication")
|
||||||
|
|
||||||
asyncio.ensure_future(close_if_not_authenticated())
|
asyncio.create_task(close_if_not_authenticated())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg: web_ws.WSMessage
|
msg: web_ws.WSMessage
|
||||||
|
@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
|
|||||||
|
|
||||||
@routes.get("/plugin/{id}")
|
@routes.get("/plugin/{id}")
|
||||||
async def get_plugin(request: web.Request) -> web.Response:
|
async def get_plugin(request: web.Request) -> web.Response:
|
||||||
plugin_id = request.match_info.get("id", None)
|
plugin_id = request.match_info["id"]
|
||||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
return resp.plugin_not_found
|
return resp.plugin_not_found
|
||||||
return resp.found(plugin.to_dict())
|
return resp.found(plugin.to_dict())
|
||||||
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.delete("/plugin/{id}")
|
@routes.delete("/plugin/{id}")
|
||||||
async def delete_plugin(request: web.Request) -> web.Response:
|
async def delete_plugin(request: web.Request) -> web.Response:
|
||||||
plugin_id = request.match_info.get("id", None)
|
plugin_id = request.match_info["id"]
|
||||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
return resp.plugin_not_found
|
return resp.plugin_not_found
|
||||||
elif len(plugin.references) > 0:
|
elif len(plugin.references) > 0:
|
||||||
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
@routes.post("/plugin/{id}/reload")
|
@routes.post("/plugin/{id}/reload")
|
||||||
async def reload_plugin(request: web.Request) -> web.Response:
|
async def reload_plugin(request: web.Request) -> web.Response:
|
||||||
plugin_id = request.match_info.get("id", None)
|
plugin_id = request.match_info["id"]
|
||||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
return resp.plugin_not_found
|
return resp.plugin_not_found
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ from .responses import resp
|
|||||||
|
|
||||||
@routes.put("/plugin/{id}")
|
@routes.put("/plugin/{id}")
|
||||||
async def put_plugin(request: web.Request) -> web.Response:
|
async def put_plugin(request: web.Request) -> web.Response:
|
||||||
plugin_id = request.match_info.get("id", None)
|
plugin_id = request.match_info["id"]
|
||||||
content = await request.read()
|
content = await request.read()
|
||||||
file = BytesIO(content)
|
file = BytesIO(content)
|
||||||
try:
|
try:
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Awaitable
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
|
|
||||||
@ -124,6 +124,7 @@ class Plugin(ABC):
|
|||||||
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_external_config_update(self) -> None:
|
def on_external_config_update(self) -> Awaitable[None] | None:
|
||||||
if self.config:
|
if self.config:
|
||||||
self.config.load_and_update()
|
self.config.load_and_update()
|
||||||
|
return None
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
# Format: #/name defines a new extras_require group called name
|
# Format: #/name defines a new extras_require group called name
|
||||||
# Uncommented lines after the group definition insert things into that group.
|
# Uncommented lines after the group definition insert things into that group.
|
||||||
|
|
||||||
#/postgres
|
#/sqlite
|
||||||
psycopg2-binary>=2,<3
|
aiosqlite>=0.16,<0.18
|
||||||
asyncpg>=0.20,<0.26
|
|
||||||
|
|
||||||
#/encryption
|
#/encryption
|
||||||
asyncpg>=0.20,<0.26
|
|
||||||
aiosqlite>=0.16,<0.18
|
|
||||||
python-olm>=3,<4
|
python-olm>=3,<4
|
||||||
pycryptodome>=3,<4
|
pycryptodome>=3,<4
|
||||||
unpaddedbase64>=1,<3
|
unpaddedbase64>=1,<3
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
mautrix>=0.15.0,<0.16
|
mautrix>=0.15.2,<0.16
|
||||||
aiohttp>=3,<4
|
aiohttp>=3,<4
|
||||||
yarl>=1,<2
|
yarl>=1,<2
|
||||||
SQLAlchemy>=1,<1.4
|
SQLAlchemy>=1,<1.4
|
||||||
|
asyncpg>=0.20,<0.26
|
||||||
alembic>=1,<2
|
alembic>=1,<2
|
||||||
commonmark>=0.9,<1
|
commonmark>=0.9,<1
|
||||||
ruamel.yaml>=0.15.35,<0.18
|
ruamel.yaml>=0.15.35,<0.18
|
||||||
|
5
setup.py
5
setup.py
@ -1,5 +1,4 @@
|
|||||||
import setuptools
|
import setuptools
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
with open("requirements.txt") as reqs:
|
with open("requirements.txt") as reqs:
|
||||||
@ -57,9 +56,7 @@ setuptools.setup(
|
|||||||
mbc=maubot.cli:app
|
mbc=maubot.cli:app
|
||||||
""",
|
""",
|
||||||
data_files=[
|
data_files=[
|
||||||
(".", ["maubot/example-config.yaml", "alembic.ini"]),
|
(".", ["maubot/example-config.yaml"]),
|
||||||
("alembic", ["alembic/env.py"]),
|
|
||||||
("alembic/versions", glob.glob("alembic/versions/*.py")),
|
|
||||||
],
|
],
|
||||||
package_data={
|
package_data={
|
||||||
"maubot": [
|
"maubot": [
|
||||||
|
Loading…
Reference in New Issue
Block a user