Compare commits

..

1 Commits

Author SHA1 Message Date
Tulir Asokan
ab88568d08 Add db table for plugin files 2020-01-02 02:03:53 +02:00
153 changed files with 11561 additions and 11884 deletions

View File

@ -1,5 +0,0 @@
.editorconfig
.codeclimate.yml
*.png
.venv
maubot/management/frontend/node_modules

View File

@ -14,6 +14,3 @@ indent_size = 2
[spec.yaml] [spec.yaml]
indent_size = 2 indent_size = 2
[CHANGELOG.md]
max_line_length = 80

View File

@ -1,26 +0,0 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- uses: isort/isort-action@master
with:
sortPaths: "./maubot"
- uses: psf/black@stable
with:
src: "./maubot"
version: "24.2.0"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-yaml
pre-commit run -av check-added-large-files

5
.gitignore vendored
View File

@ -7,13 +7,10 @@ pip-selfcheck.json
*.pyc *.pyc
__pycache__ __pycache__
*.db* *.db
*.log
/*.yaml /*.yaml
!example-config.yaml !example-config.yaml
!.pre-commit-config.yaml
/start
logs/ logs/
plugins/ plugins/
trash/ trash/

View File

@ -1,29 +0,0 @@
image: dock.mau.dev/maubot/maubot
stages:
- build
variables:
PYTHONPATH: /opt/maubot
build:
stage: build
except:
- tags
script:
- python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_REF_NAME-$CI_COMMIT_SHORT_SHA.mbp
artifacts:
paths:
- "*.mbp"
expire_in: 365 days
build tags:
stage: build
only:
- tags
script:
- python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_TAG.mbp
artifacts:
paths:
- "*.mbp"
expire_in: never

View File

@ -1,75 +1,67 @@
image: docker:stable image: docker:stable
stages: stages:
- build frontend
- build - build
- manifest - push
default: default:
before_script: before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
build frontend: build:
image: node:20-alpine stage: build
stage: build frontend script:
before_script: [] - docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA .
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
push latest:
stage: push
only:
- master
variables: variables:
NODE_ENV: "production" GIT_STRATEGY: none
cache:
paths:
- maubot/management/frontend/node_modules
script: script:
- cd maubot/management/frontend - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
- yarn --prod - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA $CI_REGISTRY_IMAGE:latest
- yarn build - docker push $CI_REGISTRY_IMAGE:latest
- mv build ../../../frontend
artifacts:
paths:
- frontend
expire_in: 1 hour
build amd64: push tag:
stage: push
variables:
GIT_STRATEGY: none
except:
- master
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME
build standalone:
stage: build stage: build
tags:
- amd64
script:
- echo maubot/management/frontend >> .dockerignore
- docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 . -f Dockerfile.ci
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
build arm64:
stage: build
tags:
- arm64
script:
- echo maubot/management/frontend >> .dockerignore
- docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 . -f Dockerfile.ci
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
manifest:
stage: manifest
before_script:
- "mkdir -p $HOME/.docker && echo '{\"experimental\": \"enabled\"}' > $HOME/.docker/config.json"
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
- if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi
- if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
build standalone amd64:
stage: build
tags:
- amd64
script: script:
- docker pull $CI_REGISTRY_IMAGE:standalone || true - docker pull $CI_REGISTRY_IMAGE:standalone || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:standalone --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone . -f maubot/standalone/Dockerfile - docker build --pull --cache-from $CI_REGISTRY_IMAGE:standalone --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone . -f maubot/standalone/Dockerfile
- if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone
- if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi
- docker rmi $CI_REGISTRY_IMAGE:standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone || true push latest standalone:
stage: push
only:
- master
variables:
GIT_STRATEGY: none
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone
- docker push $CI_REGISTRY_IMAGE:standalone
push tag standalone:
stage: push
variables:
GIT_STRATEGY: none
except:
- master
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone
- docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone

View File

@ -1,20 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
language_version: python3
files: ^maubot/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
files: ^maubot/.*\.pyi?$

View File

@ -1,137 +0,0 @@
# v0.5.0 (2024-08-24)
* Dropped Python 3.9 support.
* Updated Docker image to Alpine 3.20.
* Updated mautrix-python to 0.20.6 to support authenticated media.
* Removed hard dependency on SQLAlchemy.
* Fixed `main_class` to default to being loaded from the last module instead of
the first if a module name is not explicitly specified.
* This was already the [documented behavior](https://docs.mau.fi/maubot/dev/reference/plugin-metadata.html),
and loading from the first module doesn't make sense due to import order.
* Added simple scheduler utility for running background tasks periodically or
after a certain delay.
* Added testing framework for plugins (thanks to [@abompard] in [#225]).
* Changed `mbc build` to ignore directories declared in `modules` that are
missing an `__init__.py` file.
* Importing the modules at runtime would fail and break the plugin.
To include non-code resources outside modules in the mbp archive,
use `extra_files` instead.
[#225]: https://github.com/maubot/maubot/issues/225
[@abompard]: https://github.com/abompard
# v0.4.2 (2023-09-20)
* Updated Pillow to 10.0.1.
* Updated Docker image to Alpine 3.18.
* Added logging for errors for /whoami errors when adding new bot accounts.
* Added support for using appservice tokens (including appservice encryption)
in standalone mode.
# v0.4.1 (2023-03-15)
* Added `in_thread` parameter to `evt.reply()` and `evt.respond()`.
* By default, responses will go to the thread if the command is in a thread.
* By setting the flag to `True` or `False`, the plugin can force the response
to either be or not be in a thread.
* Fixed static files like the frontend app manifest not being served correctly.
* Fixed `self.loader.meta` not being available to plugins in standalone mode.
* Updated to mautrix-python v0.19.6.
# v0.4.0 (2023-01-29)
* Dropped support for using a custom maubot API base path.
* The public URL can still have a path prefix, e.g. when using a reverse
proxy. Both the web interface and `mbc` CLI tool should work fine with
custom prefixes.
* Added `evt.redact()` as a shortcut for `self.client.redact(evt.room_id, evt.event_id)`.
* Fixed `mbc logs` command not working on Python 3.8+.
* Fixed saving plugin configs (broke in v0.3.0).
* Fixed SSO login using the wrong API path (probably broke in v0.3.0).
* Stopped using `cd` in the docker image's `mbc` wrapper to enable using
path-dependent commands like `mbc build` by mounting a directory.
* Updated Docker image to Alpine 3.17.
# v0.3.1 (2022-03-29)
* Added encryption dependencies to standalone dockerfile.
* Fixed running without encryption dependencies installed.
* Removed unnecessary imports that broke on SQLAlchemy 1.4+.
* Removed unused alembic dependency.
# v0.3.0 (2022-03-28)
* Dropped Python 3.7 support.
* Switched main maubot database to asyncpg/aiosqlite.
* Using the same SQLite database for crypto is now safe again.
* Added support for asyncpg/aiosqlite for plugin databases.
* There are some [basic docs](https://docs.mau.fi/maubot/dev/database/index.html)
and [a simple example](./examples/database) for the new system.
* The old SQLAlchemy system is now deprecated, but will be preserved for
backwards-compatibility until most plugins have updated.
* Started enforcing minimum maubot version in plugins.
* Trying to upload a plugin where the specified version is higher than the
running maubot version will fail.
* Fixed bug where uploading a plugin twice, deleting it and trying to upload
again would fail.
* Updated Docker image to Alpine 3.15.
* Formatted all code using [black](https://github.com/psf/black)
and [isort](https://github.com/PyCQA/isort).
# v0.2.1 (2021-11-22)
Docker-only release: added automatic moving of plugin databases from
`/data/plugins/*.db` to `/data/dbs`
# v0.2.0 (2021-11-20)
* Moved plugin databases from `/data/plugins` to `/data/dbs` in the docker image.
* v0.2.0 was missing the automatic migration of databases, it was added in v0.2.1.
* If you were using a custom path, you'll have to mount it at `/data/dbs` or
move the databases yourself.
* Removed support for pickle crypto store and added support for SQLite crypto store.
* **If you were previously using the dangerous pickle store for e2ee, you'll
have to re-login with the bots (which can now be done conveniently with
`mbc auth --update-client`).**
* Added SSO support to `mbc auth`.
* Added support for setting device ID for e2ee using the web interface.
* Added e2ee fingerprint field to the web interface.
* Added `--update-client` flag to store access token inside maubot instead of
returning it in `mbc auth`.
* This will also automatically store the device ID now.
* Updated standalone mode.
* Added e2ee and web server support.
* It's now officially supported and [somewhat documented](https://docs.mau.fi/maubot/usage/standalone.html).
* Replaced `_` with `-` when generating command name from function name.
* Replaced unmaintained PyInquirer dependency with questionary
(thanks to [@TinfoilSubmarine] in [#139]).
* Updated Docker image to Alpine 3.14.
* Fixed avatar URLs without the `mxc://` prefix appearing like they work in the
frontend, but not actually working when saved.
[@TinfoilSubmarine]: https://github.com/TinfoilSubmarine
[#139]: https://github.com/maubot/maubot/pull/139
# v0.1.2 (2021-06-12)
* Added `loader` instance property for plugins to allow reading files within
the plugin archive.
* Added support for reloading `webapp` and `database` meta flags in plugins.
Previously you had to restart maubot instead of just reloading the plugin
when enabling the webapp or database for the first time.
* Added warning log if a plugin uses `@web` decorators without enabling the
`webapp` meta flag.
* Updated frontend to latest React and dependency versions.
* Updated Docker image to Alpine 3.13.
* Fixed registering accounts with Synapse shared secret registration.
* Fixed plugins using `get_event` in encrypted rooms.
* Fixed using the `@command.new` decorator without specifying a name
(i.e. falling back to the function name).
# v0.1.1 (2021-05-02)
No changelog.
# v0.1.0 (2020-10-04)
Initial tagged release.

View File

@ -1,60 +1,37 @@
FROM node:20 AS frontend-builder FROM node:12 AS frontend-builder
COPY ./maubot/management/frontend /frontend COPY ./maubot/management/frontend /frontend
RUN cd /frontend && yarn --prod && yarn build RUN cd /frontend && yarn --prod && yarn build
FROM alpine:3.20 FROM alpine:3.10
RUN apk add --no-cache \ ENV UID=1337 \
python3 py3-pip py3-setuptools py3-wheel \ GID=1337
ca-certificates \
su-exec \
yq \
py3-aiohttp \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
py3-packaging \
py3-markdown \
py3-alembic \
py3-cssselect \
py3-commonmark \
py3-pygments \
py3-tz \
py3-regex \
py3-wcwidth \
# encryption
py3-cffi \
py3-olm \
py3-pycryptodome \
py3-unpaddedbase64 \
py3-future \
# plugin deps
py3-pillow \
py3-magic \
py3-feedparser \
py3-dateutil \
py3-lxml \
py3-semver
# TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies
COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery tzlocal \
&& apk del .build-deps
# TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies
COPY . /opt/maubot COPY . /opt/maubot
RUN cp maubot/example-config.yaml .
COPY ./docker/mbc.sh /usr/local/bin/mbc
COPY --from=frontend-builder /frontend/build /opt/maubot/frontend COPY --from=frontend-builder /frontend/build /opt/maubot/frontend
ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data WORKDIR /opt/maubot
RUN apk add --no-cache \
py3-aiohttp \
py3-sqlalchemy \
py3-attrs \
py3-bcrypt \
py3-cffi \
build-base \
python3-dev \
ca-certificates \
su-exec \
py3-pillow \
py3-magic \
py3-psycopg2 \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
py3-packaging \
py3-markdown \
&& pip3 install -r requirements.txt feedparser dateparser langdetect
# TODO remove pillow, magic and feedparser when maubot supports installing dependencies
VOLUME /data VOLUME /data
CMD ["/opt/maubot/docker/run.sh"] CMD ["/opt/maubot/docker/run.sh"]

View File

@ -1,55 +0,0 @@
FROM alpine:3.20
RUN apk add --no-cache \
python3 py3-pip py3-setuptools py3-wheel \
ca-certificates \
su-exec \
yq \
py3-aiohttp \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
py3-packaging \
py3-markdown \
py3-alembic \
# py3-cssselect \
py3-commonmark \
py3-pygments \
py3-tz \
# py3-tzlocal \
py3-regex \
py3-wcwidth \
# encryption
py3-cffi \
py3-olm \
py3-pycryptodome \
py3-unpaddedbase64 \
py3-future \
# plugin deps
py3-pillow \
py3-magic \
py3-feedparser \
py3-lxml
# py3-gitlab
# py3-semver
# TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies
COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery semver tzlocal cssselect \
&& apk del .build-deps
# TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies
COPY . /opt/maubot
RUN cp /opt/maubot/maubot/example-config.yaml /opt/maubot
COPY ./docker/mbc.sh /usr/local/bin/mbc
ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data
VOLUME /data
CMD ["/opt/maubot/docker/run.sh"]

View File

@ -1,5 +0,0 @@
include README.md
include CHANGELOG.md
include LICENSE
include requirements.txt
include optional-requirements.txt

View File

@ -1,29 +1,37 @@
# maubot # maubot
![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg)
[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE)
[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases)
[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry)
[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
A plugin-based [Matrix](https://matrix.org) bot system written in Python. A plugin-based [Matrix](https://matrix.org) bot system written in Python.
## Documentation ### [Wiki](https://github.com/maubot/maubot/wiki)
All setup and usage instructions are located on ### [Management API spec](https://github.com/maubot/maubot/blob/master/maubot/management/api/spec.md)
[docs.mau.fi](https://docs.mau.fi/maubot/index.html). Some quick links:
* [Setup](https://docs.mau.fi/maubot/usage/setup/index.html)
(or [with Docker](https://docs.mau.fi/maubot/usage/setup/docker.html))
* [Basic usage](https://docs.mau.fi/maubot/usage/basic.html)
* [Encryption](https://docs.mau.fi/maubot/usage/encryption.html)
## Discussion ## Discussion
Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net) Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net)
## Plugins ## Plugins
A list of plugins can be found at [plugins.mau.bot](https://plugins.mau.bot/). * [jesaribot](https://github.com/maubot/jesaribot) - A simple bot that replies with an image when you say "jesari".
* [sed](https://github.com/maubot/sed) - A bot to do sed-like replacements.
* [factorial](https://github.com/maubot/factorial) - A bot to calculate unexpected factorials.
* [media](https://github.com/maubot/media) - A bot that replies with the MXC URI of images you send it.
* [dice](https://github.com/maubot/dice) - A combined dice rolling and calculator bot.
* [karma](https://github.com/maubot/karma) - A user karma tracker bot.
* [xkcd](https://github.com/maubot/xkcd) - A bot to view xkcd comics.
* [echo](https://github.com/maubot/echo) - A bot that echoes pings and other stuff.
* [rss](https://github.com/maubot/rss) - A bot that posts RSS feed updates to Matrix.
* [reddit](https://github.com/TomCasavant/RedditMaubot) - A bot that condescendingly corrects a user when they enter an r/subreddit without providing a link to that subreddit
* [giphy](https://github.com/TomCasavant/GiphyMaubot) - A bot that generates a gif (from giphy) given search terms
* [trump](https://github.com/jeffcasavant/MaubotTrumpTweet) - A bot that generates a Trump tweet with the given content
* [poll](https://github.com/TomCasavant/PollMaubot) - A bot that will create a simple poll for users in a room
* [urban](https://github.com/dvdgsng/UrbanMaubot) - A bot that fetches definitions from [Urban Dictionary](https://www.urbandictionary.com/).
* [reminder](https://github.com/maubot/reminder) - A bot to remind you about things.
* [translate](https://github.com/maubot/translate) - A bot to translate words.
* [reactbot](https://github.com/maubot/reactbot) - A bot that responds to messages that match predefined rules.
* [exec](https://github.com/maubot/exec) - A bot that executes code.
* [commitstrip](https://github.com/maubot/commitstrip) - A bot to view CommitStrips.
* [supportportal](https://github.com/maubot/supportportal) - A bot to manage customer support on Matrix.
* [gitlab](https://github.com/maubot/gitlab) - A GitLab client and webhook receiver.
* [github](https://github.com/maubot/github) - A GitHub client and webhook receiver.
To add your plugin to the list, send a pull request to <https://github.com/maubot/plugins.maubot.xyz>. Open a pull request or join the Matrix room linked above to get your plugin listed here
The plugin wishlist lives at <https://github.com/maubot/plugin-wishlist/issues>. The plugin wishlist lives at https://github.com/maubot/plugin-wishlist/issues

83
alembic.ini Normal file
View File

@ -0,0 +1,83 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks=black
# black.type=console_scripts
# black.entrypoint=black
# black.options=-l 79
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@ -0,0 +1 @@
Generic single-database configuration.

90
alembic/env.py Normal file
View File

@ -0,0 +1,90 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import sys
from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from mautrix.util.db import Base
from maubot.config import Config
from maubot import db
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
maubot_config = Config(maubot_config_path, None)
maubot_config.load()
config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

24
alembic/script.py.mako Normal file
View File

@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,40 @@
"""Let plugins have multiple files
Revision ID: 6b66c1600d16
Revises: d295f8dcfa64
Create Date: 2020-01-02 01:30:51.622962
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6b66c1600d16"
down_revision = "d295f8dcfa64"
branch_labels = None
depends_on = None
def upgrade():
plugin_file: sa.Table = op.create_table(
"plugin_file",
sa.Column("plugin_id", sa.String(length=255), nullable=False),
sa.Column("file_name", sa.String(length=255), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.ForeignKeyConstraint(("plugin_id",), ["plugin.id"], onupdate="CASCADE",
ondelete="CASCADE"),
sa.PrimaryKeyConstraint("plugin_id", "file_name"))
conn: sa.engine.Connection = op.get_bind()
conn.execute(plugin_file.insert().values([{
"plugin_id": plugin_id,
"file_name": "config.yaml",
"content": config
} for plugin_id, config in conn.execute("SELECT id, config FROM plugin").fetchall()]))
op.drop_column("plugin", "config")
def downgrade():
op.add_column("plugin", sa.Column("config", sa.TEXT(), autoincrement=False, nullable=False))
op.drop_table("plugin_file")

View File

@ -0,0 +1,50 @@
"""Initial revision
Revision ID: d295f8dcfa64
Revises:
Create Date: 2019-09-27 00:21:02.527915
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd295f8dcfa64'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('client',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('homeserver', sa.String(length=255), nullable=False),
sa.Column('access_token', sa.Text(), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('next_batch', sa.String(length=255), nullable=False),
sa.Column('filter_id', sa.String(length=255), nullable=False),
sa.Column('sync', sa.Boolean(), nullable=False),
sa.Column('autojoin', sa.Boolean(), nullable=False),
sa.Column('displayname', sa.String(length=255), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('plugin',
sa.Column('id', sa.String(length=255), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('primary_user', sa.String(length=255), nullable=False),
sa.Column('config', sa.Text(), nullable=False),
sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('plugin')
op.drop_table('client')
# ### end Alembic commands ###

View File

@ -1,3 +0,0 @@
pre-commit>=2.10.1,<3
isort>=5.10.1,<6
black>=24,<25

View File

@ -0,0 +1,96 @@
# The full URI to the database. SQLite and Postgres are fully supported.
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:////data/maubot.db
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: /data/plugins
# The directories from which plugins should be loaded.
# Duplicate plugin IDs will be moved to the trash.
load:
- /data/plugins
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: /data/trash
# The directory where plugin databases should be stored.
db: /data/plugins
server:
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# Public base URL where the server is visible.
public_url: https://example.com
# The base management API path.
base_path: /_matrix/maubot/v1
# The base path for the UI.
ui_base_path: /_matrix/maubot
# The base path for plugin endpoints. The instance ID will be appended directly.
plugin_base_path: /_matrix/maubot/plugin/
# Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path.
override_resource_path: /opt/maubot/frontend
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
appservice_base_path: /_matrix/app/v1
# The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# Shared registration secrets to allow registering new users from the management UI
registration_secrets:
example.com:
# Client-server API URL
url: https://example.com
# registration_shared_secret from synapse config
secret: synapse_shared_registration_secret
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
root: ""
# API feature switches.
api_features:
login: true
plugin: true
plugin_upload: true
instance: true
instance_database: true
client: true
client_proxy: true
client_auth: true
dev_open: true
log: true
# Python logging configuration.
#
# See section 16.7.2 of the Python documentation for more info:
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
logging:
version: 1
formatters:
precise:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: /var/log/maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: precise
loggers:
maubot:
level: DEBUG
mautrix:
level: DEBUG
aiohttp:
level: INFO
root:
level: DEBUG
handlers: [file, console]

View File

@ -1,3 +0,0 @@
#!/bin/sh
export PYTHONPATH=/opt/maubot
python3 -m maubot.cli "$@"

View File

@ -1,46 +1,21 @@
#!/bin/sh #!/bin/sh
function fixperms { function fixperms {
chown -R $UID:$GID /var/log /data chown -R $UID:$GID /var/log /data /opt/maubot
}
function fixdefault {
_value=$(yq e "$1" /data/config.yaml)
if [[ "$_value" == "$2" ]]; then
yq e -i "$1 = "'"'"$3"'"' /data/config.yaml
fi
}
function fixconfig {
# Change relative default paths to absolute paths in /data
fixdefault '.database' 'sqlite:maubot.db' 'sqlite:/data/maubot.db'
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
# This doesn't need to be configurable
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
} }
cd /opt/maubot cd /opt/maubot
mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs
if [ ! -f /data/config.yaml ]; then if [ ! -f /data/config.yaml ]; then
cp example-config.yaml /data/config.yaml cp docker/example-config.yaml /data/config.yaml
mkdir -p /var/log /data/plugins /data/trash /data/dbs
echo "Config file not found. Example config copied to /data/config.yaml" echo "Config file not found. Example config copied to /data/config.yaml"
echo "Please modify the config file to your liking and restart the container." echo "Please modify the config file to your liking and restart the container."
fixperms fixperms
fixconfig
exit exit
fi fi
mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs
alembic -x config=/data/config.yaml upgrade head
fixperms fixperms
fixconfig exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml -b docker/example-config.yaml
if ls /data/plugins/*.db > /dev/null 2>&1; then
mv -n /data/plugins/*.db /data/dbs/
fi
exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml

View File

@ -1,21 +1,10 @@
# The full URI to the database. SQLite and Postgres are fully supported. # The full URI to the database. SQLite and Postgres are fully supported.
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples: # Format examples:
# SQLite: sqlite:filename.db # SQLite: sqlite:///filename.db
# Postgres: postgresql://username:password@hostname/dbname # Postgres: postgres://username:password@hostname/dbname
database: sqlite:maubot.db database: sqlite:///maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above.
crypto_database: default
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
database_opts:
min_size: 1
max_size: 10
# Configuration for storing plugin .mbp files
plugin_directories: plugin_directories:
# The directory where uploaded new plugins should be stored. # The directory where uploaded new plugins should be stored.
upload: ./plugins upload: ./plugins
@ -26,27 +15,8 @@ plugin_directories:
# The directory where old plugin versions and conflicting plugins should be moved. # The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately. # Set to "delete" to delete files immediately.
trash: ./trash trash: ./trash
# The directory where plugin databases should be stored.
# Configuration for storing plugin databases db: ./plugins
plugin_databases:
# The directory where SQLite plugin databases should be stored.
sqlite: ./plugins
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
#
# To use the same connection pool as the default database, set to "default"
# (the default database above must be postgres to do this).
#
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
# SQL queries/psql commands.
postgres: null
# Maximum number of connections per plugin instance.
postgres_max_conns_per_plugin: 3
# Overrides for the default database_opts when using a non-"default" postgres connection string.
postgres_opts: {}
server: server:
# The IP and port to listen to. # The IP and port to listen to.
@ -54,6 +24,8 @@ server:
port: 29316 port: 29316
# Public base URL where the server is visible. # Public base URL where the server is visible.
public_url: https://example.com public_url: https://example.com
# The base management API path.
base_path: /_matrix/maubot/v1
# The base path for the UI. # The base path for the UI.
ui_base_path: /_matrix/maubot ui_base_path: /_matrix/maubot
# The base path for plugin endpoints. The instance ID will be appended directly. # The base path for plugin endpoints. The instance ID will be appended directly.
@ -61,22 +33,19 @@ server:
# Override path from where to load UI resources. # Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path. # Set to false to using pkg_resources to find the path.
override_resource_path: false override_resource_path: false
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
appservice_base_path: /_matrix/app/v1
# The shared secret to sign API access tokens. # The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup. # Set to "generate" to generate and save a new token at startup.
unshared_secret: generate unshared_secret: generate
# Known homeservers. This is required for the `mbc auth` command and also allows # Shared registration secrets to allow registering new users from the management UI
# more convenient access from the management UI. This is not required to create registration_secrets:
# clients in the management UI, since you can also just type the homeserver URL example.com:
# into the box there.
homeservers:
matrix.org:
# Client-server API URL # Client-server API URL
url: https://matrix-client.matrix.org url: https://example.com
# registration_shared_secret from synapse config # registration_shared_secret from synapse config
# You can leave this empty if you don't have access to the homeserver. secret: synapse_shared_registration_secret
# When this is empty, `mbc auth --register` won't work, but `mbc auth` (login) will.
secret: null
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password # List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist. # to prevent normal login. Root is a special user that can't have a password and will always exist.
@ -121,7 +90,7 @@ logging:
loggers: loggers:
maubot: maubot:
level: DEBUG level: DEBUG
mau: mautrix:
level: DEBUG level: DEBUG
aiohttp: aiohttp:
level: INFO level: INFO

View File

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2022 Tulir Asokan Copyright (c) 2018 Tulir Asokan
Permission is hereby granted, free of charge, to any person obtaining a copy of Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in this software and associated documentation files (the "Software"), to deal in

View File

@ -4,4 +4,3 @@ All examples are published under the [MIT license](LICENSE).
* [Hello World](helloworld/) - Very basic event handling bot that responds "Hello, World!" to all messages. * [Hello World](helloworld/) - Very basic event handling bot that responds "Hello, World!" to all messages.
* [Echo bot](https://github.com/maubot/echo) - Basic command handling bot with !echo and !ping commands * [Echo bot](https://github.com/maubot/echo) - Basic command handling bot with !echo and !ping commands
* [Config example](config/) - Simple example of using a config file * [Config example](config/) - Simple example of using a config file
* [Database example](database/) - Simple example of using a database

View File

@ -1,5 +1,2 @@
# Who is allowed to use the bot? # Message to send when user sends !getmessage
whitelist: message: Default configuration active
- "@user:example.com"
# The prefix for the main command without the !
command_prefix: hello-world

View File

@ -1,4 +1,5 @@
from typing import Type from typing import Type
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from maubot import Plugin, MessageEvent from maubot import Plugin, MessageEvent
from maubot.handlers import command from maubot.handlers import command
@ -6,22 +7,19 @@ from maubot.handlers import command
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("whitelist") helper.copy("message")
helper.copy("command_prefix")
class ConfigurableBot(Plugin): class DatabaseBot(Plugin):
async def start(self) -> None: async def start(self) -> None:
await super().start()
self.config.load_and_update() self.config.load_and_update()
def get_command_name(self) -> str:
return self.config["command_prefix"]
@command.new(name=get_command_name)
async def hmm(self, evt: MessageEvent) -> None:
if evt.sender in self.config["whitelist"]:
await evt.reply("You're whitelisted 🎉")
@classmethod @classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config return Config
@command.new("getmessage")
async def handler(self, event: MessageEvent) -> None:
if event.sender != self.client.mxid:
await event.reply(self.config["message"])

View File

@ -1,12 +1,11 @@
maubot: 0.1.0 maubot: 0.1.0
id: xyz.maubot.configurablebot id: xyz.maubot.databasebot
version: 2.0.0 version: 1.0.0
license: MIT license: MIT
modules: modules:
- configurablebot - configurablebot
main_class: ConfigurableBot main_class: ConfigurableBot
database: false database: false
config: true
# Instruct the build tool to include the base config. # Instruct the build tool to include the base config.
extra_files: extra_files:

View File

@ -1,10 +0,0 @@
maubot: 0.1.0
id: xyz.maubot.storagebot
version: 2.0.0
license: MIT
modules:
- storagebot
main_class: StorageBot
database: true
database_type: asyncpg
config: false

View File

@ -1,72 +0,0 @@
from __future__ import annotations
from mautrix.util.async_db import UpgradeTable, Connection
from maubot import Plugin, MessageEvent
from maubot.handlers import command
upgrade_table = UpgradeTable()
@upgrade_table.register(description="Initial revision")
async def upgrade_v1(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE stored_data (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)"""
)
@upgrade_table.register(description="Remember user who added value")
async def upgrade_v2(conn: Connection) -> None:
await conn.execute("ALTER TABLE stored_data ADD COLUMN creator TEXT")
class StorageBot(Plugin):
@command.new()
async def storage(self, evt: MessageEvent) -> None:
pass
@storage.subcommand(help="Store a value")
@command.argument("key")
@command.argument("value", pass_raw=True)
async def put(self, evt: MessageEvent, key: str, value: str) -> None:
q = """
INSERT INTO stored_data (key, value, creator) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value=excluded.value, creator=excluded.creator
"""
await self.database.execute(q, key, value, evt.sender)
await evt.reply(f"Inserted {key} into the database")
@storage.subcommand(help="Get a value from the storage")
@command.argument("key")
async def get(self, evt: MessageEvent, key: str) -> None:
q = "SELECT key, value, creator FROM stored_data WHERE LOWER(key)=LOWER($1)"
row = await self.database.fetchrow(q, key)
if row:
key = row["key"]
value = row["value"]
creator = row["creator"]
await evt.reply(f"`{key}` stored by {creator}:\n\n```\n{value}\n```")
else:
await evt.reply(f"No data stored under `{key}` :(")
@storage.subcommand(help="List keys in the storage")
@command.argument("prefix", required=False)
async def list(self, evt: MessageEvent, prefix: str | None) -> None:
q = "SELECT key, creator FROM stored_data WHERE key LIKE $1"
rows = await self.database.fetch(q, prefix + "%")
prefix_reply = f" starting with `{prefix}`" if prefix else ""
if len(rows) == 0:
await evt.reply(f"Nothing{prefix_reply} stored in database :(")
else:
formatted_data = "\n".join(
f"* `{row['key']}` stored by {row['creator']}" for row in rows
)
await evt.reply(
f"Found {len(rows)} keys{prefix_reply} in database:\n\n{formatted_data}"
)
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return upgrade_table

View File

@ -1,4 +1,3 @@
from .__meta__ import __version__
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .plugin_base import Plugin from .plugin_base import Plugin
from .plugin_server import PluginWebApp from .plugin_server import PluginWebApp
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,171 +13,82 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations import logging.config
import argparse
import asyncio import asyncio
import signal
import copy
import sys import sys
from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
from mautrix.util.program import Program
from .__meta__ import __version__
from .client import Client
from .config import Config from .config import Config
from .db import init as init_db, upgrade_table from .db import init as init_db
from .instance import PluginInstance
from .lib.future_awaitable import FutureAwaitable
from .lib.state_store import PgStateStore
from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .server import MaubotServer from .server import MaubotServer
from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args()
config = Config(args.config, args.base_config)
config.load()
config.update()
logging.config.dictConfig(copy.deepcopy(config["logging"]))
stop_log_listener = None
if config["api_features.log"]:
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
init_log_listener()
log = logging.getLogger("maubot.init")
log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
init_zip_loader(config)
db_engine = init_db(config)
clients = init_client_class(loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
for plugin in plugins:
plugin.load()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try: try:
from mautrix.crypto.store import PgCryptoStore log.info("Starting server")
except ImportError: loop.run_until_complete(server.start())
PgCryptoStore = None log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever")
class Maubot(Program): loop.run_forever()
config: Config except KeyboardInterrupt:
server: MaubotServer log.info("Interrupt received, stopping clients")
db: Database loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
crypto_db: Database | None loop=loop))
plugin_postgres_db: PostgresDatabase | None if stop_log_listener is not None:
state_store: PgStateStore log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
config_class = Config log.debug("Stopping server")
module = "maubot" try:
name = "maubot" loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
version = __version__ except asyncio.TimeoutError:
command = "python -m maubot" log.warning("Stopping server timed out")
description = "A plugin-based Matrix bot system." log.debug("Closing event loop")
loop.close()
def prepare_log_websocket(self) -> None: log.debug("Everything stopped, shutting down")
from .management.api.log import init, stop_all sys.exit(0)
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
def prepare_arg_parser(self) -> None:
super().prepare_arg_parser()
self.parser.add_argument(
"--ignore-unsupported-database",
action="store_true",
help="Run even if the database schema is too new",
)
self.parser.add_argument(
"--ignore-foreign-tables",
action="store_true",
help="Run even if the database contains tables from other programs (like Synapse)",
)
def prepare_db(self) -> None:
self.db = Database.create(
self.config["database"],
upgrade_table=upgrade_table,
db_args=self.config["database_opts"],
owner_name=self.name,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
if PgCryptoStore:
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
self.crypto_db = Database.create(
self.config["crypto_database"],
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
else:
self.crypto_db = None
if self.config["plugin_databases.postgres"] == "default":
if self.db.scheme != Scheme.POSTGRES:
self.log.critical(
'Using "default" as the postgres plugin database URL is only allowed if '
"the default database is postgres."
)
sys.exit(24)
assert isinstance(self.db, PostgresDatabase)
self.plugin_postgres_db = self.db
elif self.config["plugin_databases.postgres"]:
plugin_db = Database.create(
self.config["plugin_databases.postgres"],
db_args={
**self.config["database_opts"],
**self.config["plugin_databases.postgres_opts"],
},
)
if plugin_db.scheme != Scheme.POSTGRES:
self.log.critical("The plugin postgres database URL must be a postgres database")
sys.exit(24)
assert isinstance(plugin_db, PostgresDatabase)
self.plugin_postgres_db = plugin_db
else:
self.plugin_postgres_db = None
def prepare(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
self.prepare_db()
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
self.state_store = PgStateStore(self.db)
async def start_db(self) -> None:
self.log.debug("Starting database...")
ignore_unsupported = self.args.ignore_unsupported_database
self.db.upgrade_table.allow_unsupported = ignore_unsupported
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
try:
await self.db.start()
await self.state_store.upgrade_table.upgrade(self.db)
if self.plugin_postgres_db and self.plugin_postgres_db is not self.db:
await self.plugin_postgres_db.start()
if self.crypto_db:
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
if self.crypto_db is not self.db:
await self.crypto_db.start()
else:
await PgCryptoStore.upgrade_table.upgrade(self.db)
except DatabaseException as e:
self.log.critical("Failed to initialize database", exc_info=e)
if e.explanation:
self.log.info(e.explanation)
sys.exit(25)
async def system_exit(self) -> None:
if hasattr(self, "db"):
self.log.trace("Stopping database due to SystemExit")
await self.db.stop()
async def start(self) -> None:
await self.start_db()
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
await asyncio.gather(*[client.start() async for client in Client.all()])
await super().start()
async for plugin in PluginInstance.all():
await plugin.load()
await self.server.start()
async def stop(self) -> None:
self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
await super().stop()
self.log.debug("Stopping server")
try:
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
await self.db.stop()
Maubot().run()

View File

@ -1 +1 @@
__version__ = "0.5.0" __version__ = "0.1.0.dev30"

View File

@ -1,3 +1,2 @@
from . import app from . import app
app() app()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by

View File

@ -1,2 +1,2 @@
from .cliq import command, option from .cliq import command, option
from .validators import PathValidator, SPDXValidator, VersionValidator from .validators import SPDXValidator, VersionValidator, PathValidator

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,55 +13,20 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Any, Callable, Union, Optional
from typing import Any, Callable
import asyncio
import functools import functools
import inspect
import traceback
from colorama import Fore
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from questionary import prompt from PyInquirer import prompt
import aiohttp
import click import click
from ..base import app from ..base import app
from ..config import get_token from .validators import Required, ClickValidator
from .validators import ClickValidator, Required
def with_http(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with aiohttp.ClientSession() as sess:
try:
return await func(*args, sess=sess, **kwargs)
except aiohttp.ClientError as e:
print(f"{Fore.RED}Connection error: {e}{Fore.RESET}")
return wrapper
def with_authenticated_http(func):
@functools.wraps(func)
async def wrapper(*args, server: str, **kwargs):
server, token = get_token(server)
if not token:
return
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
try:
return await func(*args, sess=sess, server=server, **kwargs)
except aiohttp.ClientError as e:
print(f"{Fore.RED}Connection error: {e}{Fore.RESET}")
return wrapper
def command(help: str) -> Callable[[Callable], Callable]: def command(help: str) -> Callable[[Callable], Callable]:
def decorator(func) -> Callable: def decorator(func) -> Callable:
questions = getattr(func, "__inquirer_questions__", {}).copy() questions = func.__inquirer_questions__.copy()
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -70,43 +35,20 @@ def command(help: str) -> Callable[[Callable], Callable]:
continue continue
if value is not None and (questions[key]["type"] != "confirm" or value != "null"): if value is not None and (questions[key]["type"] != "confirm" or value != "null"):
questions.pop(key, None) questions.pop(key, None)
try:
required_unless = questions[key].pop("required_unless")
if isinstance(required_unless, str) and kwargs[required_unless]:
questions.pop(key)
elif isinstance(required_unless, list):
for v in required_unless:
if kwargs[v]:
questions.pop(key)
break
elif isinstance(required_unless, dict):
for k, v in required_unless.items():
if kwargs.get(v, object()) == v:
questions.pop(key)
break
except KeyError:
pass
question_list = list(questions.values()) question_list = list(questions.values())
question_list.reverse() question_list.reverse()
resp = prompt(question_list, kbi_msg="Aborted!") resp = prompt(question_list, keyboard_interrupt_msg="Aborted!")
if not resp and question_list: if not resp and question_list:
return return
kwargs = {**kwargs, **resp} kwargs = {**kwargs, **resp}
func(*args, **kwargs)
try:
res = func(*args, **kwargs)
if inspect.isawaitable(res):
asyncio.run(res)
except Exception:
print(Fore.RED + "Fatal error running command" + Fore.RESET)
traceback.print_exc()
return app.command(help=help)(wrapper) return app.command(help=help)(wrapper)
return decorator return decorator
def yesno(val: str) -> bool | None: def yesno(val: str) -> Optional[bool]:
if not val: if not val:
return None return None
elif isinstance(val, bool): elif isinstance(val, bool):
@ -120,25 +62,13 @@ def yesno(val: str) -> bool | None:
yesno.__name__ = "yes/no" yesno.__name__ = "yes/no"
def option( def option(short: str, long: str, message: str = None, help: str = None,
short: str, click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
long: str, validator: Validator = None, required: bool = False, default: str = None,
message: str = None, is_flag: bool = False, prompt: bool = True) -> Callable[[Callable], Callable]:
help: str = None,
click_type: str | Callable[[str], Any] = None,
inq_type: str = None,
validator: type[Validator] = None,
required: bool = False,
default: str | bool | None = None,
is_flag: bool = False,
prompt: bool = True,
required_unless: str | list | dict = None,
) -> Callable[[Callable], Callable]:
if not message: if not message:
message = long[2].upper() + long[3:] message = long[2].upper() + long[3:]
click_type = validator.click_type if isinstance(validator, ClickValidator) else click_type
if isinstance(validator, type) and issubclass(validator, ClickValidator):
click_type = validator.click_type
if is_flag: if is_flag:
click_type = yesno click_type = yesno
@ -149,20 +79,18 @@ def option(
if not hasattr(func, "__inquirer_questions__"): if not hasattr(func, "__inquirer_questions__"):
func.__inquirer_questions__ = {} func.__inquirer_questions__ = {}
q = { q = {
"type": ( "type": (inq_type if isinstance(inq_type, str)
inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm") else ("input" if not is_flag
), else "confirm")),
"name": long[2:], "name": long[2:],
"message": message, "message": message,
} }
if required_unless is not None:
q["required_unless"] = required_unless
if default is not None: if default is not None:
q["default"] = default q["default"] = default
if required or required_unless is not None: if required:
q["validate"] = Required(validator) q["validator"] = Required(validator)
elif validator: elif validator:
q["validate"] = validator q["validator"] = validator
func.__inquirer_questions__[long[2:]] = q func.__inquirer_questions__[long[2:]] = q
return func return func

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,9 +16,9 @@
from typing import Callable from typing import Callable
import os import os
from packaging.version import InvalidVersion, Version from packaging.version import Version, InvalidVersion
from prompt_toolkit.validation import Validator, ValidationError
from prompt_toolkit.document import Document from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
import click import click
from ..util import spdx as spdxlib from ..util import spdx as spdxlib

View File

@ -1 +1 @@
from . import auth, build, init, login, logs, upload from . import upload, build, login, init, logs, auth

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,154 +13,53 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.request import urlopen, Request
from urllib.error import HTTPError
import json import json
import webbrowser
from colorama import Fore from colorama import Fore
from yarl import URL
import aiohttp
import click import click
from ..config import get_token
from ..cliq import cliq from ..cliq import cliq
history_count: int = 10 history_count: int = 10
friendly_errors = {
"server_not_found": (
"Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty."
),
"registration_no_sso": (
"The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag."
),
}
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
async with sess.get(url) as resp:
data = await resp.json()
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in data.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
@cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required=True)
@cliq.option( @cliq.option("-u", "--username", help="The username to log in with", required=True)
"-u", "--username", help="The username to log in with", required_unless=["list", "sso"] @cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
) required=True)
@cliq.option( @cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
"-p", required=False, prompt=False)
"--password", @click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
help="The password to log in with", default=False)
inq_type="password", def auth(homeserver: str, username: str, password: str, server: str, register: bool) -> None:
required_unless=["list", "sso"], server, token = get_token(server)
) if not token:
@cliq.option(
"-s",
"--server",
help="The maubot instance to log in through",
default="",
required=False,
prompt=False,
)
@click.option(
"-r", "--register", help="Register instead of logging in", is_flag=True, default=False
)
@click.option(
"-c",
"--update-client",
help="Instead of returning the access token, create or update a client in maubot using it",
is_flag=True,
default=False,
)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option(
"-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False
)
@click.option(
"-n",
"--device-name",
help="The initial e2ee device displayname (only for login)",
default="Maubot",
required=False,
)
@cliq.with_authenticated_http
async def auth(
homeserver: str,
username: str,
password: str,
server: str,
register: bool,
list: bool,
update_client: bool,
device_name: str,
sso: bool,
sess: aiohttp.ClientSession,
) -> None:
if list:
await list_servers(server, sess)
return return
endpoint = "register" if register else "login" endpoint = "register" if register else "login"
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint req = Request(f"{server}/_matrix/maubot/v1/client/auth/{homeserver}/{endpoint}",
if update_client: headers={
url = url.update_query({"update_client": "true"}) "Authorization": f"Bearer {token}",
if sso: "Content-Type": "application/json",
url = url.update_query({"sso": "true"}) },
req_data = {"device_name": device_name} data=json.dumps({
else: "username": username,
req_data = {"username": username, "password": password, "device_name": device_name} "password": password,
}).encode("utf-8"))
async with sess.post(url, json=req_data) as resp:
if not 200 <= resp.status < 300:
await print_error(resp, is_register=register)
elif sso:
await wait_sso(resp, sess, server, homeserver)
else:
await print_response(resp, is_register=register)
async def wait_sso(
resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str
) -> None:
data = await resp.json()
sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
webbrowser.open(sso_url, autoraise=True)
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
async with sess.post(wait_url, json={}) as resp:
await print_response(resp, is_register=False)
async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> None:
if resp.status == 200:
data = await resp.json()
action = "registered" if is_register else "logged in as"
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(
f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}"
)
else:
await print_error(resp, is_register)
async def print_error(resp: aiohttp.ClientResponse, is_register: bool) -> None:
try: try:
err_data = await resp.json() with urlopen(req) as resp_data:
error = friendly_errors.get(err_data["errcode"], err_data["error"]) resp = json.load(resp_data)
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): action = "registered" if register else "logged in as"
error = await resp.text() print(f"{Fore.GREEN}Successfully {action} "
action = "register" if is_register else "log in" f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.")
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}")
except HTTPError as e:
try:
err = json.load(e)
except json.JSONDecodeError:
err = {}
action = "register" if register else "log in"
print(f"{Fore.RED}Failed to {action}: {err.get('error', str(e))}{Fore.RESET}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,28 +13,21 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Optional, Union, IO
from typing import IO
from io import BytesIO from io import BytesIO
import asyncio
import glob
import os
import zipfile import zipfile
import os
from aiohttp import ClientSession from mautrix.client.api.types.util import SerializerError
from colorama import Fore
from questionary import prompt
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from colorama import Fore
from PyInquirer import prompt
import click import click
from mautrix.types import SerializerError
from ...loader import PluginMeta from ...loader import PluginMeta
from ..base import app
from ..cliq import cliq
from ..cliq.validators import PathValidator from ..cliq.validators import PathValidator
from ..config import get_token from ..base import app
from ..config import get_default_server, get_token
from .upload import upload_file from .upload import upload_file
yaml = YAML() yaml = YAML()
@ -46,7 +39,7 @@ def zipdir(zip, dir):
zip.write(os.path.join(root, file)) zip.write(os.path.join(root, file))
def read_meta(path: str) -> PluginMeta | None: def read_meta(path: str) -> Optional[PluginMeta]:
try: try:
with open(os.path.join(path, "maubot.yaml")) as meta_file: with open(os.path.join(path, "maubot.yaml")) as meta_file:
try: try:
@ -67,7 +60,7 @@ def read_meta(path: str) -> PluginMeta | None:
return meta return meta
def read_output_path(output: str, meta: PluginMeta) -> str | None: def read_output_path(output: str, meta: PluginMeta) -> Optional[str]:
directory = os.getcwd() directory = os.getcwd()
filename = f"{meta.id}-v{meta.version}.mbp" filename = f"{meta.id}-v{meta.version}.mbp"
if not output: if not output:
@ -75,15 +68,18 @@ def read_output_path(output: str, meta: PluginMeta) -> str | None:
elif os.path.isdir(output): elif os.path.isdir(output):
output = os.path.join(output, filename) output = os.path.join(output, filename)
elif os.path.exists(output): elif os.path.exists(output):
q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}] override = prompt({
override = prompt(q)["override"] "type": "confirm",
"name": "override",
"message": f"{output} exists, override?"
})["override"]
if not override: if not override:
return None return None
os.remove(output) os.remove(output)
return os.path.abspath(output) return os.path.abspath(output)
def write_plugin(meta: PluginMeta, output: str | IO) -> None: def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None:
with zipfile.ZipFile(output, "w") as zip: with zipfile.ZipFile(output, "w") as zip:
meta_dump = BytesIO() meta_dump = BytesIO()
yaml.dump(meta.serialize(), meta_dump) yaml.dump(meta.serialize(), meta_dump)
@ -93,47 +89,33 @@ def write_plugin(meta: PluginMeta, output: str | IO) -> None:
if os.path.isfile(f"{module}.py"): if os.path.isfile(f"{module}.py"):
zip.write(f"{module}.py") zip.write(f"{module}.py")
elif module is not None and os.path.isdir(module): elif module is not None and os.path.isdir(module):
if os.path.isfile(f"{module}/__init__.py"): zipdir(zip, module)
zipdir(zip, module)
else:
print(
Fore.YELLOW
+ f"Module {module} is missing __init__.py, skipping"
+ Fore.RESET
)
else: else:
print(Fore.YELLOW + f"Module {module} not found, skipping" + Fore.RESET) print(Fore.YELLOW + f"Module {module} not found, skipping" + Fore.RESET)
for pattern in meta.extra_files:
for file in glob.iglob(pattern): for file in meta.extra_files:
zip.write(file) zip.write(file)
@cliq.with_authenticated_http def upload_plugin(output: Union[str, IO], server: str) -> None:
async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None:
server, token = get_token(server) server, token = get_token(server)
if not token: if not token:
return return
if isinstance(output, str): if isinstance(output, str):
with open(output, "rb") as file: with open(output, "rb") as file:
await upload_file(sess, file, server) upload_file(file, server, token)
else: else:
await upload_file(sess, output, server) upload_file(output, server, token)
@app.command( @app.command(short_help="Build a maubot plugin",
short_help="Build a maubot plugin", help="Build a maubot plugin. First parameter is the path to root of the plugin "
help=( "to build. You can also use --output to specify output file.")
"Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file."
),
)
@click.argument("path", default=os.getcwd()) @click.argument("path", default=os.getcwd())
@click.option( @click.option("-o", "--output", help="Path to output built plugin to",
"-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type type=PathValidator.click_type)
) @click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True,
@click.option( default=False)
"-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False
)
@click.option("-s", "--server", help="Server to upload built plugin to") @click.option("-s", "--server", help="Server to upload built plugin to")
def build(path: str, output: str, upload: bool, server: str) -> None: def build(path: str, output: str, upload: bool, server: str) -> None:
meta = read_meta(path) meta = read_meta(path)
@ -152,4 +134,4 @@ def build(path: str, output: str, upload: bool, server: str) -> None:
else: else:
output.seek(0) output.seek(0)
if upload: if upload:
asyncio.run(upload_plugin(output, server=server)) upload_plugin(output, server)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,11 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from pkg_resources import resource_string
import os import os
from jinja2 import Template
from packaging.version import Version from packaging.version import Version
from pkg_resources import resource_string from jinja2 import Template
from .. import cliq from .. import cliq
from ..cliq import SPDXValidator, VersionValidator from ..cliq import SPDXValidator, VersionValidator
@ -40,55 +40,26 @@ def load_templates():
@cliq.command(help="Initialize a new maubot plugin") @cliq.command(help="Initialize a new maubot plugin")
@cliq.option( @cliq.option("-n", "--name", help="The name of the project", required=True,
"-n", default=os.path.basename(os.getcwd()))
"--name", @cliq.option("-i", "--id", message="ID", required=True,
help="The name of the project", help="The maubot plugin ID (Java package name format)")
required=True, @cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)",
default=os.path.basename(os.getcwd()), default="0.1.0", validator=VersionValidator, required=True)
) @cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later",
@cliq.option( help="The license for the project (SPDX identifier)", required=False)
"-i", @cliq.option("-c", "--config", message="Should the plugin include a config?",
"--id", help="Include a config in the plugin stub", default=False, is_flag=True)
message="ID",
required=True,
help="The maubot plugin ID (Java package name format)",
)
@cliq.option(
"-v",
"--version",
help="Initial version for project (PEP-440 format)",
default="0.1.0",
validator=VersionValidator,
required=True,
)
@cliq.option(
"-l",
"--license",
validator=SPDXValidator,
default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)",
required=False,
)
@cliq.option(
"-c",
"--config",
message="Should the plugin include a config?",
help="Include a config in the plugin stub",
default=False,
is_flag=True,
)
def init(name: str, id: str, version: Version, license: str, config: bool) -> None: def init(name: str, id: str, version: Version, license: str, config: bool) -> None:
load_templates() load_templates()
main_class = name[0].upper() + name[1:] main_class = name[0].upper() + name[1:]
meta = meta_template.render( meta = meta_template.render(id=id, version=str(version), license=license, config=config,
id=id, version=str(version), license=license, config=config, main_class=main_class main_class=main_class)
)
with open("maubot.yaml", "w") as file: with open("maubot.yaml", "w") as file:
file.write(meta) file.write(meta)
if license: if license:
with open("LICENSE", "w") as file: with open("LICENSE", "w") as file:
file.write(spdx.get(license)["licenseText"]) file.write(spdx.get(license)["text"])
if not os.path.isdir(name): if not os.path.isdir(name):
os.mkdir(name) os.mkdir(name)
mod = mod_template.render(config=config, name=main_class) mod = mod_template.render(config=config, name=main_class)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,65 +13,37 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.request import urlopen
from urllib.error import HTTPError
import json import json
import os import os
from colorama import Fore from colorama import Fore
from yarl import URL
import aiohttp
from ..config import save_config, config
from ..cliq import cliq from ..cliq import cliq
from ..config import config, save_config
@cliq.command(help="Log in to a Maubot instance") @cliq.command(help="Log in to a Maubot instance")
@cliq.option( @cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True)
"-u", @cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True)
"--username", @cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True)
help="The username of your account", def login(server, username, password) -> None:
default=os.environ.get("USER", None),
required=True,
)
@cliq.option(
"-p", "--password", help="The password to your account", inq_type="password", required=True
)
@cliq.option(
"-s",
"--server",
help="The server to log in to",
default="http://localhost:29316",
required=True,
)
@cliq.option(
"-a",
"--alias",
help="Alias to reference the server without typing the full URL",
default="",
required=False,
)
@cliq.with_http
async def login(
server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession
) -> None:
data = { data = {
"username": username, "username": username,
"password": password, "password": password,
} }
url = URL(server) / "_matrix/maubot/v1/auth/login" try:
async with sess.post(url, json=data) as resp: with urlopen(f"{server}/_matrix/maubot/v1/auth/login",
if resp.status == 200: data=json.dumps(data).encode("utf-8")) as resp_data:
data = await resp.json() resp = json.load(resp_data)
config["servers"][server] = data["token"] config["servers"][server] = resp["token"]
if not config["default_server"]: config["default_server"] = server
print(Fore.CYAN, "Setting", server, "as the default server")
config["default_server"] = server
if alias:
config["aliases"][alias] = server
save_config() save_config()
print(Fore.GREEN + "Logged in successfully") print(Fore.GREEN + "Logged in successfully")
else: except HTTPError as e:
try: try:
err = (await resp.json())["error"] err = json.load(e)
except (json.JSONDecodeError, KeyError): except json.JSONDecodeError:
err = await resp.text() err = {}
print(Fore.RED + err + Fore.RESET) print(Fore.RED + err.get("error", str(e)) + Fore.RESET)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,14 +16,13 @@
from datetime import datetime from datetime import datetime
import asyncio import asyncio
from aiohttp import ClientSession, WSMessage, WSMsgType
from colorama import Fore from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
from mautrix.client.api.types.util import Obj
import click import click
from mautrix.types import Obj
from ..base import app
from ..config import get_token from ..config import get_token
from ..base import app
history_count: int = 10 history_count: int = 10
@ -38,13 +37,19 @@ def logs(server: str, tail: int) -> None:
global history_count global history_count
history_count = tail history_count = tail
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(view_logs(server, token)) future = asyncio.ensure_future(view_logs(server, token), loop=loop)
try:
loop.run_until_complete(future)
except KeyboardInterrupt:
future.cancel()
loop.run_until_complete(future)
loop.close()
def parsedate(entry: Obj) -> None: def parsedate(entry: Obj) -> None:
i = entry.time.index("+") i = entry.time.index("+")
i = entry.time.index(":", i) i = entry.time.index(":", i)
entry.time = entry.time[:i] + entry.time[i + 1 :] entry.time = entry.time[:i] + entry.time[i + 1:]
entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z") entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z")
@ -60,16 +65,13 @@ levelcolors = {
def print_entry(entry: dict) -> None: def print_entry(entry: dict) -> None:
entry = Obj(**entry) entry = Obj(**entry)
parsedate(entry) parsedate(entry)
print( print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}"
"{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format( .format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), level=entry.levelname,
level=entry.levelname, levelcolor=levelcolors.get(entry.levelname, ""),
levelcolor=levelcolors.get(entry.levelname, ""), resetcolor=Fore.RESET,
resetcolor=Fore.RESET, logger=entry.name,
logger=entry.name, message=entry.msg))
message=entry.msg,
)
)
if entry.exc_info: if entry.exc_info:
print(entry.exc_info) print(entry.exc_info)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,46 +13,45 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.request import urlopen, Request
from urllib.error import HTTPError
from typing import IO from typing import IO
import json import json
from colorama import Fore from colorama import Fore
from yarl import URL
import aiohttp
import click import click
from ..cliq import cliq from ..base import app
from ..config import get_default_server, get_token
class UploadError(Exception): class UploadError(Exception):
pass pass
@cliq.command(help="Upload a maubot plugin") @app.command(help="Upload a maubot plugin")
@click.argument("path") @click.argument("path")
@click.option("-s", "--server", help="The maubot instance to upload the plugin to") @click.option("-s", "--server", help="The maubot instance to upload the plugin to")
@cliq.with_authenticated_http def upload(path: str, server: str) -> None:
async def upload(path: str, server: str, sess: aiohttp.ClientSession) -> None: server, token = get_token(server)
if not token:
return
with open(path, "rb") as file: with open(path, "rb") as file:
await upload_file(sess, file, server) upload_file(file, server, token)
async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> None: def upload_file(file: IO, server: str, token: str) -> None:
url = (URL(server) / "_matrix/maubot/v1/plugins/upload").with_query({"allow_override": "true"}) req = Request(f"{server}/_matrix/maubot/v1/plugins/upload?allow_override=true", data=file,
headers = {"Content-Type": "application/zip"} headers={"Authorization": f"Bearer {token}", "Content-Type": "application/zip"})
async with sess.post(url, data=file, headers=headers) as resp: try:
if resp.status in (200, 201): with urlopen(req) as resp_data:
data = await resp.json() resp = json.load(resp_data)
print( print(f"{Fore.GREEN}Plugin {Fore.CYAN}{resp['id']} v{resp['version']}{Fore.GREEN} "
f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} " f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}")
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}" except HTTPError as e:
) try:
else: err = json.load(e)
try: except json.JSONDecodeError:
err = await resp.json() err = {}
if "stacktrace" in err: print(err.get("stacktrace", ""))
print(err["stacktrace"]) print(Fore.RED + "Failed to upload plugin: " + err.get("error", str(e)) + Fore.RESET)
err = err["error"]
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
err = await resp.text()
print(f"{Fore.RED}Failed to upload plugin: {err}{Fore.RESET}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,50 +13,37 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Tuple, Optional, Dict, Any
from typing import Any
import json import json
import os import os
from colorama import Fore from colorama import Fore
config: dict[str, Any] = { config: Dict[str, Any] = {
"servers": {}, "servers": {},
"aliases": {},
"default_server": None, "default_server": None,
} }
configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config")) configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config"))
def get_default_server() -> tuple[str | None, str | None]: def get_default_server() -> Tuple[Optional[str], Optional[str]]:
try: try:
server: str < None = config["default_server"] server: Optional[str] = config["default_server"]
except KeyError: except KeyError:
server = None server = None
if server is None: if server is None:
print(f"{Fore.RED}Default server not configured.{Fore.RESET}") print(f"{Fore.RED}Default server not configured.{Fore.RESET}")
print(f"Perhaps you forgot to {Fore.CYAN}mbc login{Fore.RESET}?")
return None, None return None, None
return server, _get_token(server) return server, _get_token(server)
def get_token(server: str) -> tuple[str | None, str | None]: def get_token(server: str) -> Tuple[Optional[str], Optional[str]]:
if not server: if not server:
return get_default_server() return get_default_server()
if server in config["aliases"]:
server = config["aliases"][server]
return server, _get_token(server) return server, _get_token(server)
def _resolve_alias(alias: str) -> str | None: def _get_token(server: str) -> Optional[str]:
try:
return config["aliases"][alias]
except KeyError:
return None
def _get_token(server: str) -> str | None:
try: try:
return config["servers"][server] return config["servers"][server]
except KeyError: except KeyError:
@ -73,8 +60,7 @@ def load_config() -> None:
try: try:
with open(f"{configdir}/maubot-cli.json") as file: with open(f"{configdir}/maubot-cli.json") as file:
loaded = json.load(file) loaded = json.load(file)
config["servers"] = loaded.get("servers", {}) config["servers"] = loaded["servers"]
config["aliases"] = loaded.get("aliases", {}) config["default_server"] = loaded["default_server"]
config["default_server"] = loaded.get("default_server", None)
except FileNotFoundError: except FileNotFoundError:
pass pass

Binary file not shown.

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Dict
import json
import zipfile import zipfile
import pkg_resources import pkg_resources
import json
spdx_list: dict[str, dict[str, str]] | None = None spdx_list = None
def load() -> None: def load() -> None:
@ -33,13 +31,13 @@ def load() -> None:
spdx_list = json.load(file) spdx_list = json.load(file)
def get(id: str) -> dict[str, str]: def get(id: str) -> Dict[str, str]:
if not spdx_list: if not spdx_list:
load() load()
return spdx_list[id] return spdx_list[id.lower()]
def valid(id: str) -> bool: def valid(id: str) -> bool:
if not spdx_list: if not spdx_list:
load() load()
return id in spdx_list return id.lower() in spdx_list

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,203 +13,69 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, cast
from collections import defaultdict
import asyncio import asyncio
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter)
from mautrix.client import InternalEventType from mautrix.client import InternalEventType
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
DeviceID,
EventFilter,
EventType,
Filter,
FilterID,
Membership,
PresenceState,
RoomEventFilter,
RoomFilter,
StateEvent,
StateFilter,
StrippedStateEvent,
SyncToken,
UserID,
)
from mautrix.util import background_task
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.logging import TraceLogger
from .db import Client as DBClient from .lib.store_proxy import ClientStoreProxy
from .db import DBClient
from .matrix import MaubotMatrixClient from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, PgCryptoStore
crypto_import_error = None
except ImportError as e:
OlmMachine = PgCryptoStore = None
crypto_import_error = e
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import Maubot
from .instance import PluginInstance from .instance import PluginInstance
log = logging.getLogger("maubot.client")
class Client(DBClient):
maubot: "Maubot" = None
cache: dict[UserID, Client] = {}
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: TraceLogger = logging.getLogger("maubot.client")
class Client:
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
references: set[PluginInstance] references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient client: MaubotMatrixClient
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool started: bool
sync_ok: bool
remote_displayname: str | None remote_displayname: Optional[str]
remote_avatar_url: ContentURI | None remote_avatar_url: Optional[ContentURI]
def __init__( def __init__(self, db_instance: DBClient) -> None:
self, self.db_instance = db_instance
id: UserID,
homeserver: str,
access_token: str,
device_id: DeviceID,
enabled: bool = False,
next_batch: SyncToken = "",
filter_id: FilterID = "",
sync: bool = True,
autojoin: bool = True,
online: bool = True,
displayname: str = "disable",
avatar_url: str = "disable",
) -> None:
super().__init__(
id=id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id,
enabled=bool(enabled),
next_batch=next_batch,
filter_id=filter_id,
sync=bool(sync),
autojoin=bool(autojoin),
online=bool(online),
displayname=displayname,
avatar_url=avatar_url,
)
self._postinited = False
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def _make_client(
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
) -> MaubotMatrixClient:
return MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=token or self.access_token,
client_session=self.http_client,
log=self.log,
crypto_log=self.log.getChild("crypto"),
loop=self.maubot.loop,
device_id=device_id or self.device_id,
sync_store=self,
state_store=self.maubot.state_store,
)
def postinit(self) -> None:
if self._postinited:
raise RuntimeError("postinit() called twice")
self._postinited = True
self.cache[self.id] = self self.cache[self.id] = self
self.log = self.log.getChild(self.id) self.log = log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop)
self.references = set() self.references = set()
self.started = False self.started = False
self.sync_ok = True self.sync_ok = True
self.remote_displayname = None self.remote_displayname = None
self.remote_avatar_url = None self.remote_avatar_url = None
self.client = self._make_client() self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
if self.enable_crypto: token=self.access_token, client_session=self.http_client,
self._prepare_crypto() log=self.log, loop=self.loop,
else: store=ClientStoreProxy(self.db_instance))
self.crypto_store = None
self.crypto = None
self.client.ignore_initial_sync = True self.client.ignore_initial_sync = True
self.client.ignore_first_sync = True self.client.ignore_first_sync = True
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
if self.autojoin: if self.autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.client.add_event_handler(EventType.ROOM_TOMBSTONE, self._handle_tombstone) self.client.add_event_handler(EventType.ROOM_TOMBSTONE, self._handle_tombstone)
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]: def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None: async def handler(data: Dict[str, Any]) -> None:
self.sync_ok = ok self.sync_ok = ok
return handler return handler
@property async def start(self, try_n: Optional[int] = 0) -> None:
def enable_crypto(self) -> bool:
if not self.device_id:
return False
elif not OlmMachine:
global crypto_import_error
self.log.warning(
"Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared")
return False
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
)
self.crypto = OlmMachine(
self.client,
self.crypto_store,
self.maubot.state_store,
log=self.client.crypto_log,
)
self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None:
if not self.crypto:
return
handlers = [
(InternalEventType.DEVICE_OTK_COUNT, self.crypto.handle_otk_count),
(InternalEventType.DEVICE_LISTS, self.crypto.handle_device_lists),
(EventType.TO_DEVICE_ENCRYPTED, self.crypto.handle_to_device_event),
(EventType.ROOM_KEY_REQUEST, self.crypto.handle_room_key_request),
(EventType.ROOM_MEMBER, self.crypto.handle_member_event),
]
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
async def start(self, try_n: int | None = 0) -> None:
try: try:
if try_n > 0: if try_n > 0:
await asyncio.sleep(try_n * 10) await asyncio.sleep(try_n * 10)
@ -217,21 +83,7 @@ class Client(DBClient):
except Exception: except Exception:
self.log.exception("Failed to start") self.log.exception("Failed to start")
async def _start_crypto(self) -> None: async def _start(self, try_n: Optional[int] = 0) -> None:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning(
"Mismatching device ID in crypto store and main database, resetting encryption"
)
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: int | None = 0) -> None:
if not self.enabled: if not self.enabled:
self.log.debug("Not starting disabled client") self.log.debug("Not starting disabled client")
return return
@ -239,60 +91,36 @@ class Client(DBClient):
self.log.warning("Ignoring start() call to started client") self.log.warning("Ignoring start() call to started client")
return return
try: try:
await self.client.versions() user_id = await self.client.whoami()
whoami = await self.client.whoami()
except MatrixInvalidToken as e: except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client") self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False self.db_instance.enabled = False
await self.update()
return return
except Exception as e: except MatrixRequestError:
if try_n >= 8: if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client") self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False self.db_instance.enabled = False
await self.update()
else: else:
self.log.warning( self.log.exception(f"Failed to get /account/whoami, "
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}" f"retrying in {(try_n + 1) * 10}s")
) _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
background_task.create(self.start(try_n + 1))
return return
if whoami.user_id != self.id: if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False self.db_instance.enabled = False
await self.update()
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(
f"Device ID mismatch: expected {self.device_id}, but got {whoami.device_id}"
)
self.enabled = False
await self.update()
return return
if not self.filter_id: if not self.filter_id:
self.filter_id = await self.client.create_filter( self.db_instance.edit(filter_id=await self.client.create_filter(Filter(
Filter( room=RoomFilter(
room=RoomFilter( timeline=RoomEventFilter(
timeline=RoomEventFilter( limit=50,
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
),
), ),
presence=EventFilter( ),
not_types=[EventType.PRESENCE], )))
),
)
)
await self.update()
if self.displayname != "disable": if self.displayname != "disable":
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
if self.crypto:
await self._start_crypto()
self.start_sync() self.start_sync()
await self._update_remote_profile() await self._update_remote_profile()
self.started = True self.started = True
@ -300,10 +128,11 @@ class Client(DBClient):
await self.start_plugins() await self.start_plugins()
async def start_plugins(self) -> None: async def start_plugins(self) -> None:
await asyncio.gather(*[plugin.start() for plugin in self.references]) await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None: async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started]) await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
loop=self.loop)
def start_sync(self) -> None: def start_sync(self) -> None:
if self.sync: if self.sync:
@ -317,31 +146,29 @@ class Client(DBClient):
self.started = False self.started = False
await self.stop_plugins() await self.stop_plugins()
self.stop_sync() self.stop_sync()
if self.crypto:
await self.crypto_store.close()
async def clear_cache(self) -> None: def clear_cache(self) -> None:
self.stop_sync() self.stop_sync()
self.filter_id = FilterID("") self.db_instance.edit(filter_id="", next_batch="")
self.next_batch = SyncToken("")
await self.update()
self.start_sync() self.start_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
"homeserver": self.homeserver, "homeserver": self.homeserver,
"access_token": self.access_token, "access_token": self.access_token,
"device_id": self.device_id,
"fingerprint": (
self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
),
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"sync": self.sync, "sync": self.sync,
"sync_ok": self.sync_ok, "sync_ok": self.sync_ok,
"autojoin": self.autojoin, "autojoin": self.autojoin,
"online": self.online,
"displayname": self.displayname, "displayname": self.displayname,
"avatar_url": self.avatar_url, "avatar_url": self.avatar_url,
"remote_displayname": self.remote_displayname, "remote_displayname": self.remote_displayname,
@ -349,6 +176,20 @@ class Client(DBClient):
"instances": [instance.to_dict() for instance in self.references], "instances": [instance.to_dict() for instance in self.references],
} }
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> Iterable['Client']:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None: async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room: if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring") self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
@ -360,7 +201,7 @@ class Client(DBClient):
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id) await self.client.join_room(evt.room_id)
async def update_started(self, started: bool | None) -> None: async def update_started(self, started: bool) -> None:
if started is None or started == self.started: if started is None or started == self.started:
return return
if started: if started:
@ -368,162 +209,117 @@ class Client(DBClient):
else: else:
await self.stop() await self.stop()
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None: async def update_displayname(self, displayname: str) -> None:
if enabled is None or enabled == self.enabled:
return
self.enabled = enabled
if save:
await self.update()
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
if displayname is None or displayname == self.displayname: if displayname is None or displayname == self.displayname:
return return
self.displayname = displayname self.db_instance.displayname = displayname
if self.displayname != "disable": if self.displayname != "disable":
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
else: else:
await self._update_remote_profile() await self._update_remote_profile()
if save:
await self.update()
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None: async def update_avatar_url(self, avatar_url: ContentURI) -> None:
if avatar_url is None or avatar_url == self.avatar_url: if avatar_url is None or avatar_url == self.avatar_url:
return return
self.avatar_url = avatar_url self.db_instance.avatar_url = avatar_url
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
else: else:
await self._update_remote_profile() await self._update_remote_profile()
if save:
await self.update()
async def update_sync(self, sync: bool | None, save: bool = True) -> None: async def update_access_details(self, access_token: str, homeserver: str) -> None:
if sync is None or self.sync == sync:
return
self.sync = sync
if self.started:
if sync:
self.start_sync()
else:
self.stop_sync()
if save:
await self.update()
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
if autojoin is None or autojoin == self.autojoin:
return
if autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.autojoin = autojoin
if save:
await self.update()
async def update_online(self, online: bool | None, save: bool = True) -> None:
if online is None or online == self.online:
return
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
self.online = online
if save:
await self.update()
async def update_access_details(
self,
access_token: str | None,
homeserver: str | None,
device_id: str | None = None,
) -> None:
if not access_token and not homeserver: if not access_token and not homeserver:
return return
if device_id is None: elif access_token == self.access_token and homeserver == self.homeserver:
device_id = self.device_id
elif not device_id:
device_id = None
if (
access_token == self.access_token
and homeserver == self.homeserver
and device_id == self.device_id
):
return return
new_client = self._make_client(homeserver, access_token, device_id) new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
whoami = await new_client.whoami() token=access_token or self.access_token, loop=self.loop,
if whoami.user_id != self.id: client_session=self.http_client, log=self.log)
raise ValueError(f"MXID mismatch: {whoami.user_id}") mxid = await new_client.whoami()
elif whoami.device_id and device_id and whoami.device_id != device_id: if mxid != self.id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}") raise ValueError(f"MXID mismatch: {mxid}")
new_client.sync_store = self new_client.store = self.db_instance
self.stop_sync() self.stop_sync()
# TODO this event handler transfer is pretty hacky
self._remove_crypto_event_handlers()
self.client.crypto = None
new_client.event_handlers = self.client.event_handlers
new_client.global_event_handlers = self.client.global_event_handlers
self.client = new_client self.client = new_client
self.homeserver = homeserver self.db_instance.homeserver = homeserver
self.access_token = access_token self.db_instance.access_token = access_token
self.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
else:
self.crypto_store = None
self.crypto = None
self.start_sync() self.start_sync()
async def _update_remote_profile(self) -> None: async def _update_remote_profile(self) -> None:
profile = await self.client.get_profile(self.id) profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
async def delete(self) -> None: # region Properties
try:
del self.cache[self.id]
except KeyError:
pass
await super().delete()
@classmethod @property
@async_getter_lock def id(self) -> UserID:
async def get( return self.db_instance.id
cls,
user_id: UserID,
*,
homeserver: str | None = None,
access_token: str | None = None,
device_id: DeviceID | None = None,
) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
pass
user = cast(cls, await super().get(user_id)) @property
if user is not None: def homeserver(self) -> str:
user.postinit() return self.db_instance.homeserver
return user
if homeserver and access_token: @property
user = cls( def access_token(self) -> str:
user_id, return self.db_instance.access_token
homeserver=homeserver,
access_token=access_token,
device_id=device_id or "",
)
await user.insert()
user.postinit()
return user
return None @property
def enabled(self) -> bool:
return self.db_instance.enabled
@classmethod @enabled.setter
async def all(cls) -> AsyncGenerator[Client, None]: def enabled(self, value: bool) -> None:
users = await super().all() self.db_instance.enabled = value
user: cls
for user in users: @property
try: def next_batch(self) -> SyncToken:
yield cls.cache[user.id] return self.db_instance.next_batch
except KeyError:
user.postinit() @property
yield user def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.db_instance.autojoin = value
@property
def displayname(self) -> str:
return self.db_instance.displayname
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
# endregion
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,10 +14,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import random import random
import re
import string import string
import bcrypt import bcrypt
import re
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
@ -32,50 +31,33 @@ class Config(BaseFileConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
base = helper.base base = helper.base
copy = helper.copy copy = helper.copy
copy("database")
if "database" in self and self["database"].startswith("sqlite:///"):
helper.base["database"] = self["database"].replace("sqlite:///", "sqlite:")
else:
copy("database")
copy("database_opts")
if isinstance(self["crypto_database"], dict):
if self["crypto_database.type"] == "postgres":
base["crypto_database"] = self["crypto_database.postgres_uri"]
else:
copy("crypto_database")
copy("plugin_directories.upload") copy("plugin_directories.upload")
copy("plugin_directories.load") copy("plugin_directories.load")
copy("plugin_directories.trash") copy("plugin_directories.trash")
if "plugin_directories.db" in self: copy("plugin_directories.db")
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
else:
copy("plugin_databases.sqlite")
copy("plugin_databases.postgres")
copy("plugin_databases.postgres_opts")
copy("server.hostname") copy("server.hostname")
copy("server.port") copy("server.port")
copy("server.public_url") copy("server.public_url")
copy("server.listen") copy("server.listen")
copy("server.base_path")
copy("server.ui_base_path") copy("server.ui_base_path")
copy("server.plugin_base_path") copy("server.plugin_base_path")
copy("server.override_resource_path") copy("server.override_resource_path")
copy("server.appservice_base_path")
shared_secret = self["server.unshared_secret"] shared_secret = self["server.unshared_secret"]
if shared_secret is None or shared_secret == "generate": if shared_secret is None or shared_secret == "generate":
base["server.unshared_secret"] = self._new_token() base["server.unshared_secret"] = self._new_token()
else: else:
base["server.unshared_secret"] = shared_secret base["server.unshared_secret"] = shared_secret
if "registration_secrets" in self: copy("registration_secrets")
base["homeservers"] = self["registration_secrets"]
else:
copy("homeservers")
copy("admins") copy("admins")
for username, password in base["admins"].items(): for username, password in base["admins"].items():
if password and not bcrypt_regex.match(password): if password and not bcrypt_regex.match(password):
if password == "password": if password == "password":
password = self._new_token() password = self._new_token()
base["admins"][username] = bcrypt.hashpw( base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
password.encode("utf-8"), bcrypt.gensalt() bcrypt.gensalt()).decode("utf-8")
).decode("utf-8")
copy("api_features.login") copy("api_features.login")
copy("api_features.plugin") copy("api_features.plugin")
copy("api_features.plugin_upload") copy("api_features.plugin_upload")

111
maubot/db.py Normal file
View File

@ -0,0 +1,111 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Iterable, Optional
import logging
import sys
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.util.db import Base
from .config import Config
class DBPlugin(Base):
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False)
@classmethod
def all(cls) -> Iterable['DBPlugin']:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBPlugin']:
return cls._select_one_or_none(cls.c.id == id)
class DBPluginFile(Base):
__tablename__ = "plugin_file"
plugin_id: str = Column(String(255),
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
file_name: str = Column(String(255), primary_key=True)
content: str = Column(Text, nullable=False, default="")
@classmethod
def all_for_plugin(cls, id: str) -> Iterable['DBPluginFile']:
return cls._select_all(cls.c.plugin_id == id)
class DBClient(Base):
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(Text, nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="")
sync: bool = Column(Boolean, nullable=False, default=True)
autojoin: bool = Column(Boolean, nullable=False, default=True)
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBClient']:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBClient']:
return cls._select_one_or_none(cls.c.id == id)
def init(config: Config) -> Engine:
db = sql.create_engine(config["database"])
Base.metadata.bind = db
for table in (DBPlugin, DBClient):
table.bind(db)
if not db.has_table("alembic_version"):
log = logging.getLogger("maubot.db")
if db.has_table("client") and db.has_table("plugin"):
log.warning("alembic_version table not found, but client and plugin tables found. "
"Assuming pre-Alembic database and inserting version.")
db.execute("CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");")
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
else:
log.critical("alembic_version table not found. "
"Did you forget to `alembic upgrade head`?")
sys.exit(10)
return db

View File

@ -1,13 +0,0 @@
from mautrix.util.async_db import Database
from .client import Client
from .instance import DatabaseEngine, Instance
from .upgrade import upgrade_table
def init(db: Database) -> None:
for table in (Client, Instance):
table.db = db
__all__ = ["upgrade_table", "init", "Client", "Instance", "DatabaseEngine"]

View File

@ -1,114 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Client(SyncStore):
db: ClassVar[Database] = fake_db
id: UserID
homeserver: str
access_token: str
device_id: DeviceID
enabled: bool
next_batch: SyncToken
filter_id: FilterID
sync: bool
autojoin: bool
online: bool
displayname: str
avatar_url: ContentURI
@classmethod
def _from_row(cls, row: Record | None) -> Client | None:
if row is None:
return None
return cls(**row)
_columns = (
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
"sync, autojoin, online, displayname, avatar_url"
)
@property
def _values(self):
return (
self.id,
self.homeserver,
self.access_token,
self.device_id,
self.enabled,
self.next_batch,
self.filter_id,
self.sync,
self.autojoin,
self.online,
self.displayname,
self.avatar_url,
)
@classmethod
async def all(cls) -> list[Client]:
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Client | None:
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def insert(self) -> None:
q = """
INSERT INTO client (
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
sync, autojoin, online, displayname, avatar_url
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"""
await self.db.execute(q, *self._values)
async def put_next_batch(self, next_batch: SyncToken) -> None:
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
self.next_batch = next_batch
async def get_next_batch(self) -> SyncToken:
return self.next_batch
async def update(self) -> None:
q = """
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
displayname=$11, avatar_url=$12
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)

View File

@ -1,101 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from enum import Enum
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
class DatabaseEngine(Enum):
SQLITE = "sqlite"
POSTGRES = "postgres"
@dataclass
class Instance:
db: ClassVar[Database] = fake_db
id: str
type: str
enabled: bool
primary_user: UserID
config_str: str
database_engine: DatabaseEngine | None
@property
def database_engine_str(self) -> str | None:
return self.database_engine.value if self.database_engine else None
@classmethod
def _from_row(cls, row: Record | None) -> Instance | None:
if row is None:
return None
data = {**row}
db_engine = data.pop("database_engine", None)
return cls(**data, database_engine=DatabaseEngine(db_engine) if db_engine else None)
_columns = "id, type, enabled, primary_user, config, database_engine"
@classmethod
async def all(cls) -> list[Instance]:
q = f"SELECT {cls._columns} FROM instance"
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Instance | None:
q = f"SELECT {cls._columns} FROM instance WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def update_id(self, new_id: str) -> None:
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
self.id = new_id
@property
def _values(self):
return (
self.id,
self.type,
self.enabled,
self.primary_user,
self.config_str,
self.database_engine_str,
)
async def insert(self) -> None:
q = (
"INSERT INTO instance (id, type, enabled, primary_user, config, database_engine) "
"VALUES ($1, $2, $3, $4, $5, $6)"
)
await self.db.execute(q, *self._values)
async def update(self) -> None:
q = """
UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5, database_engine=$6
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)

View File

@ -1,5 +0,0 @@
from mautrix.util.async_db import UpgradeTable
upgrade_table = UpgradeTable()
from . import v01_initial_revision, v02_instance_database_engine

View File

@ -1,136 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version"
last_legacy_version = "90aa88820eab"
@upgrade_table.register(description="Initial asyncpg revision")
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
if await conn.table_exists("alembic_version"):
await migrate_legacy_to_v1(conn, scheme)
else:
return await create_v1_tables(conn)
async def create_v1_tables(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE client (
id TEXT PRIMARY KEY,
homeserver TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
next_batch TEXT NOT NULL,
filter_id TEXT NOT NULL,
sync BOOLEAN NOT NULL,
autojoin BOOLEAN NOT NULL,
online BOOLEAN NOT NULL,
displayname TEXT NOT NULL,
avatar_url TEXT NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE instance (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
primary_user TEXT NOT NULL,
config TEXT NOT NULL,
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
)"""
)
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version:
raise RuntimeError(
"Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
await conn.execute("ALTER TABLE plugin RENAME TO instance")
await update_state_store(conn, scheme)
if scheme != Scheme.SQLITE:
await varchar_to_text(conn)
await conn.execute("DROP TABLE alembic_version")
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
# The Matrix state store already has more or less the correct schema, so set the version
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
if scheme != Scheme.SQLITE:
# Remove old uppercase membership type and recreate it as lowercase
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
await conn.execute("DROP TYPE IF EXISTS membership")
await conn.execute(
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
)
await conn.execute(
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
"USING LOWER(membership)::membership"
)
else:
# Recreate table to remove CHECK constraint and lowercase everything
await conn.execute(
"""CREATE TABLE new_mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
)"""
)
await conn.execute(
"""
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
FROM mx_user_profile
"""
)
await conn.execute("DROP TABLE mx_user_profile")
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = {
"client": (
"id",
"homeserver",
"device_id",
"next_batch",
"filter_id",
"displayname",
"avatar_url",
),
"instance": ("id", "type", "primary_user"),
"mx_room_state": ("room_id",),
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
}
for table, columns in columns_to_adjust.items():
for column in columns:
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')

View File

@ -1 +1 @@
from . import command, event, web from . import event, command, web

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,46 +13,28 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import ( from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Any, Dict, Tuple, Set, Iterable)
Awaitable,
Callable,
Dict,
Iterable,
List,
NewType,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import functools import functools
import inspect import inspect
import re import re
from mautrix.types import EventType, MessageType from mautrix.types import MessageType, EventType
from ..matrix import MaubotMessageEvent from ..matrix import MaubotMessageEvent
from . import event from . import event
PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]] PrefixType = Optional[Union[str, Callable[[], str]]]
AliasesType = Union[ AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool]]
List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool] CommandHandlerFunc = NewType("CommandHandlerFunc",
] Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
CommandHandlerFunc = NewType( CommandHandlerDecorator = NewType("CommandHandlerDecorator",
"CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]] Callable[[Union['CommandHandler', CommandHandlerFunc]],
) 'CommandHandler'])
CommandHandlerDecorator = NewType( PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
"CommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc])
Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"],
)
PassiveCommandHandlerDecorator = NewType(
"PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc]
)
def _split_in_two(val: str, split_by: str) -> List[str]: def _split_in_two(val: str, split_by: str) -> List[str]:
@ -69,10 +51,9 @@ class CommandHandler:
self.__mb_get_name__: Callable[[Any], str] = lambda s: "noname" self.__mb_get_name__: Callable[[Any], str] = lambda s: "noname"
self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset
self.__mb_require_subcommand__: bool = True self.__mb_require_subcommand__: bool = True
self.__mb_must_consume_args__: bool = True
self.__mb_arg_fallthrough__: bool = True self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True self.__mb_event_handler__: bool = True
self.__mb_event_types__: set[EventType] = {EventType.ROOM_MESSAGE} self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,) self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,)
self.__bound_copies__: Dict[Any, CommandHandler] = {} self.__bound_copies__: Dict[Any, CommandHandler] = {}
self.__bound_instance__: Any = None self.__bound_instance__: Any = None
@ -84,27 +65,15 @@ class CommandHandler:
return self.__bound_copies__[instance] return self.__bound_copies__[instance]
except KeyError: except KeyError:
new_ch = type(self)(self.__mb_func__) new_ch = type(self)(self.__mb_func__)
keys = [ keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
"parent", "require_subcommand", "arg_fallthrough", "event_handler", "event_type",
"subcommands", "msgtypes"]
"arguments",
"help",
"get_name",
"is_command_match",
"require_subcommand",
"must_consume_args",
"arg_fallthrough",
"event_handler",
"event_types",
"msgtypes",
]
for key in keys: for key in keys:
key = f"__mb_{key}__" key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key)) setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [ new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__ for subcmd in self.__mb_subcommands__]
]
self.__bound_copies__[instance] = new_ch self.__bound_copies__[instance] = new_ch
return new_ch return new_ch
@ -112,20 +81,14 @@ class CommandHandler:
def __command_match_unset(self, val: str) -> bool: def __command_match_unset(self, val: str) -> bool:
raise NotImplementedError("Hmm") raise NotImplementedError("Hmm")
async def __call__( async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
self, remaining_val: str = None) -> Any:
evt: MaubotMessageEvent,
*,
_existing_args: Dict[str, Any] = None,
remaining_val: str = None,
) -> Any:
if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__: if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__:
return return
if remaining_val is None: if remaining_val is None:
if not evt.content.body or evt.content.body[0] != "!": if not evt.content.body or evt.content.body[0] != "!":
return return
command, remaining_val = _split_in_two(evt.content.body[1:], " ") command, remaining_val = _split_in_two(evt.content.body[1:], " ")
command = command.lower()
if not self.__mb_is_command_match__(self.__bound_instance__, command): if not self.__mb_is_command_match__(self.__bound_instance__, command):
return return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {} call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
@ -146,34 +109,26 @@ class CommandHandler:
await evt.reply(self.__mb_full_help__) await evt.reply(self.__mb_full_help__)
return return
if self.__mb_must_consume_args__ and remaining_val.strip():
await evt.reply(self.__mb_full_help__)
return
if self.__bound_instance__: if self.__bound_instance__:
return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
return await self.__mb_func__(evt, **call_args) return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__( async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str remaining_val: str) -> Tuple[bool, Any]:
) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ") command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__: for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand( return True, await subcommand(evt, _existing_args=call_args,
evt, _existing_args=call_args, remaining_val=remaining_val remaining_val=remaining_val)
)
return False, None return False, None
async def __parse_args__( async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str remaining_val: str) -> Tuple[bool, str]:
) -> Tuple[bool, str]:
for arg in self.__mb_arguments__: for arg in self.__mb_arguments__:
try: try:
remaining_val, call_args[arg.name] = arg.match( remaining_val, call_args[arg.name] = arg.match(remaining_val.strip(), evt=evt,
remaining_val.strip(), evt=evt, instance=self.__bound_instance__ instance=self.__bound_instance__)
) if arg.required and not call_args[arg.name]:
if arg.required and call_args[arg.name] is None:
raise ValueError("Argument required") raise ValueError("Argument required")
except ArgumentSyntaxError as e: except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else "")) await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
@ -186,16 +141,13 @@ class CommandHandler:
@property @property
def __mb_full_help__(self) -> str: def __mb_full_help__(self) -> str:
usage = self.__mb_usage_without_subcommands__ + "\n\n" usage = self.__mb_usage_without_subcommands__ + "\n\n"
if not self.__mb_require_subcommand__:
usage += f"* {self.__mb_prefix__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__) usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__)
return usage return usage
@property @property
def __mb_usage_args__(self) -> str: def __mb_usage_args__(self) -> str:
arg_usage = " ".join( arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__ for arg in self.__mb_arguments__)
)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__: if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__ arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage return arg_usage
@ -211,19 +163,14 @@ class CommandHandler:
@property @property
def __mb_prefix__(self) -> str: def __mb_prefix__(self) -> str:
if self.__mb_parent__: if self.__mb_parent__:
return ( return f"{self.__mb_parent__.__mb_prefix__} {self.__mb_name__}"
f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} "
f"{self.__mb_name__}"
)
return f"!{self.__mb_name__}" return f"!{self.__mb_name__}"
@property @property
def __mb_usage_inline__(self) -> str: def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__: if not self.__mb_arg_fallthrough__:
return ( return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}"
)
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}" return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property @property
@ -233,12 +180,8 @@ class CommandHandler:
@property @property
def __mb_usage_without_subcommands__(self) -> str: def __mb_usage_without_subcommands__(self) -> str:
if not self.__mb_arg_fallthrough__: if not self.__mb_arg_fallthrough__:
if not self.__mb_arguments__: return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]" f" _OR_ {self.__mb_usage_subcommand__}")
return (
f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}"
)
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property @property
@ -247,25 +190,14 @@ class CommandHandler:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__ return self.__mb_usage_without_subcommands__
def subcommand( def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
self, required_subcommand: bool = True, arg_fallthrough: bool = True,
name: PrefixType = None, ) -> CommandHandlerDecorator:
*,
help: str = None,
aliases: AliasesType = None,
required_subcommand: bool = True,
arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler): if not isinstance(func, CommandHandler):
func = CommandHandler(func) func = CommandHandler(func)
new( new(name, help=help, aliases=aliases, require_subcommand=required_subcommand,
name, arg_fallthrough=arg_fallthrough)(func)
help=help,
aliases=aliases,
require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough,
)(func)
func.__mb_parent__ = self func.__mb_parent__ = self
func.__mb_event_handler__ = False func.__mb_event_handler__ = False
self.__mb_subcommands__.append(func) self.__mb_subcommands__.append(func)
@ -274,17 +206,9 @@ class CommandHandler:
return decorator return decorator
def new( def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
name: PrefixType = None, event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None,
*, require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
help: str = None,
aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE,
msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True,
arg_fallthrough: bool = True,
must_consume_args: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler): if not isinstance(func, CommandHandler):
func = CommandHandler(func) func = CommandHandler(func)
@ -298,24 +222,22 @@ def new(
else: else:
func.__mb_get_name__ = lambda self: name func.__mb_get_name__ = lambda self: name
else: else:
func.__mb_get_name__ = lambda self: func.__mb_func__.__name__.replace("_", "-") func.__mb_get_name__ = lambda self: func.__name__
if callable(aliases): if callable(aliases):
if len(inspect.getfullargspec(aliases).args) == 1: if len(inspect.getfullargspec(aliases).args) == 1:
func.__mb_is_command_match__ = lambda self, val: aliases(val) func.__mb_is_command_match__ = lambda self, val: aliases(val)
else: else:
func.__mb_is_command_match__ = aliases func.__mb_is_command_match__ = aliases
elif isinstance(aliases, (list, set, tuple)): elif isinstance(aliases, (list, set, tuple)):
func.__mb_is_command_match__ = lambda self, val: ( func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_get_name__(self)
val == func.__mb_get_name__(self) or val in aliases or val in aliases)
)
else: else:
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self) func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self)
# Decorators are executed last to first, so we reverse the argument list. # Decorators are executed last to first, so we reverse the argument list.
func.__mb_arguments__.reverse() func.__mb_arguments__.reverse()
func.__mb_require_subcommand__ = require_subcommand func.__mb_require_subcommand__ = require_subcommand
func.__mb_arg_fallthrough__ = arg_fallthrough func.__mb_arg_fallthrough__ = arg_fallthrough
func.__mb_must_consume_args__ = must_consume_args func.__mb_event_type__ = event_type
func.__mb_event_types__ = {event_type}
if msgtypes: if msgtypes:
func.__mb_msgtypes__ = msgtypes func.__mb_msgtypes__ = msgtypes
return func return func
@ -331,9 +253,8 @@ class ArgumentSyntaxError(ValueError):
class Argument(ABC): class Argument(ABC):
def __init__( def __init__(self, name: str, label: str = None, *, required: bool = False,
self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False pass_raw: bool = False) -> None:
) -> None:
self.name = name self.name = name
self.label = label or name self.label = label or name
self.required = required self.required = required
@ -351,15 +272,8 @@ class Argument(ABC):
class RegexArgument(Argument): class RegexArgument(Argument):
def __init__( def __init__(self, name: str, label: str = None, *, required: bool = False,
self, pass_raw: bool = False, matches: str = None) -> None:
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matches: str = None,
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw) super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$" matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches) self.regex = re.compile(matches)
@ -370,23 +284,14 @@ class RegexArgument(Argument):
val = re.split(r"\s", val, 1)[0] val = re.split(r"\s", val, 1)[0]
match = self.regex.match(val) match = self.regex.match(val)
if match: if match:
return ( return (orig_val[:match.start()] + orig_val[match.end():],
orig_val[: match.start()] + orig_val[match.end() :], match.groups() or val[match.start():match.end()])
match.groups() or val[match.start() : match.end()],
)
return orig_val, None return orig_val, None
class CustomArgument(Argument): class CustomArgument(Argument):
def __init__( def __init__(self, name: str, label: str = None, *, required: bool = False,
self, pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matcher: Callable[[str], Any],
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw) super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher self.matcher = matcher
@ -396,8 +301,8 @@ class CustomArgument(Argument):
orig_val = val orig_val = val
val = re.split(r"\s", val, 1)[0] val = re.split(r"\s", val, 1)[0]
res = self.matcher(val) res = self.matcher(val)
if res is not None: if res:
return orig_val[len(val) :], res return orig_val[len(val):], res
return orig_val, None return orig_val, None
@ -406,18 +311,12 @@ class SimpleArgument(Argument):
if self.pass_raw: if self.pass_raw:
return "", val return "", val
res = re.split(r"\s", val, 1)[0] res = re.split(r"\s", val, 1)[0]
return val[len(res) :], res return val[len(res):], res
def argument( def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
name: str, parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
label: str = None, ) -> CommandHandlerDecorator:
*,
required: bool = True,
matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False,
) -> CommandHandlerDecorator:
if matches: if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw) return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser: elif parser:
@ -426,17 +325,11 @@ def argument(
return SimpleArgument(name, label, required=required, pass_raw=pass_raw) return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive( def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
regex: Union[str, Pattern], field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
*, event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False,
msgtypes: Sequence[MessageType] = (MessageType.TEXT,), case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, ) -> PassiveCommandHandlerDecorator:
event_type: EventType = EventType.ROOM_MESSAGE,
multiple: bool = False,
case_insensitive: bool = False,
multiline: bool = False,
dot_all: bool = False,
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern): if not isinstance(regex, Pattern):
flags = re.RegexFlag.UNICODE flags = re.RegexFlag.UNICODE
if case_insensitive: if case_insensitive:
@ -465,14 +358,12 @@ def passive(
return return
data = field(evt) data = field(evt)
if multiple: if multiple:
val = [ val = [(data[match.pos:match.endpos], *match.groups())
(data[match.pos : match.endpos], *match.groups()) for match in regex.finditer(data)]
for match in regex.finditer(data)
]
else: else:
match = regex.search(data) match = regex.search(data)
if match: if match:
val = (data[match.pos : match.endpos], *match.groups()) val = (data[match.pos:match.endpos], *match.groups())
else: else:
val = None val = None
if val: if val:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,32 +13,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Callable, Union, NewType
from typing import Callable, NewType
from mautrix.client import EventHandler, InternalEventType
from mautrix.types import EventType from mautrix.types import EventType
from mautrix.client import EventHandler, InternalEventType
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler: def on(var: Union[EventType, InternalEventType, EventHandler]
) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler: def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True func.__mb_event_handler__ = True
if isinstance(var, (EventType, InternalEventType)): if isinstance(var, (EventType, InternalEventType)):
if hasattr(func, "__mb_event_types__"): func.__mb_event_type__ = var
func.__mb_event_types__.add(var)
else:
func.__mb_event_types__ = {var}
else: else:
func.__mb_event_types__ = {EventType.ALL} func.__mb_event_type__ = EventType.ALL
return func return func
return decorator if isinstance(var, (EventType, InternalEventType)) else decorator(var) return decorator if isinstance(var, (EventType, InternalEventType)) else decorator(var)
def off(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = False
return func

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,9 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Awaitable, Callable from typing import Callable, Any, Awaitable
from aiohttp import hdrs, web from aiohttp import web, hdrs
WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]] WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]
WebHandlerDecorator = Callable[[WebHandler], WebHandler] WebHandlerDecorator = Callable[[WebHandler], WebHandler]

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,92 +13,58 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from collections import defaultdict
import asyncio
import inspect
import io
import logging
import os.path import os.path
import logging
import io
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util import background_task
from mautrix.util.async_db import Database, Scheme, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger from mautrix.types import UserID
from .db import DBPlugin
from .config import Config
from .client import Client from .client import Client
from .db import DatabaseEngine, Instance as DBInstance from .loader import PluginLoader, ZippedPluginLoader
from .lib.optionalalchemy import Engine, MetaData, create_engine
from .lib.plugin_db import ProxyPostgresDatabase
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import Maubot from .server import MaubotServer, PluginWebApp
from .server import PluginWebApp
log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance")) log = logging.getLogger("maubot.instance")
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
yaml = YAML() yaml = YAML()
yaml.indent(4) yaml.indent(4)
yaml.width = 200 yaml.width = 200
class PluginInstance(DBInstance): class PluginInstance:
maubot: "Maubot" = None webserver: 'MaubotServer' = None
cache: dict[str, PluginInstance] = {} mb_config: Config = None
plugin_directories: list[str] = [] loop: AbstractEventLoop = None
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
log: logging.Logger log: logging.Logger
loader: PluginLoader | None loader: PluginLoader
client: Client | None client: Client
plugin: Plugin | None plugin: Plugin
config: BaseProxyConfig | None config: BaseProxyConfig
base_cfg: RecursiveDict[CommentedMap] | None base_cfg: Optional[RecursiveDict[CommentedMap]]
base_cfg_str: str | None base_cfg_str: Optional[str]
inst_db: sql.engine.Engine | Database | None inst_db: sql.engine.Engine
inst_db_tables: dict | None inst_db_tables: Dict[str, sql.Table]
inst_webapp: PluginWebApp | None inst_webapp: Optional['PluginWebApp']
inst_webapp_url: str | None inst_webapp_url: Optional[str]
started: bool started: bool
def __init__( def __init__(self, db_instance: DBPlugin):
self, self.db_instance = db_instance
id: str,
type: str,
enabled: bool,
primary_user: UserID,
config: str = "",
database_engine: DatabaseEngine | None = None,
) -> None:
super().__init__(
id=id,
type=type,
enabled=bool(enabled),
primary_user=primary_user,
config_str=config,
database_engine=database_engine,
)
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def postinit(self) -> None:
self.log = log.getChild(self.id) self.log = log.getChild(self.id)
self.cache[self.id] = self
self.config = None self.config = None
self.started = False self.started = False
self.loader = None self.loader = None
@ -110,6 +76,7 @@ class PluginInstance(DBInstance):
self.inst_webapp_url = None self.inst_webapp_url = None
self.base_cfg = None self.base_cfg = None
self.base_cfg_str = None self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@ -118,144 +85,44 @@ class PluginInstance(DBInstance):
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"primary_user": self.primary_user, "primary_user": self.primary_user,
"config": self.config_str, "config": self.db_instance.config,
"base_config": self.base_cfg_str, "base_config": self.base_cfg_str,
"database": ( "database": (self.inst_db is not None
self.inst_db is not None and self.maubot.config["api_features.instance_database"] and self.mb_config["api_features.instance_database"]),
),
"database_interface": self.loader.meta.database_type_str if self.loader else "unknown",
"database_engine": self.database_engine_str,
} }
def _introspect_sqlalchemy(self) -> dict: def get_db_tables(self) -> Dict[str, sql.Table]:
metadata = MetaData() if not self.inst_db_tables:
metadata.reflect(self.inst_db) metadata = sql.MetaData()
return { metadata.reflect(self.inst_db)
table.name: { self.inst_db_tables = metadata.tables
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
}
for column in table.columns
},
}
for table in metadata.tables.values()
}
async def _introspect_sqlite(self) -> dict:
q = """
SELECT
m.name AS table_name,
p.cid AS col_id,
p.name AS column_name,
p.type AS data_type,
p.pk AS is_primary,
p.dflt_value AS column_default,
p.[notnull] AS is_nullable
FROM sqlite_master m
LEFT JOIN pragma_table_info((m.name)) p
WHERE m.type = 'table'
ORDER BY table_name, col_id
"""
data = await self.inst_db.fetch(q)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"][col_name] = {
"type": column["data_type"],
"nullable": bool(column["is_nullable"]),
"default": column["column_default"],
"primary": bool(column["is_primary"]),
# TODO uniqueness?
}
return tables
async def _introspect_postgres(self) -> dict:
assert isinstance(self.inst_db, ProxyPostgresDatabase)
q = """
SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default,
tc.constraint_type
FROM information_schema.columns col
LEFT JOIN information_schema.constraint_column_usage ccu
ON ccu.column_name=col.column_name
LEFT JOIN information_schema.table_constraints tc
ON col.table_name=tc.table_name
AND col.table_schema=tc.table_schema
AND ccu.constraint_name=tc.constraint_name
AND ccu.constraint_schema=tc.constraint_schema
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
WHERE col.table_schema=$1
"""
data = await self.inst_db.fetch(q, self.inst_db.schema_name)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"].setdefault(
col_name,
{
"type": column["data_type"],
"nullable": column["is_nullable"],
"default": column["column_default"],
"primary": False,
"unique": False,
},
)
if column["constraint_type"] == "PRIMARY KEY":
tables[table_name]["columns"][col_name]["primary"] = True
elif column["constraint_type"] == "UNIQUE":
tables[table_name]["columns"][col_name]["unique"] = True
return tables
async def get_db_tables(self) -> dict:
if self.inst_db_tables is None:
if isinstance(self.inst_db, Engine):
self.inst_db_tables = self._introspect_sqlalchemy()
elif self.inst_db.scheme == Scheme.SQLITE:
self.inst_db_tables = await self._introspect_sqlite()
else:
self.inst_db_tables = await self._introspect_postgres()
return self.inst_db_tables return self.inst_db_tables
async def load(self) -> bool: def load(self) -> bool:
if not self.loader: if not self.loader:
try: try:
self.loader = PluginLoader.find(self.type) self.loader = PluginLoader.find(self.type)
except KeyError: except KeyError:
self.log.error(f"Failed to find loader for type {self.type}") self.log.error(f"Failed to find loader for type {self.type}")
await self.update_enabled(False) self.db_instance.enabled = False
return False return False
if not self.client: if not self.client:
self.client = await Client.get(self.primary_user) self.client = Client.get(self.primary_user)
if not self.client: if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}") self.log.error(f"Failed to get client for user {self.primary_user}")
await self.update_enabled(False) self.db_instance.enabled = False
return False return False
if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
if self.loader.meta.webapp: if self.loader.meta.webapp:
self.enable_webapp() self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
self.log.debug("Plugin instance dependencies loaded") self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self) self.loader.references.add(self)
self.client.references.add(self) self.client.references.add(self)
return True return True
def enable_webapp(self) -> None: def delete(self) -> None:
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None:
self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None
self.inst_webapp_url = None
@property
def _sqlite_db_path(self) -> str:
return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
async def delete(self) -> None:
if self.loader is not None: if self.loader is not None:
self.loader.references.remove(self) self.loader.references.remove(self)
if self.client is not None: if self.client is not None:
@ -264,89 +131,22 @@ class PluginInstance(DBInstance):
del self.cache[self.id] del self.cache[self.id]
except KeyError: except KeyError:
pass pass
await super().delete() self.db_instance.delete()
if self.inst_db: if self.inst_db:
await self.stop_database() self.inst_db.dispose()
await self.delete_database() ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted")
if self.inst_webapp: if self.inst_webapp:
self.disable_webapp() self.webserver.remove_instance_webapp(self.id)
def load_config(self) -> CommentedMap: def load_config(self) -> CommentedMap:
return yaml.load(self.config_str) return yaml.load(self.db_instance.config)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None: def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO() buf = io.StringIO()
yaml.dump(data, buf) yaml.dump(data, buf)
val = buf.getvalue() self.db_instance.config = buf.getvalue()
if val != self.config_str:
self.config_str = val
self.log.debug("Creating background task to save updated config")
background_task.create(self.update())
async def start_database(
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
if self.database_engine is None:
await self.update_db_engine(DatabaseEngine.SQLITE)
elif self.database_engine == DatabaseEngine.POSTGRES:
raise RuntimeError(
"Instance database engine is marked as Postgres, but plugin uses legacy "
"database interface, which doesn't support postgres."
)
self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.database_engine is None:
if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db:
await self.update_db_engine(DatabaseEngine.SQLITE)
else:
await self.update_db_engine(DatabaseEngine.POSTGRES)
instance_db_log = db_log.getChild(self.id)
if self.database_engine == DatabaseEngine.POSTGRES:
if not self.maubot.plugin_postgres_db:
raise RuntimeError(
"Instance database engine is marked as Postgres, but this maubot isn't "
"configured to support Postgres for plugin databases"
)
self.inst_db = ProxyPostgresDatabase(
pool=self.maubot.plugin_postgres_db,
instance_id=self.id,
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
upgrade_table=upgrade_table,
log=instance_db_log,
)
else:
self.inst_db = Database.create(
f"sqlite:{self._sqlite_db_path}",
upgrade_table=upgrade_table,
log=instance_db_log,
)
if actually_start:
await self.inst_db.start()
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
async def stop_database(self) -> None:
if isinstance(self.inst_db, Database):
await self.inst_db.stop()
elif isinstance(self.inst_db, Engine):
self.inst_db.dispose()
else:
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
async def delete_database(self) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.inst_db is None:
await self.start_database(None, actually_start=False)
if isinstance(self.inst_db, ProxyPostgresDatabase):
await self.inst_db.delete()
else:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
self.inst_db = None
async def start(self) -> None: async def start(self) -> None:
if self.started: if self.started:
@ -357,22 +157,9 @@ class PluginInstance(DBInstance):
return return
if not self.client or not self.loader: if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...") self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not await self.load(): if not self.load():
return return
cls = await self.loader.load() cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None:
self.log.debug("Enabling webapp after plugin meta reload")
self.enable_webapp()
elif not self.loader.meta.webapp and self.inst_webapp is not None:
self.log.debug("Disabling webapp after plugin meta reload")
self.disable_webapp()
if self.loader.meta.database:
try:
await self.start_database(cls.get_db_upgrade_table())
except Exception:
self.log.exception("Failed to start instance database")
await self.update_enabled(False)
return
config_class = cls.get_config_class() config_class = cls.get_config_class()
if config_class: if config_class:
try: try:
@ -387,35 +174,23 @@ class PluginInstance(DBInstance):
if self.base_cfg: if self.base_cfg:
base_cfg_func = self.base_cfg.clone base_cfg_func = self.base_cfg.clone
else: else:
def base_cfg_func() -> None: def base_cfg_func() -> None:
return None return None
self.config = config_class(self.load_config, base_cfg_func, self.save_config) self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls( self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
client=self.client.client, instance_id=self.id, log=self.log, config=self.config,
loop=self.maubot.loop, database=self.inst_db, webapp=self.inst_webapp,
http=self.client.http_client, webapp_url=self.inst_webapp_url)
instance_id=self.id,
log=self.log,
config=self.config,
database=self.inst_db,
loader=self.loader,
webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url,
)
try: try:
await self.plugin.internal_start() await self.plugin.internal_start()
except Exception: except Exception:
self.log.exception("Failed to start instance") self.log.exception("Failed to start instance")
await self.update_enabled(False) self.db_instance.enabled = False
return return
self.started = True self.started = True
self.inst_db_tables = None self.inst_db_tables = None
self.log.info( self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " f"with user {self.client.id}")
f"with user {self.client.id}"
)
async def stop(self) -> None: async def stop(self) -> None:
if not self.started: if not self.started:
@ -428,58 +203,63 @@ class PluginInstance(DBInstance):
except Exception: except Exception:
self.log.exception("Failed to stop instance") self.log.exception("Failed to stop instance")
self.plugin = None self.plugin = None
if self.inst_db:
try:
await self.stop_database()
except Exception:
self.log.exception("Failed to stop instance database")
self.inst_db_tables = None self.inst_db_tables = None
async def update_id(self, new_id: str | None) -> None: @classmethod
if new_id is not None and new_id.lower() != self.id: def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
await super().update_id(new_id.lower()) ) -> Optional['PluginInstance']:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
async def update_config(self, config: str | None) -> None: @classmethod
if config is None or self.config_str == config: def all(cls) -> Iterable['PluginInstance']:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
return return
self.config_str = config self.db_instance.config = config
if self.started and self.plugin is not None: if self.started and self.plugin is not None:
res = self.plugin.on_external_config_update() self.plugin.on_external_config_update()
if inspect.isawaitable(res):
await res
await self.update()
async def update_primary_user(self, primary_user: UserID | None) -> bool: async def update_primary_user(self, primary_user: UserID) -> bool:
if primary_user is None or primary_user == self.primary_user: if not primary_user or primary_user == self.primary_user:
return True return True
client = await Client.get(primary_user) client = Client.get(primary_user)
if not client: if not client:
return False return False
await self.stop() await self.stop()
self.primary_user = client.id self.db_instance.primary_user = client.id
if self.client: if self.client:
self.client.references.remove(self) self.client.references.remove(self)
self.client = client self.client = client
self.client.references.add(self) self.client.references.add(self)
await self.update()
await self.start() await self.start()
self.log.debug(f"Primary user switched to {self.client.id}") self.log.debug(f"Primary user switched to {self.client.id}")
return True return True
async def update_type(self, type: str | None) -> bool: async def update_type(self, type: str) -> bool:
if type is None or type == self.type: if not type or type == self.type:
return True return True
try: try:
loader = PluginLoader.find(type) loader = PluginLoader.find(type)
except KeyError: except KeyError:
return False return False
await self.stop() await self.stop()
self.type = loader.meta.id self.db_instance.type = loader.meta.id
if self.loader: if self.loader:
self.loader.references.remove(self) self.loader.references.remove(self)
self.loader = loader self.loader = loader
self.loader.references.add(self) self.loader.references.add(self)
await self.update()
await self.start() await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}") self.log.debug(f"Type switched to {self.loader.meta.id}")
return True return True
@ -488,46 +268,38 @@ class PluginInstance(DBInstance):
if started is not None and started != self.started: if started is not None and started != self.started:
await (self.start() if started else self.stop()) await (self.start() if started else self.stop())
async def update_enabled(self, enabled: bool) -> None: def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled: if enabled is not None and enabled != self.enabled:
self.enabled = enabled self.db_instance.enabled = enabled
await self.update()
async def update_db_engine(self, db_engine: DatabaseEngine | None) -> None: # region Properties
if db_engine is not None and db_engine != self.database_engine:
self.database_engine = db_engine
await self.update()
@classmethod @property
@async_getter_lock def id(self) -> str:
async def get( return self.db_instance.id
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
pass
instance = cast(cls, await super().get(instance_id)) @id.setter
if instance is not None: def id(self, value: str) -> None:
instance.postinit() self.db_instance.id = value
return instance
if type and primary_user: @property
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user) def type(self) -> str:
await instance.insert() return self.db_instance.type
instance.postinit()
return instance
return None @property
def enabled(self) -> bool:
return self.db_instance.enabled
@classmethod @property
async def all(cls) -> AsyncGenerator[PluginInstance, None]: def primary_user(self) -> UserID:
instances = await super().all() return self.db_instance.primary_user
instance: PluginInstance
for instance in instances: # endregion
try:
yield cls.cache[instance.id]
except KeyError: def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
instance.postinit() ) -> Iterable[PluginInstance]:
yield instance PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import ( from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
MAU_COLOR, MXID_COLOR, RESET)
MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
INST_COLOR = PREFIX + "35m" # magenta INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue LOADER_COLOR = PREFIX + "36m" # blue
@ -28,22 +23,14 @@ LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter): class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str: def _color_name(self, module: str) -> str:
client = "maubot.client" client = "maubot.client"
if module.startswith(client + "."): if module.startswith(client):
suffix = "" return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
if module.endswith(".crypto"):
suffix = f".{MAU_COLOR}crypto{RESET}"
module = module[: -len(".crypto")]
module = module[len(client) + 1 :]
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
instance = "maubot.instance" instance = "maubot.instance"
if module.startswith(instance + "."): if module.startswith(instance):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
instance_db = "maubot.instance_db"
if module.startswith(instance_db + "."):
return f"{MAU_COLOR}{instance_db}{RESET}.{INST_COLOR}{module[len(instance_db) + 1:]}{RESET}"
loader = "maubot.loader" loader = "maubot.loader"
if module.startswith(loader + "."): if module.startswith(loader):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot."): if module.startswith("maubot"):
return f"{MAU_COLOR}{module}{RESET}" return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module) return super()._color_name(module)

View File

@ -1,9 +0,0 @@
from typing import Any, Awaitable, Callable, Generator
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
self._func = func
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View File

@ -1,19 +0,0 @@
try:
from sqlalchemy import MetaData, asc, create_engine, desc
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
except ImportError:
class FakeError(Exception):
pass
class FakeType:
def __init__(self, *args, **kwargs):
raise Exception("SQLAlchemy is not installed")
def create_engine(*args, **kwargs):
raise Exception("SQLAlchemy is not installed")
MetaData = Engine = FakeType
IntegrityError = OperationalError = FakeError
asc = desc = lambda a: a

View File

@ -1,100 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from contextlib import asynccontextmanager
import asyncio
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
from mautrix.util.async_db.connection import LoggingConnection
from mautrix.util.logging import TraceLogger
remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase
schema_name: str
_quoted_schema: str
_default_search_path: str
_conn_sema: asyncio.Semaphore
_max_conns: int
def __init__(
self,
pool: PostgresDatabase,
instance_id: str,
max_conns: int,
upgrade_table: UpgradeTable | None,
log: TraceLogger | None = None,
) -> None:
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
self._underlying_pool = pool
# Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}"
self._quoted_schema = f'"{self.schema_name}"'
self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
self._max_conns = max_conns
async def start(self) -> None:
async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.trace(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}")
await super().start()
async def stop(self) -> None:
for _ in range(self._max_conns):
try:
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
except asyncio.TimeoutError:
self.log.warning(
"Failed to drain plugin database connection pool, "
"the plugin may be leaking database connections"
)
break
async def delete(self) -> None:
self.log.info(f"Deleting schema {self.schema_name} and all data in it")
try:
await self._underlying_pool.execute(
f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE"
)
except Exception:
self.log.warning("Failed to delete schema", exc_info=True)
@asynccontextmanager
async def acquire(self) -> LoggingConnection:
conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._quoted_schema}")
try:
yield conn
finally:
if not conn.wrapped.is_closed():
try:
await conn.execute(f"SET search_path = {self._default_search_path}")
except Exception:
self.log.exception("Error resetting search_path after use")
await conn.wrapped.close()
else:
self.log.debug("Connection was closed after use, not resetting search_path")
__all__ = ["ProxyPostgresDatabase"]

View File

@ -1,27 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
try:
from mautrix.crypto import StateStore as CryptoStateStore
class PgStateStore(BasePgStateStore, CryptoStateStore):
pass
except ImportError as e:
PgStateStore = BasePgStateStore
__all__ = ["PgStateStore"]

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from mautrix.client import ClientStore
from mautrix.types import SyncToken
from mautrix.util.async_db import Connection
from . import upgrade_table
@upgrade_table.register(description="Store instance database engine") class ClientStoreProxy(ClientStore):
async def upgrade_v2(conn: Connection) -> None: def __init__(self, db_instance) -> None:
await conn.execute("ALTER TABLE instance ADD COLUMN database_engine TEXT") self.db_instance = db_instance
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@next_batch.setter
def next_batch(self, value: SyncToken) -> None:
self.db_instance.edit(next_batch=value)

View File

@ -18,28 +18,26 @@ used by the builtin import mechanism for sys.path items that are paths
to Zip archives. to Zip archives.
""" """
from importlib import _bootstrap # for _verbose_message
from importlib import _bootstrap_external from importlib import _bootstrap_external
from importlib import _bootstrap # for _verbose_message
import _imp # for check_hash_based_pycs
import _io # for open
import marshal # for loads import marshal # for loads
import sys # for modules import sys # for modules
import time # for mktime import time # for mktime
import _imp # for check_hash_based_pycs __all__ = ['ZipImportError', 'zipimporter']
import _io # for open
__all__ = ["ZipImportError", "zipimporter"]
def _unpack_uint32(data): def _unpack_uint32(data):
"""Convert 4 bytes in little-endian to an integer.""" """Convert 4 bytes in little-endian to an integer."""
assert len(data) == 4 assert len(data) == 4
return int.from_bytes(data, "little") return int.from_bytes(data, 'little')
def _unpack_uint16(data): def _unpack_uint16(data):
"""Convert 2 bytes in little-endian to an integer.""" """Convert 2 bytes in little-endian to an integer."""
assert len(data) == 2 assert len(data) == 2
return int.from_bytes(data, "little") return int.from_bytes(data, 'little')
path_sep = _bootstrap_external.path_sep path_sep = _bootstrap_external.path_sep
@ -49,17 +47,15 @@ alt_path_sep = _bootstrap_external.path_separators[1:]
class ZipImportError(ImportError): class ZipImportError(ImportError):
pass pass
# _read_directory() cache # _read_directory() cache
_zip_directory_cache = {} _zip_directory_cache = {}
_module_type = type(sys) _module_type = type(sys)
END_CENTRAL_DIR_SIZE = 22 END_CENTRAL_DIR_SIZE = 22
STRING_END_ARCHIVE = b"PK\x05\x06" STRING_END_ARCHIVE = b'PK\x05\x06'
MAX_COMMENT_LEN = (1 << 16) - 1 MAX_COMMENT_LEN = (1 << 16) - 1
class zipimporter: class zipimporter:
"""zipimporter(archivepath) -> zipimporter object """zipimporter(archivepath) -> zipimporter object
@ -81,10 +77,9 @@ class zipimporter:
def __init__(self, path): def __init__(self, path):
if not isinstance(path, str): if not isinstance(path, str):
import os import os
path = os.fsdecode(path) path = os.fsdecode(path)
if not path: if not path:
raise ZipImportError("archive path is empty", path=path) raise ZipImportError('archive path is empty', path=path)
if alt_path_sep: if alt_path_sep:
path = path.replace(alt_path_sep, path_sep) path = path.replace(alt_path_sep, path_sep)
@ -97,14 +92,14 @@ class zipimporter:
# Back up one path element. # Back up one path element.
dirname, basename = _bootstrap_external._path_split(path) dirname, basename = _bootstrap_external._path_split(path)
if dirname == path: if dirname == path:
raise ZipImportError("not a Zip file", path=path) raise ZipImportError('not a Zip file', path=path)
path = dirname path = dirname
prefix.append(basename) prefix.append(basename)
else: else:
# it exists # it exists
if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG
# it's a not file # it's a not file
raise ZipImportError("not a Zip file", path=path) raise ZipImportError('not a Zip file', path=path)
break break
try: try:
@ -159,10 +154,11 @@ class zipimporter:
# This is possibly a portion of a namespace # This is possibly a portion of a namespace
# package. Return the string representing its path, # package. Return the string representing its path,
# without a trailing separator. # without a trailing separator.
return None, [f"{self.archive}{path_sep}{modpath}"] return None, [f'{self.archive}{path_sep}{modpath}']
return None, [] return None, []
# Check whether we can satisfy the import of the module named by # Check whether we can satisfy the import of the module named by
# 'fullname'. Return self if we can, None if we can't. # 'fullname'. Return self if we can, None if we can't.
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
@ -176,6 +172,7 @@ class zipimporter:
""" """
return self.find_loader(fullname, path)[0] return self.find_loader(fullname, path)[0]
def get_code(self, fullname): def get_code(self, fullname):
"""get_code(fullname) -> code object. """get_code(fullname) -> code object.
@ -185,6 +182,7 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname) code, ispackage, modpath = _get_module_code(self, fullname)
return code return code
def get_data(self, pathname): def get_data(self, pathname):
"""get_data(pathname) -> string with file data. """get_data(pathname) -> string with file data.
@ -196,14 +194,15 @@ class zipimporter:
key = pathname key = pathname
if pathname.startswith(self.archive + path_sep): if pathname.startswith(self.archive + path_sep):
key = pathname[len(self.archive + path_sep) :] key = pathname[len(self.archive + path_sep):]
try: try:
toc_entry = self._files[key] toc_entry = self._files[key]
except KeyError: except KeyError:
raise OSError(0, "", key) raise OSError(0, '', key)
return _get_data(self.archive, toc_entry) return _get_data(self.archive, toc_entry)
# Return a string matching __file__ for the named module # Return a string matching __file__ for the named module
def get_filename(self, fullname): def get_filename(self, fullname):
"""get_filename(fullname) -> filename string. """get_filename(fullname) -> filename string.
@ -215,6 +214,7 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname) code, ispackage, modpath = _get_module_code(self, fullname)
return modpath return modpath
def get_source(self, fullname): def get_source(self, fullname):
"""get_source(fullname) -> source string. """get_source(fullname) -> source string.
@ -228,9 +228,9 @@ class zipimporter:
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
if mi: if mi:
fullpath = _bootstrap_external._path_join(path, "__init__.py") fullpath = _bootstrap_external._path_join(path, '__init__.py')
else: else:
fullpath = f"{path}.py" fullpath = f'{path}.py'
try: try:
toc_entry = self._files[fullpath] toc_entry = self._files[fullpath]
@ -239,6 +239,7 @@ class zipimporter:
return None return None
return _get_data(self.archive, toc_entry).decode() return _get_data(self.archive, toc_entry).decode()
# Return a bool signifying whether the module is a package or not. # Return a bool signifying whether the module is a package or not.
def is_package(self, fullname): def is_package(self, fullname):
"""is_package(fullname) -> bool. """is_package(fullname) -> bool.
@ -251,6 +252,7 @@ class zipimporter:
raise ZipImportError(f"can't find module {fullname!r}", name=fullname) raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
return mi return mi
# Load and return the module named by 'fullname'. # Load and return the module named by 'fullname'.
def load_module(self, fullname): def load_module(self, fullname):
"""load_module(fullname) -> module. """load_module(fullname) -> module.
@ -274,7 +276,7 @@ class zipimporter:
fullpath = _bootstrap_external._path_join(self.archive, path) fullpath = _bootstrap_external._path_join(self.archive, path)
mod.__path__ = [fullpath] mod.__path__ = [fullpath]
if not hasattr(mod, "__builtins__"): if not hasattr(mod, '__builtins__'):
mod.__builtins__ = __builtins__ mod.__builtins__ = __builtins__
_bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath) _bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath)
exec(code, mod.__dict__) exec(code, mod.__dict__)
@ -285,10 +287,11 @@ class zipimporter:
try: try:
mod = sys.modules[fullname] mod = sys.modules[fullname]
except KeyError: except KeyError:
raise ImportError(f"Loaded module {fullname!r} not found in sys.modules") raise ImportError(f'Loaded module {fullname!r} not found in sys.modules')
_bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath) _bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath)
return mod return mod
def get_resource_reader(self, fullname): def get_resource_reader(self, fullname):
"""Return the ResourceReader for a package in a zip file. """Return the ResourceReader for a package in a zip file.
@ -302,11 +305,11 @@ class zipimporter:
return None return None
if not _ZipImportResourceReader._registered: if not _ZipImportResourceReader._registered:
from importlib.abc import ResourceReader from importlib.abc import ResourceReader
ResourceReader.register(_ZipImportResourceReader) ResourceReader.register(_ZipImportResourceReader)
_ZipImportResourceReader._registered = True _ZipImportResourceReader._registered = True
return _ZipImportResourceReader(self, fullname) return _ZipImportResourceReader(self, fullname)
def __repr__(self): def __repr__(self):
return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">' return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">'
@ -317,18 +320,16 @@ class zipimporter:
# are swapped by initzipimport() if we run in optimized mode. Also, # are swapped by initzipimport() if we run in optimized mode. Also,
# '/' is replaced by path_sep there. # '/' is replaced by path_sep there.
_zip_searchorder = ( _zip_searchorder = (
(path_sep + "__init__.pyc", True, True), (path_sep + '__init__.pyc', True, True),
(path_sep + "__init__.py", False, True), (path_sep + '__init__.py', False, True),
(".pyc", True, False), ('.pyc', True, False),
(".py", False, False), ('.py', False, False),
) )
# Given a module name, return the potential file path in the # Given a module name, return the potential file path in the
# archive (without extension). # archive (without extension).
def _get_module_path(self, fullname): def _get_module_path(self, fullname):
return self.prefix + fullname.rpartition(".")[2] return self.prefix + fullname.rpartition('.')[2]
# Does this path represent a directory? # Does this path represent a directory?
def _is_dir(self, path): def _is_dir(self, path):
@ -339,7 +340,6 @@ def _is_dir(self, path):
# If dirpath is present in self._files, we have a directory. # If dirpath is present in self._files, we have a directory.
return dirpath in self._files return dirpath in self._files
# Return some information about a module. # Return some information about a module.
def _get_module_info(self, fullname): def _get_module_info(self, fullname):
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
@ -352,7 +352,6 @@ def _get_module_info(self, fullname):
# implementation # implementation
# _read_directory(archive) -> files dict (new reference) # _read_directory(archive) -> files dict (new reference)
# #
# Given a path to a Zip archive, build a dict, mapping file names # Given a path to a Zip archive, build a dict, mapping file names
@ -375,7 +374,7 @@ def _get_module_info(self, fullname):
# data_size and file_offset are 0. # data_size and file_offset are 0.
def _read_directory(archive): def _read_directory(archive):
try: try:
fp = _io.open(archive, "rb") fp = _io.open(archive, 'rb')
except OSError: except OSError:
raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive)
@ -395,33 +394,36 @@ def _read_directory(archive):
fp.seek(0, 2) fp.seek(0, 2)
file_size = fp.tell() file_size = fp.tell()
except OSError: except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't read Zip file: {archive!r}",
max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0) path=archive)
max_comment_start = max(file_size - MAX_COMMENT_LEN -
END_CENTRAL_DIR_SIZE, 0)
try: try:
fp.seek(max_comment_start) fp.seek(max_comment_start)
data = fp.read() data = fp.read()
except OSError: except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't read Zip file: {archive!r}",
path=archive)
pos = data.rfind(STRING_END_ARCHIVE) pos = data.rfind(STRING_END_ARCHIVE)
if pos < 0: if pos < 0:
raise ZipImportError(f"not a Zip file: {archive!r}", path=archive) raise ZipImportError(f'not a Zip file: {archive!r}',
buffer = data[pos : pos + END_CENTRAL_DIR_SIZE] path=archive)
buffer = data[pos:pos+END_CENTRAL_DIR_SIZE]
if len(buffer) != END_CENTRAL_DIR_SIZE: if len(buffer) != END_CENTRAL_DIR_SIZE:
raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive) raise ZipImportError(f"corrupt Zip file: {archive!r}",
path=archive)
header_position = file_size - len(data) + pos header_position = file_size - len(data) + pos
header_size = _unpack_uint32(buffer[12:16]) header_size = _unpack_uint32(buffer[12:16])
header_offset = _unpack_uint32(buffer[16:20]) header_offset = _unpack_uint32(buffer[16:20])
if header_position < header_size: if header_position < header_size:
raise ZipImportError(f"bad central directory size: {archive!r}", path=archive) raise ZipImportError(f'bad central directory size: {archive!r}', path=archive)
if header_position < header_offset: if header_position < header_offset:
raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive) raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive)
header_position -= header_size header_position -= header_size
arc_offset = header_position - header_offset arc_offset = header_position - header_offset
if arc_offset < 0: if arc_offset < 0:
raise ZipImportError( raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive)
f"bad central directory size or offset: {archive!r}", path=archive
)
files = {} files = {}
# Start of Central Directory # Start of Central Directory
@ -433,12 +435,12 @@ def _read_directory(archive):
while True: while True:
buffer = fp.read(46) buffer = fp.read(46)
if len(buffer) < 4: if len(buffer) < 4:
raise EOFError("EOF read where not expected") raise EOFError('EOF read where not expected')
# Start of file header # Start of file header
if buffer[:4] != b"PK\x01\x02": if buffer[:4] != b'PK\x01\x02':
break # Bad: Central Dir File Header break # Bad: Central Dir File Header
if len(buffer) != 46: if len(buffer) != 46:
raise EOFError("EOF read where not expected") raise EOFError('EOF read where not expected')
flags = _unpack_uint16(buffer[8:10]) flags = _unpack_uint16(buffer[8:10])
compress = _unpack_uint16(buffer[10:12]) compress = _unpack_uint16(buffer[10:12])
time = _unpack_uint16(buffer[12:14]) time = _unpack_uint16(buffer[12:14])
@ -452,7 +454,7 @@ def _read_directory(archive):
file_offset = _unpack_uint32(buffer[42:46]) file_offset = _unpack_uint32(buffer[42:46])
header_size = name_size + extra_size + comment_size header_size = name_size + extra_size + comment_size
if file_offset > header_offset: if file_offset > header_offset:
raise ZipImportError(f"bad local header offset: {archive!r}", path=archive) raise ZipImportError(f'bad local header offset: {archive!r}', path=archive)
file_offset += arc_offset file_offset += arc_offset
try: try:
@ -476,19 +478,18 @@ def _read_directory(archive):
else: else:
# Historical ZIP filename encoding # Historical ZIP filename encoding
try: try:
name = name.decode("ascii") name = name.decode('ascii')
except UnicodeDecodeError: except UnicodeDecodeError:
name = name.decode("latin1").translate(cp437_table) name = name.decode('latin1').translate(cp437_table)
name = name.replace("/", path_sep) name = name.replace('/', path_sep)
path = _bootstrap_external._path_join(archive, name) path = _bootstrap_external._path_join(archive, name)
t = (path, compress, data_size, file_size, file_offset, time, date, crc) t = (path, compress, data_size, file_size, file_offset, time, date, crc)
files[name] = t files[name] = t
count += 1 count += 1
_bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive) _bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive)
return files return files
# During bootstrap, we may need to load the encodings # During bootstrap, we may need to load the encodings
# package from a ZIP file. But the cp437 encoding is implemented # package from a ZIP file. But the cp437 encoding is implemented
# in Python in the encodings package. # in Python in the encodings package.
@ -497,36 +498,35 @@ def _read_directory(archive):
# the cp437 encoding. # the cp437 encoding.
cp437_table = ( cp437_table = (
# ASCII part, 8 rows x 16 chars # ASCII part, 8 rows x 16 chars
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f" '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f" '\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f'
" !\"#$%&'()*+,-./" ' !"#$%&\'()*+,-./'
"0123456789:;<=>?" '0123456789:;<=>?'
"@ABCDEFGHIJKLMNO" '@ABCDEFGHIJKLMNO'
"PQRSTUVWXYZ[\\]^_" 'PQRSTUVWXYZ[\\]^_'
"`abcdefghijklmno" '`abcdefghijklmno'
"pqrstuvwxyz{|}~\x7f" 'pqrstuvwxyz{|}~\x7f'
# non-ASCII part, 16 rows x 8 chars # non-ASCII part, 16 rows x 8 chars
"\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7" '\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7'
"\xea\xeb\xe8\xef\xee\xec\xc4\xc5" '\xea\xeb\xe8\xef\xee\xec\xc4\xc5'
"\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9" '\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9'
"\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192" '\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192'
"\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba" '\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba'
"\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb" '\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb'
"\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556" '\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556'
"\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510" '\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510'
"\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f" '\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f'
"\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567" '\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567'
"\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b" '\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b'
"\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580" '\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580'
"\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4" '\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4'
"\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229" '\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229'
"\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248" '\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248'
"\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0" '\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0'
) )
_importing_zlib = False _importing_zlib = False
# Return the zlib.decompress function object, or NULL if zlib couldn't # Return the zlib.decompress function object, or NULL if zlib couldn't
# be imported. The function is cached when found, so subsequent calls # be imported. The function is cached when found, so subsequent calls
# don't import zlib again. # don't import zlib again.
@ -535,29 +535,28 @@ def _get_decompress_func():
if _importing_zlib: if _importing_zlib:
# Someone has a zlib.py[co] in their Zip file # Someone has a zlib.py[co] in their Zip file
# let's avoid a stack overflow. # let's avoid a stack overflow.
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
raise ZipImportError("can't decompress data; zlib not available") raise ZipImportError("can't decompress data; zlib not available")
_importing_zlib = True _importing_zlib = True
try: try:
from zlib import decompress from zlib import decompress
except Exception: except Exception:
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
raise ZipImportError("can't decompress data; zlib not available") raise ZipImportError("can't decompress data; zlib not available")
finally: finally:
_importing_zlib = False _importing_zlib = False
_bootstrap._verbose_message("zipimport: zlib available") _bootstrap._verbose_message('zipimport: zlib available')
return decompress return decompress
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data. # Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
def _get_data(archive, toc_entry): def _get_data(archive, toc_entry):
datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry
if data_size < 0: if data_size < 0:
raise ZipImportError("negative data size") raise ZipImportError('negative data size')
with _io.open(archive, "rb") as fp: with _io.open(archive, 'rb') as fp:
# Check to make sure the local file header is correct # Check to make sure the local file header is correct
try: try:
fp.seek(file_offset) fp.seek(file_offset)
@ -565,11 +564,11 @@ def _get_data(archive, toc_entry):
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
buffer = fp.read(30) buffer = fp.read(30)
if len(buffer) != 30: if len(buffer) != 30:
raise EOFError("EOF read where not expected") raise EOFError('EOF read where not expected')
if buffer[:4] != b"PK\x03\x04": if buffer[:4] != b'PK\x03\x04':
# Bad: Local File Header # Bad: Local File Header
raise ZipImportError(f"bad local file header: {archive!r}", path=archive) raise ZipImportError(f'bad local file header: {archive!r}', path=archive)
name_size = _unpack_uint16(buffer[26:28]) name_size = _unpack_uint16(buffer[26:28])
extra_size = _unpack_uint16(buffer[28:30]) extra_size = _unpack_uint16(buffer[28:30])
@ -602,17 +601,16 @@ def _eq_mtime(t1, t2):
# dostime only stores even seconds, so be lenient # dostime only stores even seconds, so be lenient
return abs(t1 - t2) <= 1 return abs(t1 - t2) <= 1
# Given the contents of a .py[co] file, unmarshal the data # Given the contents of a .py[co] file, unmarshal the data
# and return the code object. Return None if it the magic word doesn't # and return the code object. Return None if it the magic word doesn't
# match (we do this instead of raising an exception as we fall back # match (we do this instead of raising an exception as we fall back
# to .py if available and we don't want to mask other errors). # to .py if available and we don't want to mask other errors).
def _unmarshal_code(pathname, data, mtime): def _unmarshal_code(pathname, data, mtime):
if len(data) < 16: if len(data) < 16:
raise ZipImportError("bad pyc data") raise ZipImportError('bad pyc data')
if data[:4] != _bootstrap_external.MAGIC_NUMBER: if data[:4] != _bootstrap_external.MAGIC_NUMBER:
_bootstrap._verbose_message("{!r} has bad magic", pathname) _bootstrap._verbose_message('{!r} has bad magic', pathname)
return None # signal caller to try alternative return None # signal caller to try alternative
flags = _unpack_uint32(data[4:8]) flags = _unpack_uint32(data[4:8])
@ -621,57 +619,47 @@ def _unmarshal_code(pathname, data, mtime):
# pycs. We could validate hash-based pycs against the source, but it # pycs. We could validate hash-based pycs against the source, but it
# seems likely that most people putting hash-based pycs in a zipfile # seems likely that most people putting hash-based pycs in a zipfile
# will use unchecked ones. # will use unchecked ones.
if _imp.check_hash_based_pycs != "never" and ( if (_imp.check_hash_based_pycs != 'never' and
flags != 0x1 or _imp.check_hash_based_pycs == "always" (flags != 0x1 or _imp.check_hash_based_pycs == 'always')):
):
return None return None
elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime): elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime):
_bootstrap._verbose_message("{!r} has bad mtime", pathname) _bootstrap._verbose_message('{!r} has bad mtime', pathname)
return None # signal caller to try alternative return None # signal caller to try alternative
# XXX the pyc's size field is ignored; timestamp collisions are probably # XXX the pyc's size field is ignored; timestamp collisions are probably
# unimportant with zip files. # unimportant with zip files.
code = marshal.loads(data[16:]) code = marshal.loads(data[16:])
if not isinstance(code, _code_type): if not isinstance(code, _code_type):
raise TypeError(f"compiled module {pathname!r} is not a code object") raise TypeError(f'compiled module {pathname!r} is not a code object')
return code return code
_code_type = type(_unmarshal_code.__code__) _code_type = type(_unmarshal_code.__code__)
# Replace any occurrences of '\r\n?' in the input string with '\n'. # Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings. # This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source): def _normalize_line_endings(source):
source = source.replace(b"\r\n", b"\n") source = source.replace(b'\r\n', b'\n')
source = source.replace(b"\r", b"\n") source = source.replace(b'\r', b'\n')
return source return source
# Given a string buffer containing Python source code, compile it # Given a string buffer containing Python source code, compile it
# and return a code object. # and return a code object.
def _compile_source(pathname, source): def _compile_source(pathname, source):
source = _normalize_line_endings(source) source = _normalize_line_endings(source)
return compile(source, pathname, "exec", dont_inherit=True) return compile(source, pathname, 'exec', dont_inherit=True)
# Convert the date/time values found in the Zip archive to a value # Convert the date/time values found in the Zip archive to a value
# that's compatible with the time stamp stored in .pyc files. # that's compatible with the time stamp stored in .pyc files.
def _parse_dostime(d, t): def _parse_dostime(d, t):
return time.mktime( return time.mktime((
( (d >> 9) + 1980, # bits 9..15: year
(d >> 9) + 1980, # bits 9..15: year (d >> 5) & 0xF, # bits 5..8: month
(d >> 5) & 0xF, # bits 5..8: month d & 0x1F, # bits 0..4: day
d & 0x1F, # bits 0..4: day t >> 11, # bits 11..15: hours
t >> 11, # bits 11..15: hours (t >> 5) & 0x3F, # bits 8..10: minutes
(t >> 5) & 0x3F, # bits 8..10: minutes (t & 0x1F) * 2, # bits 0..7: seconds / 2
(t & 0x1F) * 2, # bits 0..7: seconds / 2 -1, -1, -1))
-1,
-1,
-1,
)
)
# Given a path to a .pyc file in the archive, return the # Given a path to a .pyc file in the archive, return the
# modification time of the matching .py file, or 0 if no source # modification time of the matching .py file, or 0 if no source
@ -679,7 +667,7 @@ def _parse_dostime(d, t):
def _get_mtime_of_source(self, path): def _get_mtime_of_source(self, path):
try: try:
# strip 'c' or 'o' from *.py[co] # strip 'c' or 'o' from *.py[co]
assert path[-1:] in ("c", "o") assert path[-1:] in ('c', 'o')
path = path[:-1] path = path[:-1]
toc_entry = self._files[path] toc_entry = self._files[path]
# fetch the time stamp of the .py file for comparison # fetch the time stamp of the .py file for comparison
@ -690,14 +678,13 @@ def _get_mtime_of_source(self, path):
except (KeyError, IndexError, TypeError): except (KeyError, IndexError, TypeError):
return 0 return 0
# Get the code object associated with the module specified by # Get the code object associated with the module specified by
# 'fullname'. # 'fullname'.
def _get_module_code(self, fullname): def _get_module_code(self, fullname):
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
for suffix, isbytecode, ispackage in _zip_searchorder: for suffix, isbytecode, ispackage in _zip_searchorder:
fullpath = path + suffix fullpath = path + suffix
_bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2) _bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2)
try: try:
toc_entry = self._files[fullpath] toc_entry = self._files[fullpath]
except KeyError: except KeyError:
@ -726,7 +713,6 @@ class _ZipImportResourceReader:
This class is allowed to reference all the innards and private parts of This class is allowed to reference all the innards and private parts of
the zipimporter. the zipimporter.
""" """
_registered = False _registered = False
def __init__(self, zipimporter, fullname): def __init__(self, zipimporter, fullname):
@ -734,10 +720,9 @@ class _ZipImportResourceReader:
self.fullname = fullname self.fullname = fullname
def open_resource(self, resource): def open_resource(self, resource):
fullname_as_path = self.fullname.replace(".", "/") fullname_as_path = self.fullname.replace('.', '/')
path = f"{fullname_as_path}/{resource}" path = f'{fullname_as_path}/{resource}'
from io import BytesIO from io import BytesIO
try: try:
return BytesIO(self.zipimporter.get_data(path)) return BytesIO(self.zipimporter.get_data(path))
except OSError: except OSError:
@ -752,8 +737,8 @@ class _ZipImportResourceReader:
def is_resource(self, name): def is_resource(self, name):
# Maybe we could do better, but if we can get the data, it's a # Maybe we could do better, but if we can get the data, it's a
# resource. Otherwise it isn't. # resource. Otherwise it isn't.
fullname_as_path = self.fullname.replace(".", "/") fullname_as_path = self.fullname.replace('.', '/')
path = f"{fullname_as_path}/{name}" path = f'{fullname_as_path}/{name}'
try: try:
self.zipimporter.get_data(path) self.zipimporter.get_data(path)
except OSError: except OSError:
@ -769,12 +754,11 @@ class _ZipImportResourceReader:
# top of the archive, and then we iterate through _files looking for # top of the archive, and then we iterate through _files looking for
# names inside that "directory". # names inside that "directory".
from pathlib import Path from pathlib import Path
fullname_path = Path(self.zipimporter.get_filename(self.fullname)) fullname_path = Path(self.zipimporter.get_filename(self.fullname))
relative_path = fullname_path.relative_to(self.zipimporter.archive) relative_path = fullname_path.relative_to(self.zipimporter.archive)
# Don't forget that fullname names a package, so its path will include # Don't forget that fullname names a package, so its path will include
# __init__.py, which we want to ignore. # __init__.py, which we want to ignore.
assert relative_path.name == "__init__.py" assert relative_path.name == '__init__.py'
package_path = relative_path.parent package_path = relative_path.parent
subdirs_seen = set() subdirs_seen = set()
for filename in self.zipimporter._files: for filename in self.zipimporter._files:

View File

@ -1,3 +1,2 @@
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader from .abc import PluginLoader, PluginClass, IDConflictError, PluginMeta
from .meta import DatabaseType, PluginMeta from .zip import ZippedPluginLoader, MaubotZipImportError
from .zip import MaubotZipImportError, ZippedPluginLoader

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,17 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from attr import dataclass
from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer)
from ..__meta__ import __version__
from ..plugin_base import Plugin from ..plugin_base import Plugin
from .meta import PluginMeta
if TYPE_CHECKING: if TYPE_CHECKING:
from ..instance import PluginInstance from ..instance import PluginInstance
@ -32,40 +35,47 @@ class IDConflictError(Exception):
pass pass
class BasePluginLoader(ABC): @serializer(Version)
meta: PluginMeta def serialize_version(version: Version) -> str:
return str(version)
@property
@abstractmethod
def source(self) -> str:
pass
def sync_read_file(self, path: str) -> bytes:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
async def read_file(self, path: str) -> bytes:
pass
def sync_list_files(self, directory: str) -> list[str]:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
async def list_files(self, directory: str) -> list[str]:
pass
class PluginLoader(BasePluginLoader, ABC): @deserializer(Version)
id_cache: dict[str, PluginLoader] = {} def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs['PluginMeta']):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
meta: PluginMeta meta: PluginMeta
references: set[PluginInstance] references: Set['PluginInstance']
def __init__(self): def __init__(self):
self.references = set() self.references = set()
@classmethod @classmethod
def find(cls, plugin_id: str) -> PluginLoader: def find(cls, plugin_id: str) -> 'PluginLoader':
return cls.id_cache[plugin_id] return cls.id_cache[plugin_id]
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -75,22 +85,33 @@ class PluginLoader(BasePluginLoader, ABC):
"instances": [instance.to_dict() for instance in self.references], "instances": [instance.to_dict() for instance in self.references],
} }
async def stop_instances(self) -> None: @property
await asyncio.gather(
*[instance.stop() for instance in self.references if instance.started]
)
async def start_instances(self) -> None:
await asyncio.gather(
*[instance.start() for instance in self.references if instance.enabled]
)
@abstractmethod @abstractmethod
async def load(self) -> type[PluginClass]: def source(self) -> str:
pass pass
@abstractmethod @abstractmethod
async def reload(self) -> type[PluginClass]: async def read_file(self, path: str) -> bytes:
pass
async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance
in self.references if instance.started])
async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance
in self.references if instance.enabled])
@abstractmethod
async def load(self) -> Type[PluginClass]:
pass
@abstractmethod
async def reload(self) -> Type[PluginClass]:
pass
@abstractmethod
async def unload(self) -> None:
pass pass
@abstractmethod @abstractmethod

View File

@ -1,69 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Optional
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import (
ExtensibleEnum,
SerializableAttrs,
SerializerError,
deserializer,
serializer,
)
from ..__meta__ import __version__
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
class DatabaseType(ExtensibleEnum):
SQLALCHEMY = "sqlalchemy"
ASYNCPG = "asyncpg"
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
database_type: DatabaseType = DatabaseType.SQLALCHEMY
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
@property
def database_type_str(self) -> Optional[str]:
return self.database_type.value if self.database else None

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,27 +13,22 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Dict, List, Type, Tuple, Optional
from zipfile import ZipFile, BadZipFile
from time import time from time import time
from zipfile import BadZipFile, ZipFile
import logging import logging
import os
import sys import sys
import os
from packaging.version import Version
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from packaging.version import Version
from mautrix.client.api.types.util import SerializerError
from mautrix.types import SerializerError from ..lib.zipimport import zipimporter, ZipImportError
from ..__meta__ import __version__
from ..config import Config
from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader from ..config import Config
from .meta import DatabaseType, PluginMeta from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError
current_version = Version(__version__)
yaml = YAML() yaml = YAML()
@ -54,25 +49,23 @@ class MaubotZipLoadError(MaubotZipImportError):
class ZippedPluginLoader(PluginLoader): class ZippedPluginLoader(PluginLoader):
path_cache: dict[str, ZippedPluginLoader] = {} path_cache: Dict[str, 'ZippedPluginLoader'] = {}
log: logging.Logger = logging.getLogger("maubot.loader.zip") log: logging.Logger = logging.getLogger("maubot.loader.zip")
trash_path: str = "delete" trash_path: str = "delete"
directories: list[str] = [] directories: List[str] = []
path: str | None path: str
meta: PluginMeta | None meta: PluginMeta
main_class: str | None main_class: str
main_module: str | None main_module: str
_loaded: type[PluginClass] | None _loaded: Type[PluginClass]
_importer: zipimporter | None _importer: zipimporter
_file: ZipFile | None _file: ZipFile
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__() super().__init__()
self.path = path self.path = path
self.meta = None self.meta = None
self.main_class = None
self.main_module = None
self._loaded = None self._loaded = None
self._importer = None self._importer = None
self._file = None self._file = None
@ -81,8 +74,7 @@ class ZippedPluginLoader(PluginLoader):
try: try:
existing = self.id_cache[self.meta.id] existing = self.id_cache[self.meta.id]
raise IDConflictError( raise IDConflictError(
f"Plugin with id {self.meta.id} already loaded from {existing.source}" f"Plugin with id {self.meta.id} already loaded from {existing.source}")
)
except KeyError: except KeyError:
pass pass
self.path_cache[self.path] = self self.path_cache[self.path] = self
@ -90,10 +82,13 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}") self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}")
def to_dict(self) -> dict: def to_dict(self) -> dict:
return {**super().to_dict(), "path": self.path} return {
**super().to_dict(),
"path": self.path
}
@classmethod @classmethod
def get(cls, path: str) -> ZippedPluginLoader: def get(cls, path: str) -> 'ZippedPluginLoader':
path = os.path.abspath(path) path = os.path.abspath(path)
try: try:
return cls.path_cache[path] return cls.path_cache[path]
@ -105,32 +100,16 @@ class ZippedPluginLoader(PluginLoader):
return self.path return self.path
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return ("<ZippedPlugin "
"<ZippedPlugin " f"path='{self.path}' "
f"path='{self.path}' " f"meta={self.meta} "
f"meta={self.meta} " f"loaded={self._loaded is not None}>")
f"loaded={self._loaded is not None}>"
)
def sync_read_file(self, path: str) -> bytes:
return self._file.read(path)
async def read_file(self, path: str) -> bytes: async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path) return self._file.read(path)
def sync_list_files(self, directory: str) -> list[str]:
directory = directory.rstrip("/")
return [
file.filename
for file in self._file.filelist
if os.path.dirname(file.filename) == directory
]
async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory)
@staticmethod @staticmethod
def _read_meta(source) -> tuple[ZipFile, PluginMeta]: def _read_meta(source) -> Tuple[ZipFile, PluginMeta]:
try: try:
file = ZipFile(source) file = ZipFile(source)
data = file.read("maubot.yaml") data = file.read("maubot.yaml")
@ -148,16 +127,12 @@ class ZippedPluginLoader(PluginLoader):
meta = PluginMeta.deserialize(meta_dict) meta = PluginMeta.deserialize(meta_dict)
except SerializerError as e: except SerializerError as e:
raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e
if meta.maubot > current_version:
raise MaubotZipMetaError(
f"Plugin requires maubot {meta.maubot}, but this instance is {current_version}"
)
return file, meta return file, meta
@classmethod @classmethod
def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]: def verify_meta(cls, source) -> Tuple[str, Version]:
_, meta = cls._read_meta(source) _, meta = cls._read_meta(source)
return meta.id, meta.version, meta.database_type if meta.database else None return meta.id, meta.version
def _load_meta(self) -> None: def _load_meta(self) -> None:
file, meta = self._read_meta(self.path) file, meta = self._read_meta(self.path)
@ -167,7 +142,7 @@ class ZippedPluginLoader(PluginLoader):
if "/" in meta.main_class: if "/" in meta.main_class:
self.main_module, self.main_class = meta.main_class.split("/")[:2] self.main_module, self.main_class = meta.main_class.split("/")[:2]
else: else:
self.main_module = meta.modules[-1] self.main_module = meta.modules[0]
self.main_class = meta.main_class self.main_class = meta.main_class
self._file = file self._file = file
@ -186,24 +161,24 @@ class ZippedPluginLoader(PluginLoader):
code = importer.get_code(self.main_module.replace(".", "/")) code = importer.get_code(self.main_module.replace(".", "/"))
if self.main_class not in code.co_names: if self.main_class not in code.co_names:
raise MaubotZipPreLoadError( raise MaubotZipPreLoadError(
f"Main class {self.main_class} not in {self.main_module}" f"Main class {self.main_class} not in {self.main_module}")
)
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e raise MaubotZipPreLoadError(
f"Main module {self.main_module} not found in file") from e
for module in self.meta.modules: for module in self.meta.modules:
try: try:
importer.find_module(module) importer.find_module(module)
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipPreLoadError(f"Module {module} not found in file") from e raise MaubotZipPreLoadError(f"Module {module} not found in file") from e
async def load(self, reset_cache: bool = False) -> type[PluginClass]: async def load(self, reset_cache: bool = False) -> Type[PluginClass]:
try: try:
return self._load(reset_cache) return self._load(reset_cache)
except MaubotZipImportError: except MaubotZipImportError:
self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}") self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}")
raise raise
def _load(self, reset_cache: bool = False) -> type[PluginClass]: def _load(self, reset_cache: bool = False) -> Type[PluginClass]:
if self._loaded is not None and not reset_cache: if self._loaded is not None and not reset_cache:
return self._loaded return self._loaded
self._load_meta() self._load_meta()
@ -232,18 +207,13 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}") self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}")
return plugin return plugin
async def reload(self, new_path: str | None = None) -> type[PluginClass]: async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
self._unload() await self.unload()
if new_path is not None and new_path != self.path: if new_path is not None:
try:
del self.path_cache[self.path]
except KeyError:
pass
self.path = new_path self.path = new_path
self.path_cache[self.path] = self
return await self.load(reset_cache=True) return await self.load(reset_cache=True)
def _unload(self) -> None: async def unload(self) -> None:
for name, mod in list(sys.modules.items()): for name, mod in list(sys.modules.items()):
if (getattr(mod, "__file__", "") or "").startswith(self.path): if (getattr(mod, "__file__", "") or "").startswith(self.path):
del sys.modules[name] del sys.modules[name]
@ -251,7 +221,7 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}") self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}")
async def delete(self) -> None: async def delete(self) -> None:
self._unload() await self.unload()
try: try:
del self.path_cache[self.path] del self.path_cache[self.path]
except KeyError: except KeyError:
@ -269,22 +239,12 @@ class ZippedPluginLoader(PluginLoader):
self.path = None self.path = None
@classmethod @classmethod
def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None: def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None:
if cls.trash_path == "delete": if cls.trash_path == "delete":
try: os.remove(file_path)
os.remove(file_path)
except FileNotFoundError:
pass
else: else:
new_name = new_name or f"{int(time())}-{reason}-{os.path.basename(file_path)}" new_name = new_name or f"{int(time())}-{reason}-{os.path.basename(file_path)}"
try: os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name)))
os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name)))
except OSError as e:
cls.log.warning(f"Failed to rename {file_path}: {e} - trying to delete")
try:
os.remove(file_path)
except FileNotFoundError:
pass
@classmethod @classmethod
def load_all(cls): def load_all(cls):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
import importlib import importlib
from aiohttp import web
from ...config import Config from ...config import Config
from .base import routes, get_config, set_config, set_loop
from .auth import check_token from .auth import check_token
from .base import get_config, routes, set_config
from .middleware import auth, error from .middleware import auth, error
@ -31,15 +30,14 @@ def features(request: web.Request) -> web.Response:
if err is None: if err is None:
return web.json_response(data) return web.json_response(data)
else: else:
return web.json_response( return web.json_response({
{ "login": data["login"],
"login": data["login"], })
}
)
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
set_config(cfg) set_config(cfg)
set_loop(loop)
for pkg, enabled in cfg["api_features"].items(): for pkg, enabled in cfg["api_features"].items():
if enabled: if enabled:
importlib.import_module(f"maubot.management.api.{pkg}") importlib.import_module(f"maubot.management.api.{pkg}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Optional
from time import time from time import time
from aiohttp import web from aiohttp import web
@ -22,7 +21,7 @@ from aiohttp import web
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token from mautrix.util.signed_token import sign_token, verify_token
from .base import get_config, routes from .base import routes, get_config
from .responses import resp from .responses import resp
@ -34,25 +33,22 @@ def is_valid_token(token: str) -> bool:
def create_token(user: UserID) -> str: def create_token(user: UserID) -> str:
return sign_token( return sign_token(get_config()["server.unshared_secret"], {
get_config()["server.unshared_secret"], "user_id": user,
{ "created_at": int(time()),
"user_id": user, })
"created_at": int(time()),
},
)
def get_token(request: web.Request) -> str: def get_token(request: web.Request) -> str:
token = request.headers.get("Authorization", "") token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "): if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", "") token = request.query.get("access_token", None)
else: else:
token = token[len("Bearer ") :] token = token[len("Bearer "):]
return token return token
def check_token(request: web.Request) -> web.Response | None: def check_token(request: web.Request) -> Optional[web.Response]:
token = get_token(request) token = get_token(request)
if not token: if not token:
return resp.no_token return resp.no_token

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,17 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import asyncio
from aiohttp import web from aiohttp import web
import asyncio
from ...__meta__ import __version__ from ...__meta__ import __version__
from ...config import Config from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef() routes: web.RouteTableDef = web.RouteTableDef()
_config: Config | None = None _config: Config = None
_loop: asyncio.AbstractEventLoop = None
def set_config(config: Config) -> None: def set_config(config: Config) -> None:
@ -35,6 +33,17 @@ def get_config() -> Config:
return _config return _config
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
global _loop
_loop = loop
def get_loop() -> asyncio.AbstractEventLoop:
return _loop
@routes.get("/version") @routes.get("/version")
async def version(_: web.Request) -> web.Response: async def version(_: web.Request) -> web.Response:
return web.json_response({"version": __version__}) return web.json_response({
"version": __version__
})

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,23 +13,20 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Optional
from json import JSONDecodeError from json import JSONDecodeError
import logging
from aiohttp import web from aiohttp import web
from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient from mautrix.client import Client as MatrixClient
from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError
from mautrix.types import FilterID, SyncToken, UserID
from ...db import DBClient
from ...client import Client from ...client import Client
from .base import routes from .base import routes
from .responses import resp from .responses import resp
log = logging.getLogger("maubot.server.client")
@routes.get("/clients") @routes.get("/clients")
async def get_clients(_: web.Request) -> web.Response: async def get_clients(_: web.Request) -> web.Response:
@ -39,94 +36,63 @@ async def get_clients(_: web.Request) -> web.Response:
@routes.get("/client/{id}") @routes.get("/client/{id}")
async def get_client(request: web.Request) -> web.Response: async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = await Client.get(user_id) client = Client.get(user_id, None)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
return resp.found(client.to_dict()) return resp.found(client.to_dict())
async def _create_client(user_id: UserID | None, data: dict) -> web.Response: async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
homeserver = data.get("homeserver", None) homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None) access_token = data.get("access_token", None)
device_id = data.get("device_id", None) new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
new_client = MatrixClient( loop=Client.loop, client_session=Client.http_client)
mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
client_session=Client.http_client,
)
try: try:
whoami = await new_client.whoami() mxid = await new_client.whoami()
except MatrixInvalidToken as e: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
except MatrixRequestError: except MatrixRequestError:
log.warning(f"Failed to get whoami from {homeserver} for new client", exc_info=True)
return resp.bad_client_access_details return resp.bad_client_access_details
except MatrixConnectionError: except MatrixConnectionError:
log.warning(f"Failed to connect to {homeserver} for new client", exc_info=True)
return resp.bad_client_connection_details return resp.bad_client_connection_details
if user_id is None: if user_id is None:
existing_client = await Client.get(whoami.user_id) existing_client = Client.get(mxid, None)
if existing_client is not None: if existing_client is not None:
return resp.user_exists return resp.user_exists
elif whoami.user_id != user_id: elif mxid != user_id:
return resp.mxid_mismatch(whoami.user_id) return resp.mxid_mismatch(mxid)
elif whoami.device_id and device_id and whoami.device_id != device_id: db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
return resp.device_id_mismatch(whoami.device_id) enabled=data.get("enabled", True), next_batch=SyncToken(""),
client = await Client.get( filter_id=FilterID(""), sync=data.get("sync", True),
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id autojoin=data.get("autojoin", True),
) displayname=data.get("displayname", ""),
client.enabled = data.get("enabled", True) avatar_url=data.get("avatar_url", ""))
client.sync = data.get("sync", True) client = Client(db_instance)
client.autojoin = data.get("autojoin", True) client.db_instance.insert()
client.online = data.get("online", True)
client.displayname = data.get("displayname", "disable")
client.avatar_url = data.get("avatar_url", "disable")
await client.update()
await client.start() await client.start()
return resp.created(client.to_dict()) return resp.created(client.to_dict())
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: async def _update_client(client: Client, data: dict) -> web.Response:
try: try:
await client.update_access_details( await client.update_access_details(data.get("access_token", None),
data.get("access_token"), data.get("homeserver"), data.get("device_id") data.get("homeserver", None))
)
except MatrixInvalidToken: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
except MatrixRequestError: except MatrixRequestError:
log.warning(
f"Failed to get whoami from homeserver to update client details", exc_info=True
)
return resp.bad_client_access_details return resp.bad_client_access_details
except MatrixConnectionError: except MatrixConnectionError:
log.warning(f"Failed to connect to homeserver to update client details", exc_info=True)
return resp.bad_client_connection_details return resp.bad_client_connection_details
except ValueError as e: except ValueError as e:
str_err = str(e) return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):])
if str_err.startswith("MXID mismatch"): with client.db_instance.edit_mode():
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :]) await client.update_avatar_url(data.get("avatar_url", None))
elif str_err.startswith("Device ID mismatch"): await client.update_displayname(data.get("displayname", None))
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :]) await client.update_started(data.get("started", None))
await client.update_avatar_url(data.get("avatar_url"), save=False) client.enabled = data.get("enabled", client.enabled)
await client.update_displayname(data.get("displayname"), save=False) client.autojoin = data.get("autojoin", client.autojoin)
await client.update_started(data.get("started")) client.sync = data.get("sync", client.sync)
await client.update_enabled(data.get("enabled"), save=False) return resp.updated(client.to_dict())
await client.update_autojoin(data.get("autojoin"), save=False)
await client.update_online(data.get("online"), save=False)
await client.update_sync(data.get("sync"), save=False)
await client.update()
return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = await Client.get(user_id)
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data, is_login=is_login)
@routes.post("/client/new") @routes.post("/client/new")
@ -140,33 +106,37 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}") @routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response: async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info["id"] user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
return resp.body_not_json return resp.body_not_json
return await _create_or_update_client(user_id, data) if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data)
@routes.delete("/client/{id}") @routes.delete("/client/{id}")
async def delete_client(request: web.Request) -> web.Response: async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info["id"] user_id = request.match_info.get("id", None)
client = await Client.get(user_id) client = Client.get(user_id, None)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
if len(client.references) > 0: if len(client.references) > 0:
return resp.client_in_use return resp.client_in_use
if client.started: if client.started:
await client.stop() await client.stop()
await client.delete() client.delete()
return resp.deleted return resp.deleted
@routes.post("/client/{id}/clearcache") @routes.post("/client/{id}/clearcache")
async def clear_client_cache(request: web.Request) -> web.Response: async def clear_client_cache(request: web.Request) -> web.Response:
user_id = request.match_info["id"] user_id = request.match_info.get("id", None)
client = await Client.get(user_id) client = Client.get(user_id, None)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
await client.clear_cache() client.clear_cache()
return resp.ok return resp.ok

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,261 +13,118 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Dict, Tuple, NamedTuple, Optional
from typing import NamedTuple
from http import HTTPStatus
from json import JSONDecodeError from json import JSONDecodeError
import asyncio from http import HTTPStatus
import hashlib import hashlib
import hmac
import random import random
import string import string
import hmac
from aiohttp import web from aiohttp import web
from yarl import URL from mautrix.api import HTTPAPI, Path, Method
from mautrix.api import Method, Path, SynapseAdminPath
from mautrix.client import ClientAPI
from mautrix.errors import MatrixRequestError from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import get_config, routes from .base import routes, get_config, get_loop
from .client import _create_client, _create_or_update_client
from .responses import resp from .responses import resp
def known_homeservers() -> dict[str, dict[str, str]]: def registration_secrets() -> Dict[str, Dict[str, str]]:
return get_config()["homeservers"] return get_config()["registration_secrets"]
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False):
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(user.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
return mac.hexdigest()
@routes.get("/client/auth/servers") @routes.get("/client/auth/servers")
async def get_known_servers(_: web.Request) -> web.Response: async def get_registerable_servers(_: web.Request) -> web.Response:
return web.json_response({key: value["url"] for key, value in known_homeservers().items()}) return web.json_response({key: value["url"] for key, value in registration_secrets().items()})
class AuthRequestInfo(NamedTuple): AuthRequestInfo = NamedTuple("AuthRequestInfo", api=HTTPAPI, secret=str, username=str,
server_name: str password=str, user_type=str)
client: ClientAPI
secret: str
username: str
password: str
user_type: str
device_name: str
update_client: bool
sso: bool
truthy_strings = ("1", "true", "yes") async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
Optional[web.Response]]:
async def read_client_auth_request(
request: web.Request,
) -> tuple[AuthRequestInfo | None, web.Response | None]:
server_name = request.match_info.get("server", None) server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None) server = registration_secrets().get(server_name, None)
if not server: if not server:
return None, resp.server_not_found return None, resp.server_not_found
try: try:
body = await request.json() body = await request.json()
except JSONDecodeError: except JSONDecodeError:
return None, resp.body_not_json return None, resp.body_not_json
sso = request.query.get("sso", "").lower() in truthy_strings
try: try:
username = body["username"] username = body["username"]
password = body["password"] password = body["password"]
except KeyError: except KeyError:
if not sso: return None, resp.username_or_password_missing
return None, resp.username_or_password_missing
username = password = None
try: try:
base_url = server["url"] base_url = server["url"]
secret = server["secret"]
except KeyError: except KeyError:
return None, resp.invalid_server return None, resp.invalid_server
return ( api = HTTPAPI(base_url, "", loop=get_loop())
AuthRequestInfo( user_type = body.get("user_type", None)
server_name=server_name, return AuthRequestInfo(api, secret, username, password, user_type), None
client=ClientAPI(base_url=base_url),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
),
None,
)
def generate_mac(
secret: str,
nonce: str,
username: str,
password: str,
admin: bool = False,
user_type: str = None,
) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(username.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type is not None:
mac.update(b"\x00")
mac.update(user_type.encode("utf8"))
return mac.hexdigest()
@routes.post("/client/auth/{server}/register") @routes.post("/client/auth/{server}/register")
async def register(request: web.Request) -> web.Response: async def register(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request) info, err = await read_client_auth_request(request)
if err is not None: if err is not None:
return err return err
if req.sso: api, secret, username, password, user_type = info
return resp.registration_no_sso res = await api.request(Method.GET, Path.admin.register)
elif not req.secret: nonce = res["nonce"]
return resp.registration_secret_not_found mac = generate_mac(secret, nonce, username, password)
path = SynapseAdminPath.v1.register
res = await req.client.api.request(Method.GET, path)
content = {
"nonce": res["nonce"],
"username": req.username,
"password": req.password,
"admin": False,
"user_type": req.user_type,
}
content["mac"] = generate_mac(**content, secret=req.secret)
try: try:
raw_res = await req.client.api.request(Method.POST, path, content=content) return web.json_response(await api.request(Method.POST, Path.admin.register, content={
"nonce": nonce,
"username": username,
"password": password,
"admin": False,
"mac": mac,
# Older versions of synapse will ignore this field if it is None
"user_type": user_type,
}))
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response( return web.json_response({
{ "errcode": e.errcode,
"errcode": e.errcode, "error": e.message,
"error": e.message, "http_status": e.http_status,
"http_status": e.http_status, }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
login_res = LoginResponse.deserialize(raw_res)
if req.update_client:
return await _create_client(
login_res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
},
)
return web.json_response(login_res.serialize())
@routes.post("/client/auth/{server}/login") @routes.post("/client/auth/{server}/login")
async def login(request: web.Request) -> web.Response: async def login(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request) info, err = await read_client_auth_request(request)
if err is not None: if err is not None:
return err return err
if req.sso: api, _, username, password, _ = info
return await _do_sso(req) device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
else:
return await _do_login(req)
async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (
URL(cfg["server.public_url"])
/ "_matrix/maubot/v1/client/auth_external_sso/complete"
/ waiter_id
)
sso_url = req.client.api.base_url.with_path(str(Path.v3.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try: try:
if req.sso: return web.json_response(await api.request(Method.POST, Path.login, content={
res = await req.client.login( "type": "m.login.password",
token=login_token, "identifier": {
login_type=LoginType.TOKEN, "type": "m.id.user",
device_id=device_id, "user": username,
store_access_token=False, },
initial_device_display_name=req.device_name, "password": password,
) "device_id": f"maubot_{device_id}",
else: }))
res = await req.client.login(
identifier=req.username,
login_type=LoginType.PASSWORD,
password=req.password,
device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False,
)
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response( return web.json_response({
{ "errcode": e.errcode,
"errcode": e.errcode, "error": e.message,
"error": e.message, }, status=e.http_status)
},
status=e.http_status,
)
if req.update_client:
return await _create_or_update_client(
res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
},
is_login=True,
)
return web.json_response(res.serialize())
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
async def wait_sso(request: web.Request) -> web.Response:
waiter_id = request.match_info["id"]
req, fut = sso_waiters[waiter_id]
try:
login_token = await fut
finally:
sso_waiters.pop(waiter_id, None)
return await _do_login(req, login_token)
@routes.get("/client/auth_external_sso/complete/{id}")
async def complete_sso(request: web.Request) -> web.Response:
try:
_, fut = sso_waiters[request.match_info["id"]]
except KeyError:
return web.Response(status=404, text="Invalid session ID\n")
if fut.cancelled():
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
elif fut.done():
return web.Response(status=200, text="The login token was already received\n")
try:
fut.set_result(request.query["loginToken"])
except KeyError:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(
status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n",
)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import client as http, web from aiohttp import web, client as http
from ...client import Client from ...client import Client
from .base import routes from .base import routes
@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
@routes.view("/proxy/{id}/{path:_matrix/.+}") @routes.view("/proxy/{id}/{path:_matrix/.+}")
async def proxy(request: web.Request) -> web.StreamResponse: async def proxy(request: web.Request) -> web.StreamResponse:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = await Client.get(user_id) client = Client.get(user_id, None)
if not client: if not client:
return resp.client_not_found return resp.client_not_found
@ -45,9 +45,8 @@ async def proxy(request: web.Request) -> web.StreamResponse:
headers["X-Forwarded-For"] = f"{host}:{port}" headers["X-Forwarded-For"] = f"{host}:{port}"
data = await request.read() data = await request.read()
async with http.request( async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers,
request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data params=query, data=data) as proxy_resp:
) as proxy_resp:
response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers) response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers)
await response.prepare(request) await response.prepare(request)
async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE): async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from string import Template from string import Template
import asyncio from subprocess import run
import re import re
from aiohttp import web
from ruamel.yaml import YAML from ruamel.yaml import YAML
from aiohttp import web
from .base import routes from .base import routes
@ -27,7 +27,9 @@ enabled = False
@routes.get("/debug/open") @routes.get("/debug/open")
async def check_enabled(_: web.Request) -> web.Response: async def check_enabled(_: web.Request) -> web.Response:
return web.json_response({"enabled": enabled}) return web.json_response({
"enabled": enabled,
})
try: try:
@ -38,6 +40,7 @@ try:
editor_command = Template(cfg["editor"]) editor_command = Template(cfg["editor"])
pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]] pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]]
@routes.post("/debug/open") @routes.post("/debug/open")
async def open_file(request: web.Request) -> web.Response: async def open_file(request: web.Request) -> web.Response:
data = await request.json() data = await request.json()
@ -48,9 +51,13 @@ try:
cmd = editor_command.substitute(path=path, line=data["line"]) cmd = editor_command.substitute(path=path, line=data["line"])
except (KeyError, ValueError): except (KeyError, ValueError):
return web.Response(status=400) return web.Response(status=400)
res = await asyncio.create_subprocess_shell(cmd) res = run(cmd, shell=True)
stdout, stderr = await res.communicate() return web.json_response({
return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr}) "return": res.returncode,
"stdout": res.stdout,
"stderr": res.stderr
})
enabled = True enabled = True
except Exception: except Exception:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,10 @@ from json import JSONDecodeError
from aiohttp import web from aiohttp import web
from ...client import Client from ...db import DBPlugin
from ...instance import PluginInstance from ...instance import PluginInstance
from ...loader import PluginLoader from ...loader import PluginLoader
from ...client import Client
from .base import routes from .base import routes
from .responses import resp from .responses import resp
@ -31,50 +32,51 @@ async def get_instances(_: web.Request) -> web.Response:
@routes.get("/instance/{id}") @routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response: async def get_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "").lower()
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id, None)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
return resp.found(instance.to_dict()) return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response: async def _create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type") plugin_type = data.get("type", None)
primary_user = data.get("primary_user") primary_user = data.get("primary_user", None)
if not plugin_type: if not plugin_type:
return resp.plugin_type_required return resp.plugin_type_required
elif not primary_user: elif not primary_user:
return resp.primary_user_required return resp.primary_user_required
elif not await Client.get(primary_user): elif not Client.get(primary_user):
return resp.primary_user_not_found return resp.primary_user_not_found
try: try:
PluginLoader.find(plugin_type) PluginLoader.find(plugin_type)
except KeyError: except KeyError:
return resp.plugin_type_not_found return resp.plugin_type_not_found
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user) db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
instance.enabled = data.get("enabled", True) primary_user=primary_user, config=data.get("config", ""))
instance.config_str = data.get("config") or "" instance = PluginInstance(db_instance)
await instance.update() instance.load()
await instance.load() instance.db_instance.insert()
await instance.start() await instance.start()
return resp.created(instance.to_dict()) return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response: async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user")): if not await instance.update_primary_user(data.get("primary_user", None)):
return resp.primary_user_not_found return resp.primary_user_not_found
await instance.update_id(data.get("id")) with instance.db_instance.edit_mode():
await instance.update_enabled(data.get("enabled")) instance.update_id(data.get("id", None))
await instance.update_config(data.get("config")) instance.update_enabled(data.get("enabled", None))
await instance.update_started(data.get("started")) instance.update_config(data.get("config", None))
await instance.update_type(data.get("type")) await instance.update_started(data.get("started", None))
return resp.updated(instance.to_dict()) await instance.update_type(data.get("type", None))
return resp.updated(instance.to_dict())
@routes.put("/instance/{id}") @routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response: async def update_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "").lower()
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id, None)
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
@ -87,11 +89,11 @@ async def update_instance(request: web.Request) -> web.Response:
@routes.delete("/instance/{id}") @routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response: async def delete_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "").lower()
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
if instance.started: if instance.started:
await instance.stop() await instance.stop()
await instance.delete() instance.delete()
return resp.deleted return resp.deleted

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,67 +13,80 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Union, TYPE_CHECKING
from datetime import datetime from datetime import datetime
from aiohttp import web from aiohttp import web
from asyncpg import PostgresError from sqlalchemy import Table, Column, asc, desc, exc
import aiosqlite from sqlalchemy.orm import Query
from sqlalchemy.engine.result import ResultProxy, RowProxy
from mautrix.util.async_db import Database
from ...instance import PluginInstance from ...instance import PluginInstance
from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc
from .base import routes from .base import routes
from .responses import resp from .responses import resp
@routes.get("/instance/{id}/database") @routes.get("/instance/{id}/database")
async def get_database(request: web.Request) -> web.Response: async def get_database(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "")
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id, None)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
return resp.plugin_has_no_database return resp.plugin_has_no_database
return web.json_response(await instance.get_db_tables()) if TYPE_CHECKING:
table: Table
column: Column
return web.json_response({
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
} for column in table.columns
},
} for table in instance.get_db_tables().values()
})
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
@routes.get("/instance/{id}/database/{table}") @routes.get("/instance/{id}/database/{table}")
async def get_table(request: web.Request) -> web.Response: async def get_table(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "")
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id, None)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
return resp.plugin_has_no_database return resp.plugin_has_no_database
tables = await instance.get_db_tables() tables = instance.get_db_tables()
try: try:
table = tables[request.match_info.get("table", "")] table = tables[request.match_info.get("table", "")]
except KeyError: except KeyError:
return resp.table_not_found return resp.table_not_found
try: try:
order = [tuple(order.split(":")) for order in request.query.getall("order")] order = [tuple(order.split(":")) for order in request.query.getall("order")]
order = [ order = [(asc if sort.lower() == "asc" else desc)(table.columns[column])
( if sort else table.columns[column]
(asc if sort.lower() == "asc" else desc)(table.columns[column]) for column, sort in order]
if sort
else table.columns[column]
)
for column, sort in order
]
except KeyError: except KeyError:
order = [] order = []
limit = int(request.query.get("limit", "100")) limit = int(request.query.get("limit", 100))
if isinstance(instance.inst_db, Engine): return execute_query(instance, table.select().order_by(*order).limit(limit))
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query") @routes.post("/instance/{id}/database/query")
async def query(request: web.Request) -> web.Response: async def query(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower() instance_id = request.match_info.get("id", "")
instance = await PluginInstance.get(instance_id) instance = PluginInstance.get(instance_id, None)
if not instance: if not instance:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
@ -83,76 +96,28 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"] sql_query = data["query"]
except KeyError: except KeyError:
return resp.query_missing return resp.query_missing
rows_as_dict = data.get("rows_as_dict", False) return execute_query(instance, sql_query,
if isinstance(instance.inst_db, Engine): rows_as_dict=data.get("rows_as_dict", False))
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
elif isinstance(instance.inst_db, Database):
try:
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
except (PostgresError, aiosqlite.Error) as e:
return resp.sql_error(e, sql_query)
else:
return resp.unsupported_plugin_database
def check_type(val): def execute_query(instance: PluginInstance, sql_query: Union[str, Query],
if isinstance(val, datetime): rows_as_dict: bool = False) -> web.Response:
return val.isoformat()
return val
async def _execute_query_asyncpg(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
data = {"ok": True, "query": sql_query}
if sql_query.upper().startswith("SELECT"):
res = await instance.inst_db.fetch(sql_query)
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
if len(res) > 0:
# TODO can we find column names when there are no rows?
data["columns"] = list(res[0].keys())
else:
res = await instance.inst_db.execute(sql_query)
if isinstance(res, str):
data["status_msg"] = res
elif isinstance(res, aiosqlite.Cursor):
data["rowcount"] = res.rowcount
# data["inserted_primary_key"] = res.lastrowid
else:
data["status_msg"] = "unknown status"
return web.json_response(data)
def _execute_query_sqlalchemy(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
assert isinstance(instance.inst_db, Engine)
try: try:
res = instance.inst_db.execute(sql_query) res: ResultProxy = instance.inst_db.execute(sql_query)
except IntegrityError as e: except exc.IntegrityError as e:
return resp.sql_integrity_error(e, sql_query) return resp.sql_integrity_error(e, sql_query)
except OperationalError as e: except exc.OperationalError as e:
return resp.sql_operational_error(e, sql_query) return resp.sql_operational_error(e, sql_query)
data = { data = {
"ok": True, "ok": True,
"query": str(sql_query), "query": str(sql_query),
} }
if res.returns_rows: if res.returns_rows:
data["rows"] = [ row: RowProxy
( data["rows"] = [({key: check_type(value) for key, value in row.items()}
{key: check_type(value) for key, value in row.items()} if rows_as_dict
if rows_as_dict else [check_type(value) for value in row])
else [check_type(value) for value in row] for row in res]
)
for row in res
]
data["columns"] = res.keys() data["columns"] = res.keys()
else: else:
data["rowcount"] = res.rowcount data["rowcount"] = res.rowcount

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,63 +13,31 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from typing import Deque, List
from collections import deque
from datetime import datetime from datetime import datetime
import asyncio from collections import deque
import logging import logging
import asyncio
from aiohttp import web, web_ws from aiohttp import web
from mautrix.util import background_task
from .base import routes, get_loop
from .auth import is_valid_token from .auth import is_valid_token
from .base import routes
BUILTIN_ATTRS = { BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName",
"args", "levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name",
"asctime", "pathname", "process", "processName", "relativeCreated", "stack_info", "thread",
"created", "threadName"}
"exc_info", INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name",
"exc_text", "pathname"}
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
INCLUDE_ATTRS = {
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"name",
"pathname",
}
EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS
MAX_LINES = 2048 MAX_LINES = 2048
class LogCollector(logging.Handler): class LogCollector(logging.Handler):
lines: deque[dict] lines: Deque[dict]
formatter: logging.Formatter formatter: logging.Formatter
listeners: list[web.WebSocketResponse] listeners: List[web.WebSocketResponse]
loop: asyncio.AbstractEventLoop
def __init__(self, level=logging.NOTSET) -> None: def __init__(self, level=logging.NOTSET) -> None:
super().__init__(level) super().__init__(level)
@ -87,7 +55,9 @@ class LogCollector(logging.Handler):
# JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license) # JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license)
# https://github.com/marselester/json-log-formatter # https://github.com/marselester/json-log-formatter
content = { content = {
name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS name: value
for name, value in record.__dict__.items()
if name not in EXCLUDE_ATTRS
} }
content["id"] = str(record.relativeCreated) content["id"] = str(record.relativeCreated)
content["msg"] = record.getMessage() content["msg"] = record.getMessage()
@ -99,7 +69,7 @@ class LogCollector(logging.Handler):
for name, value in content.items(): for name, value in content.items():
if isinstance(value, datetime): if isinstance(value, datetime):
content[name] = value.astimezone().isoformat() content[name] = value.astimezone().isoformat()
asyncio.run_coroutine_threadsafe(self.send(content), loop=self.loop) asyncio.ensure_future(self.send(content))
self.lines.append(content) self.lines.append(content)
async def send(self, record: dict) -> None: async def send(self, record: dict) -> None:
@ -111,18 +81,17 @@ class LogCollector(logging.Handler):
handler = LogCollector() handler = LogCollector()
log_root = logging.getLogger("maubot")
log = logging.getLogger("maubot.server.websocket") log = logging.getLogger("maubot.server.websocket")
sockets = [] sockets = []
def init(loop: asyncio.AbstractEventLoop) -> None: def init() -> None:
logging.root.addHandler(handler) log_root.addHandler(handler)
handler.loop = loop
async def stop_all() -> None: async def stop_all() -> None:
log.debug("Closing log listener websockets") log_root.removeHandler(handler)
logging.root.removeHandler(handler)
for socket in sockets: for socket in sockets:
try: try:
await socket.close(code=1012) await socket.close(code=1012)
@ -139,15 +108,14 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
authenticated = False authenticated = False
async def close_if_not_authenticated(): async def close_if_not_authenticated():
await asyncio.sleep(5) await asyncio.sleep(5, loop=get_loop())
if not authenticated: if not authenticated:
await ws.close(code=4000) await ws.close(code=4000)
log.debug(f"Connection from {request.remote} terminated due to no authentication") log.debug(f"Connection from {request.remote} terminated due to no authentication")
background_task.create(close_if_not_authenticated()) asyncio.ensure_future(close_if_not_authenticated())
try: try:
msg: web_ws.WSMessage
async for msg in ws: async for msg in ws:
if msg.type != web.WSMsgType.TEXT: if msg.type != web.WSMsgType.TEXT:
continue continue

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,9 @@ import json
from aiohttp import web from aiohttp import web
from .auth import create_token from .base import routes, get_config
from .base import get_config, routes
from .responses import resp from .responses import resp
from .auth import create_token
@routes.post("/auth/login") @routes.post("/auth/login")
async def login(request: web.Request) -> web.Response: async def login(request: web.Request) -> web.Response:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable from typing import Callable, Awaitable
import base64
import logging import logging
from aiohttp import web from aiohttp import web
from .responses import resp
from .auth import check_token from .auth import check_token
from .base import get_config from .base import get_config
from .responses import resp
Handler = Callable[[web.Request], Awaitable[web.Response]] Handler = Callable[[web.Request], Awaitable[web.Response]]
log = logging.getLogger("maubot.server") log = logging.getLogger("maubot.server")
@ -29,13 +28,8 @@ log = logging.getLogger("maubot.server")
@web.middleware @web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response: async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len("/_matrix/maubot/v1") :] subpath = request.path[len(get_config()["server.base_path"]):]
if ( if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs":
subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/")
or subpath == "/features"
or subpath == "/logs"
):
return await handler(request) return await handler(request)
err = check_token(request) err = check_token(request)
if err is not None: if err is not None:
@ -52,18 +46,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
return resp.path_not_found return resp.path_not_found
elif ex.status_code == 405: elif ex.status_code == 405:
return resp.method_not_allowed return resp.method_not_allowed
return web.json_response( return web.json_response({
{ "error": f"Unhandled HTTP {ex.status}",
"httpexception": { "errcode": f"unhandled_http_{ex.status}",
"headers": {key: value for key, value in ex.headers.items()}, }, status=ex.status)
"class": type(ex).__name__,
"body": ex.text or base64.b64encode(ex.body),
},
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}",
"errcode": f"unhandled_http_{ex.status}",
},
status=ex.status,
)
except Exception: except Exception:
log.exception("Error in handler") log.exception("Error in handler")
return resp.internal_server_error return resp.internal_server_error

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,9 @@ import traceback
from aiohttp import web from aiohttp import web
from ...loader import MaubotZipImportError, PluginLoader from ...loader import PluginLoader, MaubotZipImportError
from .base import routes
from .responses import resp from .responses import resp
from .base import routes
@routes.get("/plugins") @routes.get("/plugins")
@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
@routes.get("/plugin/{id}") @routes.get("/plugin/{id}")
async def get_plugin(request: web.Request) -> web.Response: async def get_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"] plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id) plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found
return resp.found(plugin.to_dict()) return resp.found(plugin.to_dict())
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
@routes.delete("/plugin/{id}") @routes.delete("/plugin/{id}")
async def delete_plugin(request: web.Request) -> web.Response: async def delete_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"] plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id) plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found
elif len(plugin.references) > 0: elif len(plugin.references) > 0:
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
@routes.post("/plugin/{id}/reload") @routes.post("/plugin/{id}/reload")
async def reload_plugin(request: web.Request) -> web.Response: async def reload_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"] plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id) plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin: if not plugin:
return resp.plugin_not_found return resp.plugin_not_found

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -15,39 +15,27 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from io import BytesIO from io import BytesIO
from time import time from time import time
import logging import traceback
import os.path import os.path
import re import re
import traceback
from aiohttp import web from aiohttp import web
from packaging.version import Version from packaging.version import Version
from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError
from .base import get_config, routes
from .responses import resp from .responses import resp
from .base import routes, get_config
try:
import sqlalchemy
has_alchemy = True
except ImportError:
has_alchemy = False
log = logging.getLogger("maubot.server.upload")
@routes.put("/plugin/{id}") @routes.put("/plugin/{id}")
async def put_plugin(request: web.Request) -> web.Response: async def put_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"] plugin_id = request.match_info.get("id", None)
content = await request.read() content = await request.read()
file = BytesIO(content) file = BytesIO(content)
try: try:
pid, version, db_type = ZippedPluginLoader.verify_meta(file) pid, version = ZippedPluginLoader.verify_meta(file)
except MaubotZipImportError as e: except MaubotZipImportError as e:
return resp.plugin_import_error(str(e), traceback.format_exc()) return resp.plugin_import_error(str(e), traceback.format_exc())
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
return resp.sqlalchemy_not_installed
if pid != plugin_id: if pid != plugin_id:
return resp.pid_mismatch return resp.pid_mismatch
plugin = PluginLoader.id_cache.get(plugin_id, None) plugin = PluginLoader.id_cache.get(plugin_id, None)
@ -64,11 +52,9 @@ async def upload_plugin(request: web.Request) -> web.Response:
content = await request.read() content = await request.read()
file = BytesIO(content) file = BytesIO(content)
try: try:
pid, version, db_type = ZippedPluginLoader.verify_meta(file) pid, version = ZippedPluginLoader.verify_meta(file)
except MaubotZipImportError as e: except MaubotZipImportError as e:
return resp.plugin_import_error(str(e), traceback.format_exc()) return resp.plugin_import_error(str(e), traceback.format_exc())
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
return resp.sqlalchemy_not_installed
plugin = PluginLoader.id_cache.get(pid, None) plugin = PluginLoader.id_cache.get(pid, None)
if not plugin: if not plugin:
return await upload_new_plugin(content, pid, version) return await upload_new_plugin(content, pid, version)
@ -92,20 +78,15 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R
return resp.created(plugin.to_dict()) return resp.created(plugin.to_dict())
async def upload_replacement_plugin( async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
plugin: ZippedPluginLoader, content: bytes, new_version: Version new_version: Version) -> web.Response:
) -> web.Response:
dirname = os.path.dirname(plugin.path) dirname = os.path.dirname(plugin.path)
old_filename = os.path.basename(plugin.path) old_filename = os.path.basename(plugin.path)
if str(plugin.meta.version) in old_filename: if str(plugin.meta.version) in old_filename:
replacement = ( replacement = (str(new_version) if plugin.meta.version != new_version
str(new_version) else f"{new_version}-ts{int(time())}")
if plugin.meta.version != new_version filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?",
else f"{new_version}-ts{int(time() * 1000)}" replacement, old_filename)
)
filename = re.sub(
f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename
)
else: else:
filename = old_filename.rstrip(".mbp") filename = old_filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp" filename = f"{filename}-v{new_version}.mbp"
@ -117,29 +98,12 @@ async def upload_replacement_plugin(
try: try:
await plugin.reload(new_path=path) await plugin.reload(new_path=path)
except MaubotZipImportError as e: except MaubotZipImportError as e:
log.exception(f"Error loading updated version of {plugin.meta.id}, rolling back")
try: try:
await plugin.reload(new_path=old_path) await plugin.reload(new_path=old_path)
await plugin.start_instances() await plugin.start_instances()
except MaubotZipImportError: except MaubotZipImportError:
log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True) pass
finally:
ZippedPluginLoader.trash(path, reason="failed_update")
return resp.plugin_import_error(str(e), traceback.format_exc()) return resp.plugin_import_error(str(e), traceback.format_exc())
try: await plugin.start_instances()
await plugin.start_instances()
except Exception as e:
log.exception(f"Error starting {plugin.meta.id} instances after update, rolling back")
try:
await plugin.stop_instances()
await plugin.reload(new_path=old_path)
await plugin.start_instances()
except Exception:
log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True)
finally:
ZippedPluginLoader.trash(path, reason="failed_update")
return resp.plugin_reload_error(str(e), traceback.format_exc())
log.debug(f"Successfully updated {plugin.meta.id}, moving old version to trash")
ZippedPluginLoader.trash(old_path, reason="update") ZippedPluginLoader.trash(old_path, reason="update")
return resp.updated(plugin.to_dict()) return resp.updated(plugin.to_dict())

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan # Copyright (C) 2019 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,457 +13,271 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING
from http import HTTPStatus from http import HTTPStatus
from aiohttp import web from aiohttp import web
from asyncpg import PostgresError from sqlalchemy.exc import OperationalError, IntegrityError
import aiosqlite
if TYPE_CHECKING:
from sqlalchemy.exc import IntegrityError, OperationalError
class _Response: class _Response:
@property @property
def body_not_json(self) -> web.Response: def body_not_json(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Request body is not JSON",
"error": "Request body is not JSON", "errcode": "body_not_json",
"errcode": "body_not_json", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def plugin_type_required(self) -> web.Response: def plugin_type_required(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Plugin type is required when creating plugin instances",
"error": "Plugin type is required when creating plugin instances", "errcode": "plugin_type_required",
"errcode": "plugin_type_required", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def primary_user_required(self) -> web.Response: def primary_user_required(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Primary user is required when creating plugin instances",
"error": "Primary user is required when creating plugin instances", "errcode": "primary_user_required",
"errcode": "primary_user_required", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_access_token(self) -> web.Response: def bad_client_access_token(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Invalid access token",
"error": "Invalid access token", "errcode": "bad_client_access_token",
"errcode": "bad_client_access_token", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_access_details(self) -> web.Response: def bad_client_access_details(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Invalid homeserver or access token",
"error": "Invalid homeserver or access token", "errcode": "bad_client_access_details"
"errcode": "bad_client_access_details", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_connection_details(self) -> web.Response: def bad_client_connection_details(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Could not connect to homeserver",
"error": "Could not connect to homeserver", "errcode": "bad_client_connection_details"
"errcode": "bad_client_connection_details", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
def mxid_mismatch(self, found: str) -> web.Response: def mxid_mismatch(self, found: str) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "The Matrix user ID of the client and the user ID of the access token don't "
"error": ( f"match. Access token is for user {found}",
"The Matrix user ID of the client and the user ID of the access token don't " "errcode": "mxid_mismatch",
f"match. Access token is for user {found}" }, status=HTTPStatus.BAD_REQUEST)
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response(
{
"error": (
"The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def pid_mismatch(self) -> web.Response: def pid_mismatch(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "The ID in the path does not match the ID of the uploaded plugin",
"error": "The ID in the path does not match the ID of the uploaded plugin", "errcode": "pid_mismatch",
"errcode": "pid_mismatch", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def username_or_password_missing(self) -> web.Response: def username_or_password_missing(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Username or password missing",
"error": "Username or password missing", "errcode": "username_or_password_missing",
"errcode": "username_or_password_missing", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def query_missing(self) -> web.Response: def query_missing(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Query missing",
"error": "Query missing", "errcode": "query_missing",
"errcode": "query_missing", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod
def sql_error(error: PostgresError | aiosqlite.Error, query: str) -> web.Response:
return web.json_response(
{
"ok": False,
"query": query,
"error": str(error),
"errcode": "sql_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def sql_operational_error(error: OperationalError, query: str) -> web.Response: def sql_operational_error(error: OperationalError, query: str) -> web.Response:
return web.json_response( return web.json_response({
{ "ok": False,
"ok": False, "query": query,
"query": query, "error": str(error.orig),
"error": str(error.orig), "full_error": str(error),
"full_error": str(error), "errcode": "sql_operational_error",
"errcode": "sql_operational_error", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def sql_integrity_error(error: IntegrityError, query: str) -> web.Response: def sql_integrity_error(error: IntegrityError, query: str) -> web.Response:
return web.json_response( return web.json_response({
{ "ok": False,
"ok": False, "query": query,
"query": query, "error": str(error.orig),
"error": str(error.orig), "full_error": str(error),
"full_error": str(error), "errcode": "sql_integrity_error",
"errcode": "sql_integrity_error", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_auth(self) -> web.Response: def bad_auth(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Invalid username or password",
"error": "Invalid username or password", "errcode": "invalid_auth",
"errcode": "invalid_auth", }, status=HTTPStatus.UNAUTHORIZED)
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def no_token(self) -> web.Response: def no_token(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Authorization token missing",
"error": "Authorization token missing", "errcode": "auth_token_missing",
"errcode": "auth_token_missing", }, status=HTTPStatus.UNAUTHORIZED)
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def invalid_token(self) -> web.Response: def invalid_token(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Invalid authorization token",
"error": "Invalid authorization token", "errcode": "auth_token_invalid",
"errcode": "auth_token_invalid", }, status=HTTPStatus.UNAUTHORIZED)
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def plugin_not_found(self) -> web.Response: def plugin_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Plugin not found",
"error": "Plugin not found", "errcode": "plugin_not_found",
"errcode": "plugin_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def client_not_found(self) -> web.Response: def client_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Client not found",
"error": "Client not found", "errcode": "client_not_found",
"errcode": "client_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def primary_user_not_found(self) -> web.Response: def primary_user_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Client for given primary user not found",
"error": "Client for given primary user not found", "errcode": "primary_user_not_found",
"errcode": "primary_user_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def instance_not_found(self) -> web.Response: def instance_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Plugin instance not found",
"error": "Plugin instance not found", "errcode": "instance_not_found",
"errcode": "instance_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def plugin_type_not_found(self) -> web.Response: def plugin_type_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Given plugin type not found",
"error": "Given plugin type not found", "errcode": "plugin_type_not_found",
"errcode": "plugin_type_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def path_not_found(self) -> web.Response: def path_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Resource not found",
"error": "Resource not found", "errcode": "resource_not_found",
"errcode": "resource_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def server_not_found(self) -> web.Response: def server_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Registration target server not found",
"error": "Registration target server not found", "errcode": "server_not_found",
"errcode": "server_not_found", }, status=HTTPStatus.NOT_FOUND)
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_secret_not_found(self) -> web.Response:
return web.json_response(
{
"error": "Config does not have a registration secret for that server",
"errcode": "registration_secret_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_no_sso(self) -> web.Response:
return web.json_response(
{
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def sso_not_supported(self) -> web.Response:
return web.json_response(
{
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
},
status=HTTPStatus.FORBIDDEN,
)
@property @property
def plugin_has_no_database(self) -> web.Response: def plugin_has_no_database(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Given plugin does not have a database",
"error": "Given plugin does not have a database", "errcode": "plugin_has_no_database",
"errcode": "plugin_has_no_database", })
}
)
@property
def unsupported_plugin_database(self) -> web.Response:
return web.json_response(
{
"error": "The database type is not supported by this API",
"errcode": "unsupported_plugin_database",
}
)
@property
def sqlalchemy_not_installed(self) -> web.Response:
return web.json_response(
{
"error": "This plugin requires a legacy database, but SQLAlchemy is not installed",
"errcode": "unsupported_plugin_database",
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
@property @property
def table_not_found(self) -> web.Response: def table_not_found(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Given table not found in plugin database",
"error": "Given table not found in plugin database", "errcode": "table_not_found",
"errcode": "table_not_found", })
}
)
@property @property
def method_not_allowed(self) -> web.Response: def method_not_allowed(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Method not allowed",
"error": "Method not allowed", "errcode": "method_not_allowed",
"errcode": "method_not_allowed", }, status=HTTPStatus.METHOD_NOT_ALLOWED)
},
status=HTTPStatus.METHOD_NOT_ALLOWED,
)
@property @property
def user_exists(self) -> web.Response: def user_exists(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "There is already a client with the user ID of that token",
"error": "There is already a client with the user ID of that token", "errcode": "user_exists",
"errcode": "user_exists", }, status=HTTPStatus.CONFLICT)
},
status=HTTPStatus.CONFLICT,
)
@property @property
def plugin_exists(self) -> web.Response: def plugin_exists(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "A plugin with the same ID as the uploaded plugin already exists",
"error": "A plugin with the same ID as the uploaded plugin already exists", "errcode": "plugin_exists"
"errcode": "plugin_exists", }, status=HTTPStatus.CONFLICT)
},
status=HTTPStatus.CONFLICT,
)
@property @property
def plugin_in_use(self) -> web.Response: def plugin_in_use(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Plugin instances of this type still exist",
"error": "Plugin instances of this type still exist", "errcode": "plugin_in_use",
"errcode": "plugin_in_use", }, status=HTTPStatus.PRECONDITION_FAILED)
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@property @property
def client_in_use(self) -> web.Response: def client_in_use(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Plugin instances with this client as their primary user still exist",
"error": "Plugin instances with this client as their primary user still exist", "errcode": "client_in_use",
"errcode": "client_in_use", }, status=HTTPStatus.PRECONDITION_FAILED)
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@staticmethod @staticmethod
def plugin_import_error(error: str, stacktrace: str) -> web.Response: def plugin_import_error(error: str, stacktrace: str) -> web.Response:
return web.json_response( return web.json_response({
{ "error": error,
"error": error, "stacktrace": stacktrace,
"stacktrace": stacktrace, "errcode": "plugin_invalid",
"errcode": "plugin_invalid", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def plugin_reload_error(error: str, stacktrace: str) -> web.Response: def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
return web.json_response( return web.json_response({
{ "error": error,
"error": error, "stacktrace": stacktrace,
"stacktrace": stacktrace, "errcode": "plugin_reload_fail",
"errcode": "plugin_reload_fail", }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def internal_server_error(self) -> web.Response: def internal_server_error(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Internal server error",
"error": "Internal server error", "errcode": "internal_server_error",
"errcode": "internal_server_error", }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def invalid_server(self) -> web.Response: def invalid_server(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Invalid registration server object in maubot configuration",
"error": "Invalid registration server object in maubot configuration", "errcode": "invalid_server",
"errcode": "invalid_server", }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def unsupported_plugin_loader(self) -> web.Response: def unsupported_plugin_loader(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Existing plugin with same ID uses unsupported plugin loader",
"error": "Existing plugin with same ID uses unsupported plugin loader", "errcode": "unsupported_plugin_loader",
"errcode": "unsupported_plugin_loader", }, status=HTTPStatus.BAD_REQUEST)
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def not_implemented(self) -> web.Response: def not_implemented(self) -> web.Response:
return web.json_response( return web.json_response({
{ "error": "Not implemented",
"error": "Not implemented", "errcode": "not_implemented",
"errcode": "not_implemented", }, status=HTTPStatus.NOT_IMPLEMENTED)
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
@property @property
def ok(self) -> web.Response: def ok(self) -> web.Response:
return web.json_response( return web.json_response({
{"success": True}, "success": True,
status=HTTPStatus.OK, }, status=HTTPStatus.OK)
)
@property @property
def deleted(self) -> web.Response: def deleted(self) -> web.Response:
@ -473,15 +287,19 @@ class _Response:
def found(data: dict) -> web.Response: def found(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK) return web.json_response(data, status=HTTPStatus.OK)
@staticmethod def updated(self, data: dict) -> web.Response:
def updated(data: dict, is_login: bool = False) -> web.Response: return self.found(data)
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def logged_in(self, token: str) -> web.Response: def logged_in(self, token: str) -> web.Response:
return self.found({"token": token}) return self.found({
"token": token,
})
def pong(self, user: str, features: dict) -> web.Response: def pong(self, user: str, features: dict) -> web.Response:
return self.found({"username": user, "features": features}) return self.found({
"username": user,
"features": features,
})
@staticmethod @staticmethod
def created(data: dict) -> web.Response: def created(data: dict) -> web.Response:

View File

@ -1,2 +1,93 @@
# Maubot Management API # Maubot Management API
This document has been moved to docs.mau.fi: <https://docs.mau.fi/maubot/management-api.html> Most of the API is simple HTTP+JSON and has OpenAPI documentation (see
[spec.yaml](spec.yaml), [rendered](https://maubot.xyz/spec/)). However,
some parts of the API aren't documented in the OpenAPI document.
## Matrix API proxy
The full Matrix API can be accessed for each client with a request to
`/_matrix/maubot/v1/proxy/<user>/<path>`. `<user>` is the Matrix user
ID of the user to access the API as and `<path>` is the whole API
path to access (e.g. `/_matrix/client/r0/whoami`).
The body, headers, query parameters, etc are sent to the Matrix server
as-is, with a few exceptions:
* The `Authorization` header will be replaced with the access token
for the Matrix user from the maubot database.
* The `access_token` query parameter will be removed.
## Log viewing
1. Open websocket to `/_matrix/maubot/v1/logs`.
2. Send authentication token as a plain string.
3. Server will respond with `{"auth_success": true}` and then with
`{"history": ...}` where `...` is a list of log entries.
4. Server will send new log entries as JSON.
### Log entry object format
Log entries are a JSON-serialized form of Python log records.
Log entries will always have:
* `id` - A string that should uniquely identify the row. Currently
uses the `relativeCreated` field of Python logging records.
* `msg` - The text in the entry.
* `time` - The ISO date when the log entry was created.
Log entries should also always have:
* `levelname` - The log level (e.g. `DEBUG` or `ERROR`).
* `levelno` - The integer log level.
* `name` - The name of the logger. Common values:
* `maubot.client.<mxid>` - Client loggers (Matrix HTTP requests)
* `maubot.instance.<id>` - Plugin instance loggers
* `maubot.loader.zip` - The zip plugin loader (plugins don't
have their own logs)
* `module` - The Python module name where the log call happened.
* `pathname` - The full path of the file where the log call happened.
* `filename` - The file name portion of `pathname`
* `lineno` - The line in code where the log call happened.
* `funcName` - The name of the function where the log call happened.
Log entries might have:
* `exc_info` - The formatted exception info if an exception was logged.
* `matrix_http_request` - The info about a Matrix HTTP request. Subfields:
* `method` - The HTTP method used.
* `path` - The API path used.
* `content` - The content sent.
* `user` - The appservice user who the request was ran as.
## Debug file open
For debug and development purposes, the API and frontend support
clicking on lines in stack traces to open that line in the selected
editor.
### Configuration
First, the directory where maubot is run from must have a
`.dev-open-cfg.yaml` file. The file should contain the following
fields:
* `editor` - The command to run to open a file.
* `$path` is replaced with the full file path.
* `$line` is replaced with the line number.
* `pathmap` - A list of find-and-replaces to execute on paths.
These are needed to map things like `.mbp` files to the extracted
sources on disk. Each pathmap entry should have:
* `find` - The regex to match.
* `replace` - The replacement. May insert capture groups with Python
syntax (e.g. `\1`)
Example file:
```yaml
editor: pycharm --line $line $path
pathmap:
- find: "maubot/plugins/xyz\\.maubot\\.(.+)-v.+(?:-ts[0-9]+)?.mbp"
replace: "mbplugins/\\1"
- find: "maubot/.venv/lib/python3.6/site-packages/mautrix"
replace: "mautrix-python/mautrix"
```
### API
Clients can `GET /_matrix/maubot/v1/debug/open` to check if the file
open endpoint has been set up. The response is a JSON object with a
single field `enabled`. If the value is true, the endpoint can be used.
To open files, clients can `POST /_matrix/maubot/v1/debug/open` with
a JSON body containing
* `path` - The full file path to open
* `line` - The line number to open

View File

@ -366,7 +366,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/MatrixClient' $ref: '#/components/schemas/MatrixClient'
responses: responses:
202: 200:
description: Client updated description: Client updated
content: content:
application/json: application/json:
@ -454,12 +454,6 @@ paths:
required: true required: true
schema: schema:
type: string type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post: post:
operationId: client_auth_register operationId: client_auth_register
summary: | summary: |
@ -481,29 +475,18 @@ paths:
properties: properties:
access_token: access_token:
type: string type: string
example: syt_123_456_789 example: token_here
user_id: user_id:
type: string type: string
example: '@putkiteippi:maunium.net' example: '@putkiteippi:maunium.net'
home_server:
type: string
example: maunium.net
device_id: device_id:
type: string type: string
example: maubot_F00BAR12 example: device_id_here
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401: 401:
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
409:
description: |
There is already a client with the user ID of that token.
This should usually not happen, because the user ID was just created.
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
500: 500:
$ref: '#/components/responses/MatrixServerError' $ref: '#/components/responses/MatrixServerError'
'/client/auth/{server}/login': '/client/auth/{server}/login':
@ -514,12 +497,6 @@ paths:
required: true required: true
schema: schema:
type: string type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post: post:
operationId: client_auth_login operationId: client_auth_login
summary: Log in to the given Matrix server via the maubot server summary: Log in to the given Matrix server via the maubot server
@ -542,22 +519,10 @@ paths:
example: '@putkiteippi:maunium.net' example: '@putkiteippi:maunium.net'
access_token: access_token:
type: string type: string
example: syt_123_456_789 example: token_here
device_id: device_id:
type: string type: string
example: maubot_F00BAR12 example: device_id_here
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
202:
description: Client updated (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401: 401:
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
500: 500:
@ -676,12 +641,6 @@ components:
access_token: access_token:
type: string type: string
description: The Matrix access token for this client. description: The Matrix access token for this client.
device_id:
type: string
description: The Matrix device ID corresponding to the access token.
fingerprint:
type: string
description: The encryption device fingerprint for verification.
enabled: enabled:
type: boolean type: boolean
example: true example: true

View File

@ -1,6 +1,6 @@
{ {
"name": "maubot-manager", "name": "maubot-manager",
"version": "0.1.1", "version": "0.1.0",
"private": true, "private": true,
"author": "Tulir Asokan", "author": "Tulir Asokan",
"license": "AGPL-3.0-or-later", "license": "AGPL-3.0-or-later",
@ -13,15 +13,15 @@
}, },
"homepage": ".", "homepage": ".",
"dependencies": { "dependencies": {
"react": "^17.0.2", "node-sass": "^4.12.0",
"react-ace": "^9.4.1", "react": "^16.8.6",
"react-contextmenu": "^2.14.0", "react-ace": "^8.0.0",
"react-dom": "^17.0.2", "react-contextmenu": "^2.11.0",
"react-json-tree": "^0.16.1", "react-dom": "^16.8.6",
"react-router-dom": "^5.3.0", "react-json-tree": "^0.11.2",
"react-scripts": "5.0.0", "react-router-dom": "^5.0.1",
"react-select": "^5.2.1", "react-scripts": "3.3.0",
"sass": "^1.34.1" "react-select": "^3.0.4"
}, },
"scripts": { "scripts": {
"start": "react-scripts start", "start": "react-scripts start",
@ -30,11 +30,15 @@
"eject": "react-scripts eject" "eject": "react-scripts eject"
}, },
"browserslist": [ "browserslist": [
"last 2 firefox versions", "last 5 firefox versions",
"last 2 and_ff versions", "last 3 and_ff versions",
"last 2 chrome versions", "last 5 chrome versions",
"last 2 and_chr versions", "last 3 and_chr versions",
"last 1 safari versions", "last 2 safari versions",
"last 1 ios_saf versions" "last 2 ios_saf versions"
] ],
"devDependencies": {
"sass-lint": "^1.13.1",
"sass-lint-auto-fix": "^0.17.0"
}
} }

View File

@ -1,6 +1,6 @@
<!-- <!--
maubot - A plugin-based Matrix bot system. maubot - A plugin-based Matrix bot system.
Copyright (C) 2022 Tulir Asokan Copyright (C) 2019 Tulir Asokan
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,7 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
let BASE_PATH = "/_matrix/maubot/v1" export const BASE_PATH = "/_matrix/maubot/v1"
export function setBasePath(basePath) {
BASE_PATH = basePath
}
function getHeaders(contentType = "application/json") { function getHeaders(contentType = "application/json") {
return { return {
@ -214,10 +210,10 @@ export async function uploadAvatar(id, data, mime) {
} }
export function getAvatarURL({ id, avatar_url }) { export function getAvatarURL({ id, avatar_url }) {
if (!avatar_url?.startsWith("mxc://")) { avatar_url = avatar_url || ""
return null if (avatar_url.startsWith("mxc://")) {
avatar_url = avatar_url.substr("mxc://".length)
} }
avatar_url = avatar_url.substr("mxc://".length)
return `${BASE_PATH}/proxy/${id}/_matrix/media/r0/download/${avatar_url}?access_token=${ return `${BASE_PATH}/proxy/${id}/_matrix/media/r0/download/${avatar_url}?access_token=${
localStorage.accessToken}` localStorage.accessToken}`
} }
@ -244,9 +240,9 @@ export async function doClientAuth(server, type, username, password) {
return await resp.json() return await resp.json()
} }
// eslint-disable-next-line import/no-anonymous-default-export
export default { export default {
login, ping, setBasePath, getFeatures, remoteGetFeatures, BASE_PATH,
login, ping, getFeatures, remoteGetFeatures,
openLogSocket, openLogSocket,
debugOpenFile, debugOpenFileEnabled, updateDebugOpenFileEnabled, debugOpenFile, debugOpenFileEnabled, updateDebugOpenFileEnabled,
getInstances, getInstance, putInstance, deleteInstance, getInstances, getInstance, putInstance, deleteInstance,

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -57,9 +57,7 @@ export const PrefSwitch = ({ rowName, active, origActive, fullWidth = false, ...
</PrefRow> </PrefRow>
) )
export const PrefSelect = ({ export const PrefSelect = ({ rowName, value, origValue, fullWidth = false, creatable = false, ...args }) => (
rowName, value, origValue, fullWidth = false, creatable = false, ...args
}) => (
<PrefRow name={rowName} fullWidth={fullWidth} labelFor={rowName} <PrefRow name={rowName} fullWidth={fullWidth} labelFor={rowName}
changed={origValue !== undefined && value.id !== origValue}> changed={origValue !== undefined && value.id !== origValue}>
{creatable {creatable

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -23,12 +23,10 @@ class Switch extends Component {
} }
} }
componentDidUpdate(prevProps) { componentWillReceiveProps(nextProps) {
if (prevProps.active !== this.props.active) { this.setState({
this.setState({ active: nextProps.active,
active: this.props.active, })
})
}
} }
toggle = () => { toggle = () => {

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2019 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -30,8 +30,7 @@ class Main extends Component {
} }
} }
async componentDidMount() { async componentWillMount() {
await this.getBasePath()
if (localStorage.accessToken) { if (localStorage.accessToken) {
await this.ping() await this.ping()
} else { } else {
@ -40,19 +39,6 @@ class Main extends Component {
this.setState({ pinged: true }) this.setState({ pinged: true })
} }
async getBasePath() {
try {
const resp = await fetch(process.env.PUBLIC_URL + "/paths.json", {
headers: { "Content-Type": "application/json" },
})
const apiPaths = await resp.json()
api.setBasePath(apiPaths.api_path)
} catch (err) {
console.error("Failed to get API path:", err)
}
}
async ping() { async ping() {
try { try {
const username = await api.ping() const username = await api.ping()

Some files were not shown because too many files have changed in this diff Show More