Compare commits

..

1 Commits

Author SHA1 Message Date
Tulir Asokan
92c9072257 Add WIP example of database usage in bot 2019-01-16 11:40:51 +02:00
156 changed files with 11183 additions and 13718 deletions

View File

@ -1,5 +0,0 @@
.editorconfig
.codeclimate.yml
*.png
.venv
maubot/management/frontend/node_modules

View File

@ -14,6 +14,3 @@ indent_size = 2
[spec.yaml]
indent_size = 2
[CHANGELOG.md]
max_line_length = 80

View File

@ -1,26 +0,0 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- uses: isort/isort-action@master
with:
sortPaths: "./maubot"
- uses: psf/black@stable
with:
src: "./maubot"
version: "24.2.0"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-yaml
pre-commit run -av check-added-large-files

5
.gitignore vendored
View File

@ -7,13 +7,10 @@ pip-selfcheck.json
*.pyc
__pycache__
*.db*
*.log
*.db
/*.yaml
!example-config.yaml
!.pre-commit-config.yaml
/start
logs/
plugins/
trash/

View File

@ -1,29 +0,0 @@
image: dock.mau.dev/maubot/maubot
stages:
- build
variables:
PYTHONPATH: /opt/maubot
build:
stage: build
except:
- tags
script:
- python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_REF_NAME-$CI_COMMIT_SHORT_SHA.mbp
artifacts:
paths:
- "*.mbp"
expire_in: 365 days
build tags:
stage: build
only:
- tags
script:
- python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_TAG.mbp
artifacts:
paths:
- "*.mbp"
expire_in: never

View File

@ -1,75 +0,0 @@
image: docker:stable
stages:
- build frontend
- build
- manifest
default:
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
build frontend:
image: node:20-alpine
stage: build frontend
before_script: []
variables:
NODE_ENV: "production"
cache:
paths:
- maubot/management/frontend/node_modules
script:
- cd maubot/management/frontend
- yarn --prod
- yarn build
- mv build ../../../frontend
artifacts:
paths:
- frontend
expire_in: 1 hour
build amd64:
stage: build
tags:
- amd64
script:
- echo maubot/management/frontend >> .dockerignore
- docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 . -f Dockerfile.ci
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
build arm64:
stage: build
tags:
- arm64
script:
- echo maubot/management/frontend >> .dockerignore
- docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 . -f Dockerfile.ci
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
manifest:
stage: manifest
before_script:
- "mkdir -p $HOME/.docker && echo '{\"experimental\": \"enabled\"}' > $HOME/.docker/config.json"
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
- if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi
- if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
build standalone amd64:
stage: build
tags:
- amd64
script:
- docker pull $CI_REGISTRY_IMAGE:standalone || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:standalone --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone . -f maubot/standalone/Dockerfile
- if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi
- if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi
- docker rmi $CI_REGISTRY_IMAGE:standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone || true

View File

@ -1,20 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
language_version: python3
files: ^maubot/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
files: ^maubot/.*\.pyi?$

View File

@ -1,137 +0,0 @@
# v0.5.0 (2024-08-24)
* Dropped Python 3.9 support.
* Updated Docker image to Alpine 3.20.
* Updated mautrix-python to 0.20.6 to support authenticated media.
* Removed hard dependency on SQLAlchemy.
* Fixed `main_class` to default to being loaded from the last module instead of
the first if a module name is not explicitly specified.
* This was already the [documented behavior](https://docs.mau.fi/maubot/dev/reference/plugin-metadata.html),
and loading from the first module doesn't make sense due to import order.
* Added simple scheduler utility for running background tasks periodically or
after a certain delay.
* Added testing framework for plugins (thanks to [@abompard] in [#225]).
* Changed `mbc build` to ignore directories declared in `modules` that are
missing an `__init__.py` file.
* Importing the modules at runtime would fail and break the plugin.
To include non-code resources outside modules in the mbp archive,
use `extra_files` instead.
[#225]: https://github.com/maubot/maubot/issues/225
[@abompard]: https://github.com/abompard
# v0.4.2 (2023-09-20)
* Updated Pillow to 10.0.1.
* Updated Docker image to Alpine 3.18.
* Added logging for errors for /whoami errors when adding new bot accounts.
* Added support for using appservice tokens (including appservice encryption)
in standalone mode.
# v0.4.1 (2023-03-15)
* Added `in_thread` parameter to `evt.reply()` and `evt.respond()`.
* By default, responses will go to the thread if the command is in a thread.
* By setting the flag to `True` or `False`, the plugin can force the response
to either be or not be in a thread.
* Fixed static files like the frontend app manifest not being served correctly.
* Fixed `self.loader.meta` not being available to plugins in standalone mode.
* Updated to mautrix-python v0.19.6.
# v0.4.0 (2023-01-29)
* Dropped support for using a custom maubot API base path.
* The public URL can still have a path prefix, e.g. when using a reverse
proxy. Both the web interface and `mbc` CLI tool should work fine with
custom prefixes.
* Added `evt.redact()` as a shortcut for `self.client.redact(evt.room_id, evt.event_id)`.
* Fixed `mbc logs` command not working on Python 3.8+.
* Fixed saving plugin configs (broke in v0.3.0).
* Fixed SSO login using the wrong API path (probably broke in v0.3.0).
* Stopped using `cd` in the docker image's `mbc` wrapper to enable using
path-dependent commands like `mbc build` by mounting a directory.
* Updated Docker image to Alpine 3.17.
# v0.3.1 (2022-03-29)
* Added encryption dependencies to standalone dockerfile.
* Fixed running without encryption dependencies installed.
* Removed unnecessary imports that broke on SQLAlchemy 1.4+.
* Removed unused alembic dependency.
# v0.3.0 (2022-03-28)
* Dropped Python 3.7 support.
* Switched main maubot database to asyncpg/aiosqlite.
* Using the same SQLite database for crypto is now safe again.
* Added support for asyncpg/aiosqlite for plugin databases.
* There are some [basic docs](https://docs.mau.fi/maubot/dev/database/index.html)
and [a simple example](./examples/database) for the new system.
* The old SQLAlchemy system is now deprecated, but will be preserved for
backwards-compatibility until most plugins have updated.
* Started enforcing minimum maubot version in plugins.
* Trying to upload a plugin where the specified version is higher than the
running maubot version will fail.
* Fixed bug where uploading a plugin twice, deleting it and trying to upload
again would fail.
* Updated Docker image to Alpine 3.15.
* Formatted all code using [black](https://github.com/psf/black)
and [isort](https://github.com/PyCQA/isort).
# v0.2.1 (2021-11-22)
Docker-only release: added automatic moving of plugin databases from
`/data/plugins/*.db` to `/data/dbs`
# v0.2.0 (2021-11-20)
* Moved plugin databases from `/data/plugins` to `/data/dbs` in the docker image.
* v0.2.0 was missing the automatic migration of databases, it was added in v0.2.1.
* If you were using a custom path, you'll have to mount it at `/data/dbs` or
move the databases yourself.
* Removed support for pickle crypto store and added support for SQLite crypto store.
* **If you were previously using the dangerous pickle store for e2ee, you'll
have to re-login with the bots (which can now be done conveniently with
`mbc auth --update-client`).**
* Added SSO support to `mbc auth`.
* Added support for setting device ID for e2ee using the web interface.
* Added e2ee fingerprint field to the web interface.
* Added `--update-client` flag to store access token inside maubot instead of
returning it in `mbc auth`.
* This will also automatically store the device ID now.
* Updated standalone mode.
* Added e2ee and web server support.
* It's now officially supported and [somewhat documented](https://docs.mau.fi/maubot/usage/standalone.html).
* Replaced `_` with `-` when generating command name from function name.
* Replaced unmaintained PyInquirer dependency with questionary
(thanks to [@TinfoilSubmarine] in [#139]).
* Updated Docker image to Alpine 3.14.
* Fixed avatar URLs without the `mxc://` prefix appearing like they work in the
frontend, but not actually working when saved.
[@TinfoilSubmarine]: https://github.com/TinfoilSubmarine
[#139]: https://github.com/maubot/maubot/pull/139
# v0.1.2 (2021-06-12)
* Added `loader` instance property for plugins to allow reading files within
the plugin archive.
* Added support for reloading `webapp` and `database` meta flags in plugins.
Previously you had to restart maubot instead of just reloading the plugin
when enabling the webapp or database for the first time.
* Added warning log if a plugin uses `@web` decorators without enabling the
`webapp` meta flag.
* Updated frontend to latest React and dependency versions.
* Updated Docker image to Alpine 3.13.
* Fixed registering accounts with Synapse shared secret registration.
* Fixed plugins using `get_event` in encrypted rooms.
* Fixed using the `@command.new` decorator without specifying a name
(i.e. falling back to the function name).
# v0.1.1 (2021-05-02)
No changelog.
# v0.1.0 (2020-10-04)
Initial tagged release.

View File

@ -1,60 +1,28 @@
FROM node:20 AS frontend-builder
FROM node:10 AS frontend-builder
COPY ./maubot/management/frontend /frontend
RUN cd /frontend && yarn --prod && yarn build
FROM alpine:3.20
FROM alpine:3.8
ENV UID=1337 \
GID=1337
COPY . /opt/maubot
COPY --from=frontend-builder /frontend/build /opt/maubot/frontend
WORKDIR /opt/maubot
RUN apk add --no-cache \
python3 py3-pip py3-setuptools py3-wheel \
ca-certificates \
su-exec \
yq \
py3-aiohttp \
py3-sqlalchemy \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
py3-packaging \
py3-markdown \
py3-alembic \
py3-cssselect \
py3-commonmark \
py3-pygments \
py3-tz \
py3-regex \
py3-wcwidth \
# encryption
py3-cffi \
py3-olm \
py3-pycryptodome \
py3-unpaddedbase64 \
py3-future \
# plugin deps
py3-pillow \
py3-magic \
py3-feedparser \
py3-dateutil \
py3-lxml \
py3-semver
# TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies
build-base \
python3-dev \
ca-certificates \
su-exec \
&& pip3 install -r requirements.txt
COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery tzlocal \
&& apk del .build-deps
# TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies
COPY . /opt/maubot
RUN cp maubot/example-config.yaml .
COPY ./docker/mbc.sh /usr/local/bin/mbc
COPY --from=frontend-builder /frontend/build /opt/maubot/frontend
ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data
VOLUME /data
CMD ["/opt/maubot/docker/run.sh"]

View File

@ -1,55 +0,0 @@
FROM alpine:3.20
RUN apk add --no-cache \
python3 py3-pip py3-setuptools py3-wheel \
ca-certificates \
su-exec \
yq \
py3-aiohttp \
py3-attrs \
py3-bcrypt \
py3-cffi \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
py3-packaging \
py3-markdown \
py3-alembic \
# py3-cssselect \
py3-commonmark \
py3-pygments \
py3-tz \
# py3-tzlocal \
py3-regex \
py3-wcwidth \
# encryption
py3-cffi \
py3-olm \
py3-pycryptodome \
py3-unpaddedbase64 \
py3-future \
# plugin deps
py3-pillow \
py3-magic \
py3-feedparser \
py3-lxml
# py3-gitlab
# py3-semver
# TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies
COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
&& pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery semver tzlocal cssselect \
&& apk del .build-deps
# TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies
COPY . /opt/maubot
RUN cp /opt/maubot/maubot/example-config.yaml /opt/maubot
COPY ./docker/mbc.sh /usr/local/bin/mbc
ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data
VOLUME /data
CMD ["/opt/maubot/docker/run.sh"]

View File

@ -1,5 +0,0 @@
include README.md
include CHANGELOG.md
include LICENSE
include requirements.txt
include optional-requirements.txt

View File

@ -1,29 +1,28 @@
# maubot
![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg)
[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE)
[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases)
[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry)
[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
A plugin-based [Matrix](https://matrix.org) bot system written in Python.
## Documentation
### [Wiki](https://github.com/maubot/maubot/wiki)
All setup and usage instructions are located on
[docs.mau.fi](https://docs.mau.fi/maubot/index.html). Some quick links:
* [Setup](https://docs.mau.fi/maubot/usage/setup/index.html)
(or [with Docker](https://docs.mau.fi/maubot/usage/setup/docker.html))
* [Basic usage](https://docs.mau.fi/maubot/usage/basic.html)
* [Encryption](https://docs.mau.fi/maubot/usage/encryption.html)
### [Management API spec](https://github.com/maubot/maubot/blob/master/maubot/management/api/spec.md)
## Discussion
Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net)
## Plugins
A list of plugins can be found at [plugins.mau.bot](https://plugins.mau.bot/).
* [jesaribot](https://github.com/maubot/jesaribot) - A simple bot that replies with an image when you say "jesari".
* [sed](https://github.com/maubot/sed) - A bot to do sed-like replacements.
* [factorial](https://github.com/maubot/factorial) - A bot to calculate unexpected factorials.
* [media](https://github.com/maubot/media) - A bot that replies with the MXC URI of images you send it.
* [dice](https://github.com/maubot/dice) - A combined dice rolling and calculator bot.
* [karma](https://github.com/maubot/karma) - A user karma tracker bot.
* [xkcd](https://github.com/maubot/xkcd) - A bot to view xkcd comics.
* [echo](https://github.com/maubot/echo) - A bot that echoes pings and other stuff.
* [rss](https://github.com/maubot/rss) - A bot that posts RSS feed updates to Matrix.
To add your plugin to the list, send a pull request to <https://github.com/maubot/plugins.maubot.xyz>.
The plugin wishlist lives at <https://github.com/maubot/plugin-wishlist/issues>.
### Upcoming
* dictionary - A bot to get the dictionary definitions of words.
* poll - A simple poll bot.
* reminder - A bot to ping you about something after a certain amount of time.
* github - A GitHub client and webhook receiver bot.
* wolfram - A Wolfram Alpha bot
* gitlab - A GitLab client and webhook receiver bot.

View File

@ -1,3 +0,0 @@
pre-commit>=2.10.1,<3
isort>=5.10.1,<6
black>=24,<25

View File

@ -0,0 +1,92 @@
# The full URI to the database. SQLite and Postgres are fully supported.
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:////data/maubot.db
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: /data/plugins
# The directories from which plugins should be loaded.
# Duplicate plugin IDs will be moved to the trash.
load:
- /data/plugins
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: /data/trash
# The directory where plugin databases should be stored.
db: /data/plugins
server:
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# The base management API path.
base_path: /_matrix/maubot/v1
# The base path for the UI.
ui_base_path: /_matrix/maubot
# Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path.
override_resource_path: /opt/maubot/frontend
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
appservice_base_path: /_matrix/app/v1
# The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# Shared registration secrets to allow registering new users from the management UI
registration_secrets:
example.com:
# Client-server API URL
url: https://example.com
# registration_shared_secret from synapse config
secret: synapse_shared_registration_secret
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
root: ""
# API feature switches.
api_features:
login: true
plugin: true
plugin_upload: true
instance: true
instance_database: true
client: true
client_proxy: true
client_auth: true
dev_open: true
log: true
# Python logging configuration.
#
# See section 16.7.2 of the Python documentation for more info:
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
logging:
version: 1
formatters:
precise:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: /var/log/maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: precise
loggers:
maubot:
level: DEBUG
mautrix:
level: DEBUG
aiohttp:
level: INFO
root:
level: DEBUG
handlers: [file, console]

View File

@ -1,3 +0,0 @@
#!/bin/sh
export PYTHONPATH=/opt/maubot
python3 -m maubot.cli "$@"

View File

@ -1,46 +1,21 @@
#!/bin/sh
function fixperms {
chown -R $UID:$GID /var/log /data
}
function fixdefault {
_value=$(yq e "$1" /data/config.yaml)
if [[ "$_value" == "$2" ]]; then
yq e -i "$1 = "'"'"$3"'"' /data/config.yaml
fi
}
function fixconfig {
# Change relative default paths to absolute paths in /data
fixdefault '.database' 'sqlite:maubot.db' 'sqlite:/data/maubot.db'
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
# This doesn't need to be configurable
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
chown -R $UID:$GID /var/log /data /opt/maubot
}
cd /opt/maubot
mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs
if [ ! -f /data/config.yaml ]; then
cp example-config.yaml /data/config.yaml
cp docker/example-config.yaml /data/config.yaml
mkdir -p /var/log /data/plugins /data/trash /data/dbs
echo "Config file not found. Example config copied to /data/config.yaml"
echo "Please modify the config file to your liking and restart the container."
fixperms
fixconfig
exit
fi
mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs
#alembic -x config=/data/config.yaml upgrade head
fixperms
fixconfig
if ls /data/plugins/*.db > /dev/null 2>&1; then
mv -n /data/plugins/*.db /data/dbs/
fi
exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml
exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml -b docker/example-config.yaml

92
example-config.yaml Normal file
View File

@ -0,0 +1,92 @@
# The full URI to the database. SQLite and Postgres are fully supported.
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:///maubot.db
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
# The directories from which plugins should be loaded.
# Duplicate plugin IDs will be moved to the trash.
load:
- ./plugins
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: ./trash
# The directory where plugin databases should be stored.
db: ./plugins
server:
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# The base management API path.
base_path: /_matrix/maubot/v1
# The base path for the UI.
ui_base_path: /_matrix/maubot
# Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path.
override_resource_path: false
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
appservice_base_path: /_matrix/app/v1
# The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# Shared registration secrets to allow registering new users from the management UI
registration_secrets:
example.com:
# Client-server API URL
url: https://example.com
# registration_shared_secret from synapse config
secret: synapse_shared_registration_secret
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
root: ""
# API feature switches.
api_features:
login: true
plugin: true
plugin_upload: true
instance: true
instance_database: true
client: true
client_proxy: true
client_auth: true
dev_open: true
log: true
# Python logging configuration.
#
# See section 16.7.2 of the Python documentation for more info:
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
logging:
version: 1
formatters:
precise:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ./logs/maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: precise
loggers:
maubot:
level: DEBUG
mautrix:
level: DEBUG
aiohttp:
level: INFO
root:
level: DEBUG
handlers: [file, console]

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2022 Tulir Asokan
Copyright (c) 2018 Tulir Asokan
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in

View File

@ -1,5 +1,2 @@
# Who is allowed to use the bot?
whitelist:
- "@user:example.com"
# The prefix for the main command without the !
command_prefix: hello-world
# Message to send when user sends !getmessage
message: Default configuration active

View File

@ -1,4 +1,5 @@
from typing import Type
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from maubot import Plugin, MessageEvent
from maubot.handlers import command
@ -6,22 +7,19 @@ from maubot.handlers import command
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("whitelist")
helper.copy("command_prefix")
helper.copy("message")
class ConfigurableBot(Plugin):
class DatabaseBot(Plugin):
async def start(self) -> None:
await super().start()
self.config.load_and_update()
def get_command_name(self) -> str:
return self.config["command_prefix"]
@command.new(name=get_command_name)
async def hmm(self, evt: MessageEvent) -> None:
if evt.sender in self.config["whitelist"]:
await evt.reply("You're whitelisted 🎉")
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
@command.new("getmessage")
async def handler(self, event: MessageEvent) -> None:
if event.sender != self.client.mxid:
await event.reply(self.config["message"])

View File

@ -1,12 +1,11 @@
maubot: 0.1.0
id: xyz.maubot.configurablebot
version: 2.0.0
id: xyz.maubot.databasebot
version: 1.0.0
license: MIT
modules:
- configurablebot
main_class: ConfigurableBot
database: false
config: true
# Instruct the build tool to include the base config.
extra_files:

View File

@ -0,0 +1,39 @@
from typing import Type
from sqlalchemy import Column, String, Text, Table, MetaData, orm, func
from mautrix.types import EventType
from maubot import Plugin, MessageEvent
from maubot.handlers import event, command
class DatabaseBot(Plugin):
db: orm.Session
events: Type[Table]
async def start(self) -> None:
await super().start()
db_factory = orm.sessionmaker(bind=self.database)
self.db = orm.scoped_session(db_factory)
table_meta = MetaData(bind=self.db)
self.events = Table("event", table_meta,
Column("room_id", String(255), primary_key=True),
Column("event_id", String(255), primary_key=True),
Column("sender", String(255)),
Column("body", Text))
# In the future, there will be a proper way to include Alembic upgrades in plugins.
table_meta.create_all()
@event.on(EventType.ROOM_MESSAGE)
async def handler(self, event: MessageEvent) -> None:
self.db.add(self.events(room_id=event.room_id, event_id=event.event_id,
sender=event.sender, body=event.content.body))
@command.new("stats")
async def find(self, _: MessageEvent) -> None:
res = (self.db
.query(func.sum(self.events.event_id))
.group_by(self.events.room_id, self.events.sender)
.all())
print(res)

View File

@ -1,10 +1,10 @@
maubot: 0.1.0
id: xyz.maubot.storagebot
version: 2.0.0
id: xyz.maubot.databasebot
version: 1.0.0
license: MIT
modules:
- storagebot
main_class: StorageBot
- databasebot
main_class: DatabaseBot
# This is required for a database to be available.
database: true
database_type: asyncpg
config: false

View File

@ -1,72 +0,0 @@
from __future__ import annotations
from mautrix.util.async_db import UpgradeTable, Connection
from maubot import Plugin, MessageEvent
from maubot.handlers import command
upgrade_table = UpgradeTable()
@upgrade_table.register(description="Initial revision")
async def upgrade_v1(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE stored_data (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)"""
)
@upgrade_table.register(description="Remember user who added value")
async def upgrade_v2(conn: Connection) -> None:
await conn.execute("ALTER TABLE stored_data ADD COLUMN creator TEXT")
class StorageBot(Plugin):
@command.new()
async def storage(self, evt: MessageEvent) -> None:
pass
@storage.subcommand(help="Store a value")
@command.argument("key")
@command.argument("value", pass_raw=True)
async def put(self, evt: MessageEvent, key: str, value: str) -> None:
q = """
INSERT INTO stored_data (key, value, creator) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value=excluded.value, creator=excluded.creator
"""
await self.database.execute(q, key, value, evt.sender)
await evt.reply(f"Inserted {key} into the database")
@storage.subcommand(help="Get a value from the storage")
@command.argument("key")
async def get(self, evt: MessageEvent, key: str) -> None:
q = "SELECT key, value, creator FROM stored_data WHERE LOWER(key)=LOWER($1)"
row = await self.database.fetchrow(q, key)
if row:
key = row["key"]
value = row["value"]
creator = row["creator"]
await evt.reply(f"`{key}` stored by {creator}:\n\n```\n{value}\n```")
else:
await evt.reply(f"No data stored under `{key}` :(")
@storage.subcommand(help="List keys in the storage")
@command.argument("prefix", required=False)
async def list(self, evt: MessageEvent, prefix: str | None) -> None:
q = "SELECT key, creator FROM stored_data WHERE key LIKE $1"
rows = await self.database.fetch(q, prefix + "%")
prefix_reply = f" starting with `{prefix}`" if prefix else ""
if len(rows) == 0:
await evt.reply(f"Nothing{prefix_reply} stored in database :(")
else:
formatted_data = "\n".join(
f"* `{row['key']}` stored by {row['creator']}" for row in rows
)
await evt.reply(
f"Found {len(rows)} keys{prefix_reply} in database:\n\n{formatted_data}"
)
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return upgrade_table

View File

@ -1,4 +1,2 @@
from .__meta__ import __version__
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .plugin_base import Plugin
from .plugin_server import PluginWebApp
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,171 +13,96 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import logging.config
import argparse
import asyncio
import signal
import copy
import sys
from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
from mautrix.util.program import Program
from .__meta__ import __version__
from .client import Client
from .config import Config
from .db import init as init_db, upgrade_table
from .instance import PluginInstance
from .lib.future_awaitable import FutureAwaitable
from .lib.state_store import PgStateStore
from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .db import init as init_db
from .server import MaubotServer
from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args()
config = Config(args.config, args.base_config)
config.load()
config.update()
logging.config.dictConfig(copy.deepcopy(config["logging"]))
stop_log_listener = None
if config["api_features.log"]:
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
init_log_listener()
log = logging.getLogger("maubot.init")
log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
init_zip_loader(config)
db_session = init_db(config)
clients = init_client_class(db_session, loop)
plugins = init_plugin_instance_class(db_session, config, loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(config, loop)
server.app.add_subapp(config["server.base_path"], management_api)
for plugin in plugins:
plugin.load()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
async def periodic_commit():
while True:
await asyncio.sleep(60)
db_session.commit()
periodic_commit_task: asyncio.Future = None
try:
from mautrix.crypto.store import PgCryptoStore
except ImportError:
PgCryptoStore = None
class Maubot(Program):
config: Config
server: MaubotServer
db: Database
crypto_db: Database | None
plugin_postgres_db: PostgresDatabase | None
state_store: PgStateStore
config_class = Config
module = "maubot"
name = "maubot"
version = __version__
command = "python -m maubot"
description = "A plugin-based Matrix bot system."
def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
def prepare_arg_parser(self) -> None:
super().prepare_arg_parser()
self.parser.add_argument(
"--ignore-unsupported-database",
action="store_true",
help="Run even if the database schema is too new",
)
self.parser.add_argument(
"--ignore-foreign-tables",
action="store_true",
help="Run even if the database contains tables from other programs (like Synapse)",
)
def prepare_db(self) -> None:
self.db = Database.create(
self.config["database"],
upgrade_table=upgrade_table,
db_args=self.config["database_opts"],
owner_name=self.name,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
if PgCryptoStore:
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
self.crypto_db = Database.create(
self.config["crypto_database"],
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
else:
self.crypto_db = None
if self.config["plugin_databases.postgres"] == "default":
if self.db.scheme != Scheme.POSTGRES:
self.log.critical(
'Using "default" as the postgres plugin database URL is only allowed if '
"the default database is postgres."
)
sys.exit(24)
assert isinstance(self.db, PostgresDatabase)
self.plugin_postgres_db = self.db
elif self.config["plugin_databases.postgres"]:
plugin_db = Database.create(
self.config["plugin_databases.postgres"],
db_args={
**self.config["database_opts"],
**self.config["plugin_databases.postgres_opts"],
},
)
if plugin_db.scheme != Scheme.POSTGRES:
self.log.critical("The plugin postgres database URL must be a postgres database")
sys.exit(24)
assert isinstance(plugin_db, PostgresDatabase)
self.plugin_postgres_db = plugin_db
else:
self.plugin_postgres_db = None
def prepare(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
self.prepare_db()
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
self.state_store = PgStateStore(self.db)
async def start_db(self) -> None:
self.log.debug("Starting database...")
ignore_unsupported = self.args.ignore_unsupported_database
self.db.upgrade_table.allow_unsupported = ignore_unsupported
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
log.info("Starting server")
loop.run_until_complete(server.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
if periodic_commit_task is not None:
periodic_commit_task.cancel()
log.debug("Stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop))
db_session.commit()
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
log.debug("Stopping server")
try:
await self.db.start()
await self.state_store.upgrade_table.upgrade(self.db)
if self.plugin_postgres_db and self.plugin_postgres_db is not self.db:
await self.plugin_postgres_db.start()
if self.crypto_db:
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
if self.crypto_db is not self.db:
await self.crypto_db.start()
else:
await PgCryptoStore.upgrade_table.upgrade(self.db)
except DatabaseException as e:
self.log.critical("Failed to initialize database", exc_info=e)
if e.explanation:
self.log.info(e.explanation)
sys.exit(25)
async def system_exit(self) -> None:
if hasattr(self, "db"):
self.log.trace("Stopping database due to SystemExit")
await self.db.stop()
async def start(self) -> None:
await self.start_db()
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
await asyncio.gather(*[client.start() async for client in Client.all()])
await super().start()
async for plugin in PluginInstance.all():
await plugin.load()
await self.server.start()
async def stop(self) -> None:
self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
await super().stop()
self.log.debug("Stopping server")
try:
await asyncio.wait_for(self.server.stop(), 5)
loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
await self.db.stop()
Maubot().run()
log.warning("Stopping server timed out")
log.debug("Closing event loop")
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)

View File

@ -1 +1 @@
__version__ = "0.5.0"
__version__ = "0.1.0.dev15"

View File

@ -1,3 +0,0 @@
from . import app
app()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by

View File

@ -1,2 +1,2 @@
from .cliq import command, option
from .validators import PathValidator, SPDXValidator, VersionValidator
from .validators import SPDXValidator, VersionValidator, PathValidator

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,104 +13,42 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any, Callable
import asyncio
from typing import Any, Callable, Union, Optional
import functools
import inspect
import traceback
from colorama import Fore
from prompt_toolkit.validation import Validator
from questionary import prompt
import aiohttp
from PyInquirer import prompt
import click
from ..base import app
from ..config import get_token
from .validators import ClickValidator, Required
def with_http(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with aiohttp.ClientSession() as sess:
try:
return await func(*args, sess=sess, **kwargs)
except aiohttp.ClientError as e:
print(f"{Fore.RED}Connection error: {e}{Fore.RESET}")
return wrapper
def with_authenticated_http(func):
@functools.wraps(func)
async def wrapper(*args, server: str, **kwargs):
server, token = get_token(server)
if not token:
return
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
try:
return await func(*args, sess=sess, server=server, **kwargs)
except aiohttp.ClientError as e:
print(f"{Fore.RED}Connection error: {e}{Fore.RESET}")
return wrapper
from .validators import Required, ClickValidator
def command(help: str) -> Callable[[Callable], Callable]:
def decorator(func) -> Callable:
questions = getattr(func, "__inquirer_questions__", {}).copy()
questions = func.__inquirer_questions__.copy()
@functools.wraps(func)
def wrapper(*args, **kwargs):
for key, value in kwargs.items():
if key not in questions:
continue
if value is not None and (questions[key]["type"] != "confirm" or value != "null"):
questions.pop(key, None)
try:
required_unless = questions[key].pop("required_unless")
if isinstance(required_unless, str) and kwargs[required_unless]:
questions.pop(key)
elif isinstance(required_unless, list):
for v in required_unless:
if kwargs[v]:
questions.pop(key)
break
elif isinstance(required_unless, dict):
for k, v in required_unless.items():
if kwargs.get(v, object()) == v:
questions.pop(key)
break
except KeyError:
pass
question_list = list(questions.values())
question_list.reverse()
resp = prompt(question_list, kbi_msg="Aborted!")
resp = prompt(question_list, keyboard_interrupt_msg="Aborted!")
if not resp and question_list:
return
kwargs = {**kwargs, **resp}
try:
res = func(*args, **kwargs)
if inspect.isawaitable(res):
asyncio.run(res)
except Exception:
print(Fore.RED + "Fatal error running command" + Fore.RESET)
traceback.print_exc()
func(*args, **kwargs)
return app.command(help=help)(wrapper)
return decorator
def yesno(val: str) -> bool | None:
def yesno(val: str) -> Optional[bool]:
if not val:
return None
elif isinstance(val, bool):
return val
elif val.lower() in ("true", "t", "yes", "y"):
return True
elif val.lower() in ("false", "f", "no", "n"):
@ -120,49 +58,33 @@ def yesno(val: str) -> bool | None:
yesno.__name__ = "yes/no"
def option(
short: str,
long: str,
message: str = None,
help: str = None,
click_type: str | Callable[[str], Any] = None,
inq_type: str = None,
validator: type[Validator] = None,
required: bool = False,
default: str | bool | None = None,
is_flag: bool = False,
prompt: bool = True,
required_unless: str | list | dict = None,
) -> Callable[[Callable], Callable]:
def option(short: str, long: str, message: str = None, help: str = None,
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
validator: Validator = None, required: bool = False, default: str = None,
is_flag: bool = False) -> Callable[[Callable], Callable]:
if not message:
message = long[2].upper() + long[3:]
if isinstance(validator, type) and issubclass(validator, ClickValidator):
click_type = validator.click_type
click_type = validator.click_type if isinstance(validator, ClickValidator) else click_type
if is_flag:
click_type = yesno
def decorator(func) -> Callable:
click.option(short, long, help=help, type=click_type)(func)
if not prompt:
return func
if not hasattr(func, "__inquirer_questions__"):
func.__inquirer_questions__ = {}
q = {
"type": (
inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm")
),
"type": (inq_type if isinstance(inq_type, str)
else ("input" if not is_flag
else "confirm")),
"name": long[2:],
"message": message,
}
if required_unless is not None:
q["required_unless"] = required_unless
if default is not None:
q["default"] = default
if required or required_unless is not None:
q["validate"] = Required(validator)
if required:
q["validator"] = Required(validator)
elif validator:
q["validate"] = validator
q["validator"] = validator
func.__inquirer_questions__[long[2:]] = q
return func

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,9 +16,9 @@
from typing import Callable
import os
from packaging.version import InvalidVersion, Version
from packaging.version import Version, InvalidVersion
from prompt_toolkit.validation import Validator, ValidationError
from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
import click
from ..util import spdx as spdxlib
@ -76,7 +76,7 @@ class VersionValidator(ClickValidator):
def spdx(val: str) -> str:
if not spdxlib.valid(val):
if spdxlib.valid(val):
raise click.BadParameter(f"{val} is not a valid SPDX license identifier")
return val

View File

@ -1 +1 @@
from . import auth, build, init, login, logs, upload
from . import upload, build, login, init, logs

View File

@ -1,166 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import json
import webbrowser
from colorama import Fore
from yarl import URL
import aiohttp
import click
from ..cliq import cliq
history_count: int = 10
friendly_errors = {
"server_not_found": (
"Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty."
),
"registration_no_sso": (
"The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag."
),
}
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
async with sess.get(url) as resp:
data = await resp.json()
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in data.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
@cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option(
"-u", "--username", help="The username to log in with", required_unless=["list", "sso"]
)
@cliq.option(
"-p",
"--password",
help="The password to log in with",
inq_type="password",
required_unless=["list", "sso"],
)
@cliq.option(
"-s",
"--server",
help="The maubot instance to log in through",
default="",
required=False,
prompt=False,
)
@click.option(
"-r", "--register", help="Register instead of logging in", is_flag=True, default=False
)
@click.option(
"-c",
"--update-client",
help="Instead of returning the access token, create or update a client in maubot using it",
is_flag=True,
default=False,
)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option(
"-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False
)
@click.option(
"-n",
"--device-name",
help="The initial e2ee device displayname (only for login)",
default="Maubot",
required=False,
)
@cliq.with_authenticated_http
async def auth(
homeserver: str,
username: str,
password: str,
server: str,
register: bool,
list: bool,
update_client: bool,
device_name: str,
sso: bool,
sess: aiohttp.ClientSession,
) -> None:
if list:
await list_servers(server, sess)
return
endpoint = "register" if register else "login"
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
if update_client:
url = url.update_query({"update_client": "true"})
if sso:
url = url.update_query({"sso": "true"})
req_data = {"device_name": device_name}
else:
req_data = {"username": username, "password": password, "device_name": device_name}
async with sess.post(url, json=req_data) as resp:
if not 200 <= resp.status < 300:
await print_error(resp, is_register=register)
elif sso:
await wait_sso(resp, sess, server, homeserver)
else:
await print_response(resp, is_register=register)
async def wait_sso(
resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str
) -> None:
data = await resp.json()
sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
webbrowser.open(sso_url, autoraise=True)
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
async with sess.post(wait_url, json={}) as resp:
await print_response(resp, is_register=False)
async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> None:
if resp.status == 200:
data = await resp.json()
action = "registered" if is_register else "logged in as"
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(
f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}"
)
else:
await print_error(resp, is_register)
async def print_error(resp: aiohttp.ClientResponse, is_register: bool) -> None:
try:
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = await resp.text()
action = "register" if is_register else "log in"
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,28 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import IO
from typing import Optional, Union, IO
from io import BytesIO
import asyncio
import glob
import os
import zipfile
import os
from aiohttp import ClientSession
from colorama import Fore
from questionary import prompt
from mautrix.client.api.types.util import SerializerError
from ruamel.yaml import YAML, YAMLError
from colorama import Fore
from PyInquirer import prompt
import click
from mautrix.types import SerializerError
from ...loader import PluginMeta
from ..base import app
from ..cliq import cliq
from ..cliq.validators import PathValidator
from ..config import get_token
from ..base import app
from ..config import get_default_server, get_token
from .upload import upload_file
yaml = YAML()
@ -46,7 +39,7 @@ def zipdir(zip, dir):
zip.write(os.path.join(root, file))
def read_meta(path: str) -> PluginMeta | None:
def read_meta(path: str) -> Optional[PluginMeta]:
try:
with open(os.path.join(path, "maubot.yaml")) as meta_file:
try:
@ -67,7 +60,7 @@ def read_meta(path: str) -> PluginMeta | None:
return meta
def read_output_path(output: str, meta: PluginMeta) -> str | None:
def read_output_path(output: str, meta: PluginMeta) -> Optional[str]:
directory = os.getcwd()
filename = f"{meta.id}-v{meta.version}.mbp"
if not output:
@ -75,15 +68,18 @@ def read_output_path(output: str, meta: PluginMeta) -> str | None:
elif os.path.isdir(output):
output = os.path.join(output, filename)
elif os.path.exists(output):
q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}]
override = prompt(q)["override"]
override = prompt({
"type": "confirm",
"name": "override",
"message": f"{output} exists, override?"
})["override"]
if not override:
return None
os.remove(output)
return os.path.abspath(output)
def write_plugin(meta: PluginMeta, output: str | IO) -> None:
def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None:
with zipfile.ZipFile(output, "w") as zip:
meta_dump = BytesIO()
yaml.dump(meta.serialize(), meta_dump)
@ -92,48 +88,37 @@ def write_plugin(meta: PluginMeta, output: str | IO) -> None:
for module in meta.modules:
if os.path.isfile(f"{module}.py"):
zip.write(f"{module}.py")
elif module is not None and os.path.isdir(module):
if os.path.isfile(f"{module}/__init__.py"):
elif os.path.isdir(module):
zipdir(zip, module)
else:
print(
Fore.YELLOW
+ f"Module {module} is missing __init__.py, skipping"
+ Fore.RESET
)
else:
print(Fore.YELLOW + f"Module {module} not found, skipping" + Fore.RESET)
for pattern in meta.extra_files:
for file in glob.iglob(pattern):
for file in meta.extra_files:
zip.write(file)
@cliq.with_authenticated_http
async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None:
server, token = get_token(server)
def upload_plugin(output: Union[str, IO], server: str) -> None:
if not server:
server, token = get_default_server()
else:
token = get_token(server)
if not token:
return
if isinstance(output, str):
with open(output, "rb") as file:
await upload_file(sess, file, server)
upload_file(file, server, token)
else:
await upload_file(sess, output, server)
upload_file(output, server, token)
@app.command(
short_help="Build a maubot plugin",
help=(
"Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file."
),
)
@app.command(short_help="Build a maubot plugin",
help="Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file.")
@click.argument("path", default=os.getcwd())
@click.option(
"-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type
)
@click.option(
"-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False
)
@click.option("-o", "--output", help="Path to output built plugin to",
type=PathValidator.click_type)
@click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True,
default=False)
@click.option("-s", "--server", help="Server to upload built plugin to")
def build(path: str, output: str, upload: bool, server: str) -> None:
meta = read_meta(path)
@ -152,4 +137,4 @@ def build(path: str, output: str, upload: bool, server: str) -> None:
else:
output.seek(0)
if upload:
asyncio.run(upload_plugin(output, server=server))
upload_plugin(output, server)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,11 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from pkg_resources import resource_string
import os
from jinja2 import Template
from packaging.version import Version
from pkg_resources import resource_string
from jinja2 import Template
from .. import cliq
from ..cliq import SPDXValidator, VersionValidator
@ -40,55 +40,25 @@ def load_templates():
@cliq.command(help="Initialize a new maubot plugin")
@cliq.option(
"-n",
"--name",
help="The name of the project",
required=True,
default=os.path.basename(os.getcwd()),
)
@cliq.option(
"-i",
"--id",
message="ID",
required=True,
help="The maubot plugin ID (Java package name format)",
)
@cliq.option(
"-v",
"--version",
help="Initial version for project (PEP-440 format)",
default="0.1.0",
validator=VersionValidator,
required=True,
)
@cliq.option(
"-l",
"--license",
validator=SPDXValidator,
default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)",
required=False,
)
@cliq.option(
"-c",
"--config",
message="Should the plugin include a config?",
help="Include a config in the plugin stub",
default=False,
is_flag=True,
)
@cliq.option("-n", "--name", help="The name of the project", required=True,
default=os.path.basename(os.getcwd()))
@cliq.option("-i", "--id", message="ID", required=True,
help="The maubot plugin ID (Java package name format)")
@cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)",
default="0.1.0", validator=VersionValidator, required=True)
@cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)", required=False)
@cliq.option("-c", "--config", message="Should the plugin include a config?",
help="Include a config in the plugin stub", default=False, is_flag=True)
def init(name: str, id: str, version: Version, license: str, config: bool) -> None:
load_templates()
main_class = name[0].upper() + name[1:]
meta = meta_template.render(
id=id, version=str(version), license=license, config=config, main_class=main_class
)
meta = meta_template.render(id=id, version=str(version), license=license, config=config,
main_class=main_class)
with open("maubot.yaml", "w") as file:
file.write(meta)
if license:
with open("LICENSE", "w") as file:
file.write(spdx.get(license)["licenseText"])
file.write(spdx.get(license)["text"])
if not os.path.isdir(name):
os.mkdir(name)
mod = mod_template.render(config=config, name=main_class)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,65 +13,37 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.request import urlopen
from urllib.error import HTTPError
import json
import os
from colorama import Fore
from yarl import URL
import aiohttp
from ..config import save_config, config
from ..cliq import cliq
from ..config import config, save_config
@cliq.command(help="Log in to a Maubot instance")
@cliq.option(
"-u",
"--username",
help="The username of your account",
default=os.environ.get("USER", None),
required=True,
)
@cliq.option(
"-p", "--password", help="The password to your account", inq_type="password", required=True
)
@cliq.option(
"-s",
"--server",
help="The server to log in to",
default="http://localhost:29316",
required=True,
)
@cliq.option(
"-a",
"--alias",
help="Alias to reference the server without typing the full URL",
default="",
required=False,
)
@cliq.with_http
async def login(
server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession
) -> None:
@cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True)
@cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True)
@cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True)
def login(server, username, password) -> None:
data = {
"username": username,
"password": password,
}
url = URL(server) / "_matrix/maubot/v1/auth/login"
async with sess.post(url, json=data) as resp:
if resp.status == 200:
data = await resp.json()
config["servers"][server] = data["token"]
if not config["default_server"]:
print(Fore.CYAN, "Setting", server, "as the default server")
try:
with urlopen(f"{server}/_matrix/maubot/v1/auth/login",
data=json.dumps(data).encode("utf-8")) as resp_data:
resp = json.load(resp_data)
config["servers"][server] = resp["token"]
config["default_server"] = server
if alias:
config["aliases"][alias] = server
save_config()
print(Fore.GREEN + "Logged in successfully")
else:
except HTTPError as e:
try:
err = (await resp.json())["error"]
except (json.JSONDecodeError, KeyError):
err = await resp.text()
print(Fore.RED + err + Fore.RESET)
err = json.load(e)
except json.JSONDecodeError:
err = {}
print(Fore.RED + err.get("error", str(e)) + Fore.RESET)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,14 +16,13 @@
from datetime import datetime
import asyncio
from aiohttp import ClientSession, WSMessage, WSMsgType
from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
from mautrix.client.api.types.util import Obj
import click
from mautrix.types import Obj
from ..config import get_token, get_default_server
from ..base import app
from ..config import get_token
history_count: int = 10
@ -32,19 +31,28 @@ history_count: int = 10
@click.argument("server", required=False)
@click.option("-t", "--tail", default=10, help="Maximum number of old log lines to display")
def logs(server: str, tail: int) -> None:
server, token = get_token(server)
if not server:
server, token = get_default_server()
else:
token = get_token(server)
if not token:
return
global history_count
history_count = tail
loop = asyncio.get_event_loop()
loop.run_until_complete(view_logs(server, token))
future = asyncio.ensure_future(view_logs(server, token), loop=loop)
try:
loop.run_until_complete(future)
except KeyboardInterrupt:
future.cancel()
loop.run_until_complete(future)
loop.close()
def parsedate(entry: Obj) -> None:
i = entry.time.index("+")
i = entry.time.index(":", i)
entry.time = entry.time[:i] + entry.time[i + 1 :]
entry.time = entry.time[:i] + entry.time[i + 1:]
entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z")
@ -60,16 +68,13 @@ levelcolors = {
def print_entry(entry: dict) -> None:
entry = Obj(**entry)
parsedate(entry)
print(
"{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format(
date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}"
.format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
level=entry.levelname,
levelcolor=levelcolors.get(entry.levelname, ""),
resetcolor=Fore.RESET,
logger=entry.name,
message=entry.msg,
)
)
message=entry.msg))
if entry.exc_info:
print(entry.exc_info)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,46 +13,48 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.request import urlopen, Request
from urllib.error import HTTPError
from typing import IO
import json
from colorama import Fore
from yarl import URL
import aiohttp
import click
from ..cliq import cliq
from ..base import app
from ..config import get_default_server, get_token
class UploadError(Exception):
pass
@cliq.command(help="Upload a maubot plugin")
@app.command(help="Upload a maubot plugin")
@click.argument("path")
@click.option("-s", "--server", help="The maubot instance to upload the plugin to")
@cliq.with_authenticated_http
async def upload(path: str, server: str, sess: aiohttp.ClientSession) -> None:
with open(path, "rb") as file:
await upload_file(sess, file, server)
async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> None:
url = (URL(server) / "_matrix/maubot/v1/plugins/upload").with_query({"allow_override": "true"})
headers = {"Content-Type": "application/zip"}
async with sess.post(url, data=file, headers=headers) as resp:
if resp.status in (200, 201):
data = await resp.json()
print(
f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} "
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}"
)
def upload(path: str, server: str) -> None:
if not server:
server, token = get_default_server()
else:
token = get_token(server)
if not token:
return
with open(path, "rb") as file:
upload_file(file, server, token)
def upload_file(file: IO, server: str, token: str) -> None:
req = Request(f"{server}/_matrix/maubot/v1/plugins/upload?allow_override=true", data=file,
headers={"Authorization": f"Bearer {token}", "Content-Type": "application/zip"})
try:
err = await resp.json()
if "stacktrace" in err:
print(err["stacktrace"])
err = err["error"]
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
err = await resp.text()
print(f"{Fore.RED}Failed to upload plugin: {err}{Fore.RESET}")
with urlopen(req) as resp_data:
resp = json.load(resp_data)
print(f"{Fore.GREEN}Plugin {Fore.CYAN}{resp['id']} v{resp['version']}{Fore.GREEN} "
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}")
except HTTPError as e:
try:
err = json.load(e)
except json.JSONDecodeError:
err = {}
print(err.get("stacktrace", ""))
print(Fore.RED + "Failed to upload plugin: " + err.get("error", str(e)) + Fore.RESET)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,50 +13,31 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Any
from typing import Tuple, Optional
import json
import os
from colorama import Fore
config: dict[str, Any] = {
config = {
"servers": {},
"aliases": {},
"default_server": None,
}
configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config"))
def get_default_server() -> tuple[str | None, str | None]:
def get_default_server() -> Tuple[Optional[str], Optional[str]]:
try:
server: str < None = config["default_server"]
server: str = config["default_server"]
except KeyError:
server = None
if server is None:
print(f"{Fore.RED}Default server not configured.{Fore.RESET}")
print(f"Perhaps you forgot to {Fore.CYAN}mbc login{Fore.RESET}?")
return None, None
return server, _get_token(server)
return server, get_token(server)
def get_token(server: str) -> tuple[str | None, str | None]:
if not server:
return get_default_server()
if server in config["aliases"]:
server = config["aliases"][server]
return server, _get_token(server)
def _resolve_alias(alias: str) -> str | None:
try:
return config["aliases"][alias]
except KeyError:
return None
def _get_token(server: str) -> str | None:
def get_token(server: str) -> Optional[str]:
try:
return config["servers"][server]
except KeyError:
@ -73,8 +54,7 @@ def load_config() -> None:
try:
with open(f"{configdir}/maubot-cli.json") as file:
loaded = json.load(file)
config["servers"] = loaded.get("servers", {})
config["aliases"] = loaded.get("aliases", {})
config["default_server"] = loaded.get("default_server", None)
config["servers"] = loaded["servers"]
config["default_server"] = loaded["default_server"]
except FileNotFoundError:
pass

View File

@ -1,9 +1,6 @@
from typing import Type
{% if config %}
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
{% endif %}
from maubot import Plugin
{% if config %}
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
@ -12,14 +9,18 @@ class Config(BaseProxyConfig):
helper.copy("example_2.value")
{% endif %}
class {{ name }}(Plugin):
async def start(self) -> None:{% if config %}
class {{ name }}:
async def start() -> None:
{% if config %}
self.config.load_and_update()
self.log.debug("Loaded %s from config example 2", self.config["example_2.value"]){% else %}
pass{% endif %}
async def stop(self) -> None:
self.log.debug("Loaded %s from config example 2", self.config["example_2.value"])
{% else %}
pass
{% endif %}
async def stop() -> None:
pass
{% if config %}
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:

Binary file not shown.

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import json
from typing import Dict
import zipfile
import pkg_resources
import json
spdx_list: dict[str, dict[str, str]] | None = None
spdx_list = None
def load() -> None:
@ -33,13 +31,13 @@ def load() -> None:
spdx_list = json.load(file)
def get(id: str) -> dict[str, str]:
def get(id: str) -> Dict[str, str]:
if not spdx_list:
load()
return spdx_list[id]
return spdx_list[id.lower()]
def valid(id: str) -> bool:
if not spdx_list:
load()
return id in spdx_list
return id.lower() in spdx_list

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,203 +13,51 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, cast
from collections import defaultdict
from typing import Dict, List, Optional, Set, TYPE_CHECKING
import asyncio
import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.client import InternalEventType
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
DeviceID,
EventFilter,
EventType,
Filter,
FilterID,
Membership,
PresenceState,
RoomEventFilter,
RoomFilter,
StateEvent,
StateFilter,
StrippedStateEvent,
SyncToken,
UserID,
)
from mautrix.util import background_task
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.logging import TraceLogger
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
from .db import Client as DBClient
from .db import DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, PgCryptoStore
crypto_import_error = None
except ImportError as e:
OlmMachine = PgCryptoStore = None
crypto_import_error = e
if TYPE_CHECKING:
from .__main__ import Maubot
from .instance import PluginInstance
log = logging.getLogger("maubot.client")
class Client(DBClient):
maubot: "Maubot" = None
cache: dict[UserID, Client] = {}
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: TraceLogger = logging.getLogger("maubot.client")
class Client:
db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
references: set[PluginInstance]
references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
sync_ok: bool
remote_displayname: str | None
remote_avatar_url: ContentURI | None
def __init__(
self,
id: UserID,
homeserver: str,
access_token: str,
device_id: DeviceID,
enabled: bool = False,
next_batch: SyncToken = "",
filter_id: FilterID = "",
sync: bool = True,
autojoin: bool = True,
online: bool = True,
displayname: str = "disable",
avatar_url: str = "disable",
) -> None:
super().__init__(
id=id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id,
enabled=bool(enabled),
next_batch=next_batch,
filter_id=filter_id,
sync=bool(sync),
autojoin=bool(autojoin),
online=bool(online),
displayname=displayname,
avatar_url=avatar_url,
)
self._postinited = False
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def _make_client(
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
) -> MaubotMatrixClient:
return MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=token or self.access_token,
client_session=self.http_client,
log=self.log,
crypto_log=self.log.getChild("crypto"),
loop=self.maubot.loop,
device_id=device_id or self.device_id,
sync_store=self,
state_store=self.maubot.state_store,
)
def postinit(self) -> None:
if self._postinited:
raise RuntimeError("postinit() called twice")
self._postinited = True
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.log = self.log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop)
self.log = log.getChild(self.id)
self.references = set()
self.started = False
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
self.client = self._make_client()
if self.enable_crypto:
self._prepare_crypto()
else:
self.crypto_store = None
self.crypto = None
self.client.ignore_initial_sync = True
self.client.ignore_first_sync = True
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.client.add_event_handler(EventType.ROOM_TOMBSTONE, self._handle_tombstone)
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
@property
def enable_crypto(self) -> bool:
if not self.device_id:
return False
elif not OlmMachine:
global crypto_import_error
self.log.warning(
"Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared")
return False
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
)
self.crypto = OlmMachine(
self.client,
self.crypto_store,
self.maubot.state_store,
log=self.client.crypto_log,
)
self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None:
if not self.crypto:
return
handlers = [
(InternalEventType.DEVICE_OTK_COUNT, self.crypto.handle_otk_count),
(InternalEventType.DEVICE_LISTS, self.crypto.handle_device_lists),
(EventType.TO_DEVICE_ENCRYPTED, self.crypto.handle_to_device_event),
(EventType.ROOM_KEY_REQUEST, self.crypto.handle_room_key_request),
(EventType.ROOM_MEMBER, self.crypto.handle_member_event),
]
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
async def start(self, try_n: int | None = 0) -> None:
async def start(self, try_n: Optional[int] = 0) -> None:
try:
if try_n > 0:
await asyncio.sleep(try_n * 10)
@ -217,21 +65,7 @@ class Client(DBClient):
except Exception:
self.log.exception("Failed to start")
async def _start_crypto(self) -> None:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning(
"Mismatching device ID in crypto store and main database, resetting encryption"
)
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: int | None = 0) -> None:
async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
@ -239,71 +73,47 @@ class Client(DBClient):
self.log.warning("Ignoring start() call to started client")
return
try:
await self.client.versions()
whoami = await self.client.whoami()
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
await self.update()
self.db_instance.enabled = False
return
except Exception as e:
if try_n >= 8:
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
await self.update()
self.db_instance.enabled = False
else:
self.log.warning(
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
)
background_task.create(self.start(try_n + 1))
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.enabled = False
await self.update()
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(
f"Device ID mismatch: expected {self.device_id}, but got {whoami.device_id}"
)
self.enabled = False
await self.update()
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.db_instance.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(
Filter(
self.db_instance.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
),
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)
)
await self.update()
))
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.crypto:
await self._start_crypto()
self.start_sync()
await self._update_remote_profile()
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
async def start_plugins(self) -> None:
await asyncio.gather(*[plugin.start() for plugin in self.references])
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started])
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
loop=self.loop)
def start_sync(self) -> None:
if self.sync:
@ -317,50 +127,48 @@ class Client(DBClient):
self.started = False
await self.stop_plugins()
self.stop_sync()
if self.crypto:
await self.crypto_store.close()
async def clear_cache(self) -> None:
self.stop_sync()
self.filter_id = FilterID("")
self.next_batch = SyncToken("")
await self.update()
self.start_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
def to_dict(self) -> dict:
return {
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
"device_id": self.device_id,
"fingerprint": (
self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
),
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
"sync_ok": self.sync_ok,
"autojoin": self.autojoin,
"online": self.online,
"displayname": self.displayname,
"avatar_url": self.avatar_url,
"remote_displayname": self.remote_displayname,
"remote_avatar_url": self.remote_avatar_url,
"instances": [instance.to_dict() for instance in self.references],
}
async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
return
_, server = self.client.parse_user_id(evt.sender)
await self.client.join_room(evt.content.replacement_room, servers=[server])
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.query.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
async def update_started(self, started: bool | None) -> None:
async def update_started(self, started: bool) -> None:
if started is None or started == self.started:
return
if started:
@ -368,162 +176,108 @@ class Client(DBClient):
else:
await self.stop()
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
if enabled is None or enabled == self.enabled:
return
self.enabled = enabled
if save:
await self.update()
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
async def update_displayname(self, displayname: str) -> None:
if displayname is None or displayname == self.displayname:
return
self.displayname = displayname
if self.displayname != "disable":
self.db_instance.displayname = displayname
await self.client.set_displayname(self.displayname)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
if avatar_url is None or avatar_url == self.avatar_url:
return
self.avatar_url = avatar_url
if self.avatar_url != "disable":
self.db_instance.avatar_url = avatar_url
await self.client.set_avatar_url(self.avatar_url)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
if sync is None or self.sync == sync:
async def update_access_details(self, access_token: str, homeserver: str) -> None:
if not access_token and not homeserver:
return
self.sync = sync
elif access_token == self.access_token and homeserver == self.homeserver:
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, log=self.log)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError(f"MXID mismatch: {mxid}")
new_client.store = self.db_instance
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.start_sync()
# region Properties
@property
def id(self) -> UserID:
return self.db_instance.id
@property
def homeserver(self) -> str:
return self.db_instance.homeserver
@property
def access_token(self) -> str:
return self.db_instance.access_token
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if sync:
if value:
self.start_sync()
else:
self.stop_sync()
if save:
await self.update()
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
if autojoin is None or autojoin == self.autojoin:
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if autojoin:
if value:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.autojoin = autojoin
if save:
await self.update()
self.db_instance.autojoin = value
async def update_online(self, online: bool | None, save: bool = True) -> None:
if online is None or online == self.online:
return
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
self.online = online
if save:
await self.update()
@property
def displayname(self) -> str:
return self.db_instance.displayname
async def update_access_details(
self,
access_token: str | None,
homeserver: str | None,
device_id: str | None = None,
) -> None:
if not access_token and not homeserver:
return
if device_id is None:
device_id = self.device_id
elif not device_id:
device_id = None
if (
access_token == self.access_token
and homeserver == self.homeserver
and device_id == self.device_id
):
return
new_client = self._make_client(homeserver, access_token, device_id)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = self
self.stop_sync()
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
# TODO this event handler transfer is pretty hacky
self._remove_crypto_event_handlers()
self.client.crypto = None
new_client.event_handlers = self.client.event_handlers
new_client.global_event_handlers = self.client.global_event_handlers
# endregion
self.client = new_client
self.homeserver = homeserver
self.access_token = access_token
self.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
else:
self.crypto_store = None
self.crypto = None
self.start_sync()
async def _update_remote_profile(self) -> None:
profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
async def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
await super().delete()
@classmethod
@async_getter_lock
async def get(
cls,
user_id: UserID,
*,
homeserver: str | None = None,
access_token: str | None = None,
device_id: DeviceID | None = None,
) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
pass
user = cast(cls, await super().get(user_id))
if user is not None:
user.postinit()
return user
if homeserver and access_token:
user = cls(
user_id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id or "",
)
await user.insert()
user.postinit()
return user
return None
@classmethod
async def all(cls) -> AsyncGenerator[Client, None]:
users = await super().all()
user: cls
for user in users:
try:
yield cls.cache[user.id]
except KeyError:
user.postinit()
yield user
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.db = db
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,10 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import random
import re
import string
import bcrypt
import re
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
@ -27,55 +26,36 @@ bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
class Config(BaseFileConfig):
@staticmethod
def _new_token() -> str:
return "".join(random.choices(string.ascii_lowercase + string.digits, k=64))
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def do_update(self, helper: ConfigUpdateHelper) -> None:
base = helper.base
copy = helper.copy
if "database" in self and self["database"].startswith("sqlite:///"):
helper.base["database"] = self["database"].replace("sqlite:///", "sqlite:")
else:
copy("database")
copy("database_opts")
if isinstance(self["crypto_database"], dict):
if self["crypto_database.type"] == "postgres":
base["crypto_database"] = self["crypto_database.postgres_uri"]
else:
copy("crypto_database")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
if "plugin_directories.db" in self:
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
else:
copy("plugin_databases.sqlite")
copy("plugin_databases.postgres")
copy("plugin_databases.postgres_opts")
copy("plugin_directories.db")
copy("server.hostname")
copy("server.port")
copy("server.public_url")
copy("server.listen")
copy("server.base_path")
copy("server.ui_base_path")
copy("server.plugin_base_path")
copy("server.override_resource_path")
copy("server.appservice_base_path")
shared_secret = self["server.unshared_secret"]
if shared_secret is None or shared_secret == "generate":
base["server.unshared_secret"] = self._new_token()
else:
base["server.unshared_secret"] = shared_secret
if "registration_secrets" in self:
base["homeservers"] = self["registration_secrets"]
else:
copy("homeservers")
copy("registration_secrets")
copy("admins")
for username, password in base["admins"].items():
if password and not bcrypt_regex.match(password):
if password == "password":
password = self._new_token()
base["admins"][username] = bcrypt.hashpw(
password.encode("utf-8"), bcrypt.gensalt()
).decode("utf-8")
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
bcrypt.gensalt()).decode("utf-8")
copy("api_features.login")
copy("api_features.plugin")
copy("api_features.plugin_upload")

72
maubot/db.py Normal file
View File

@ -0,0 +1,72 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import cast
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from .config import Config
Base: declarative_base = declarative_base()
class DBPlugin(Base):
query: Query
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False)
config: str = Column(Text, nullable=False, default='')
class DBClient(Base):
query: Query
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="")
sync: bool = Column(Boolean, nullable=False, default=True)
autojoin: bool = Column(Boolean, nullable=False, default=True)
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
def init(config: Config) -> Session:
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = sessionmaker(bind=db_engine)
db_session = scoped_session(db_factory)
Base.metadata.bind = db_engine
Base.metadata.create_all()
DBPlugin.query = db_session.query_property()
DBClient.query = db_session.query_property()
return cast(Session, db_session)

View File

@ -1,13 +0,0 @@
from mautrix.util.async_db import Database
from .client import Client
from .instance import DatabaseEngine, Instance
from .upgrade import upgrade_table
def init(db: Database) -> None:
for table in (Client, Instance):
table.db = db
__all__ = ["upgrade_table", "init", "Client", "Instance", "DatabaseEngine"]

View File

@ -1,114 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
@dataclass
class Client(SyncStore):
db: ClassVar[Database] = fake_db
id: UserID
homeserver: str
access_token: str
device_id: DeviceID
enabled: bool
next_batch: SyncToken
filter_id: FilterID
sync: bool
autojoin: bool
online: bool
displayname: str
avatar_url: ContentURI
@classmethod
def _from_row(cls, row: Record | None) -> Client | None:
if row is None:
return None
return cls(**row)
_columns = (
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
"sync, autojoin, online, displayname, avatar_url"
)
@property
def _values(self):
return (
self.id,
self.homeserver,
self.access_token,
self.device_id,
self.enabled,
self.next_batch,
self.filter_id,
self.sync,
self.autojoin,
self.online,
self.displayname,
self.avatar_url,
)
@classmethod
async def all(cls) -> list[Client]:
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Client | None:
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def insert(self) -> None:
q = """
INSERT INTO client (
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
sync, autojoin, online, displayname, avatar_url
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"""
await self.db.execute(q, *self._values)
async def put_next_batch(self, next_batch: SyncToken) -> None:
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
self.next_batch = next_batch
async def get_next_batch(self) -> SyncToken:
return self.next_batch
async def update(self) -> None:
q = """
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
displayname=$11, avatar_url=$12
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)

View File

@ -1,101 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from enum import Enum
from asyncpg import Record
from attr import dataclass
from mautrix.types import UserID
from mautrix.util.async_db import Database
fake_db = Database.create("") if TYPE_CHECKING else None
class DatabaseEngine(Enum):
SQLITE = "sqlite"
POSTGRES = "postgres"
@dataclass
class Instance:
db: ClassVar[Database] = fake_db
id: str
type: str
enabled: bool
primary_user: UserID
config_str: str
database_engine: DatabaseEngine | None
@property
def database_engine_str(self) -> str | None:
return self.database_engine.value if self.database_engine else None
@classmethod
def _from_row(cls, row: Record | None) -> Instance | None:
if row is None:
return None
data = {**row}
db_engine = data.pop("database_engine", None)
return cls(**data, database_engine=DatabaseEngine(db_engine) if db_engine else None)
_columns = "id, type, enabled, primary_user, config, database_engine"
@classmethod
async def all(cls) -> list[Instance]:
q = f"SELECT {cls._columns} FROM instance"
rows = await cls.db.fetch(q)
return [cls._from_row(row) for row in rows]
@classmethod
async def get(cls, id: str) -> Instance | None:
q = f"SELECT {cls._columns} FROM instance WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))
async def update_id(self, new_id: str) -> None:
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
self.id = new_id
@property
def _values(self):
return (
self.id,
self.type,
self.enabled,
self.primary_user,
self.config_str,
self.database_engine_str,
)
async def insert(self) -> None:
q = (
"INSERT INTO instance (id, type, enabled, primary_user, config, database_engine) "
"VALUES ($1, $2, $3, $4, $5, $6)"
)
await self.db.execute(q, *self._values)
async def update(self) -> None:
q = """
UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5, database_engine=$6
WHERE id=$1
"""
await self.db.execute(q, *self._values)
async def delete(self) -> None:
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)

View File

@ -1,5 +0,0 @@
from mautrix.util.async_db import UpgradeTable
upgrade_table = UpgradeTable()
from . import v01_initial_revision, v02_instance_database_engine

View File

@ -1,136 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
legacy_version_query = "SELECT version_num FROM alembic_version"
last_legacy_version = "90aa88820eab"
@upgrade_table.register(description="Initial asyncpg revision")
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
if await conn.table_exists("alembic_version"):
await migrate_legacy_to_v1(conn, scheme)
else:
return await create_v1_tables(conn)
async def create_v1_tables(conn: Connection) -> None:
await conn.execute(
"""CREATE TABLE client (
id TEXT PRIMARY KEY,
homeserver TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
next_batch TEXT NOT NULL,
filter_id TEXT NOT NULL,
sync BOOLEAN NOT NULL,
autojoin BOOLEAN NOT NULL,
online BOOLEAN NOT NULL,
displayname TEXT NOT NULL,
avatar_url TEXT NOT NULL
)"""
)
await conn.execute(
"""CREATE TABLE instance (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
enabled BOOLEAN NOT NULL,
primary_user TEXT NOT NULL,
config TEXT NOT NULL,
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
)"""
)
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
legacy_version = await conn.fetchval(legacy_version_query)
if legacy_version != last_legacy_version:
raise RuntimeError(
"Legacy database is not on last version. "
"Please upgrade the old database with alembic or drop it completely first."
)
await conn.execute("ALTER TABLE plugin RENAME TO instance")
await update_state_store(conn, scheme)
if scheme != Scheme.SQLITE:
await varchar_to_text(conn)
await conn.execute("DROP TABLE alembic_version")
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
# The Matrix state store already has more or less the correct schema, so set the version
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
if scheme != Scheme.SQLITE:
# Remove old uppercase membership type and recreate it as lowercase
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
await conn.execute("DROP TYPE IF EXISTS membership")
await conn.execute(
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
)
await conn.execute(
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
"USING LOWER(membership)::membership"
)
else:
# Recreate table to remove CHECK constraint and lowercase everything
await conn.execute(
"""CREATE TABLE new_mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
)"""
)
await conn.execute(
"""
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
FROM mx_user_profile
"""
)
await conn.execute("DROP TABLE mx_user_profile")
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
async def varchar_to_text(conn: Connection) -> None:
columns_to_adjust = {
"client": (
"id",
"homeserver",
"device_id",
"next_batch",
"filter_id",
"displayname",
"avatar_url",
),
"instance": ("id", "type", "primary_user"),
"mx_room_state": ("room_id",),
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
}
for table, columns in columns_to_adjust.items():
for column in columns:
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')

View File

@ -1,25 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from mautrix.util.async_db import Connection
from . import upgrade_table
@upgrade_table.register(description="Store instance database engine")
async def upgrade_v2(conn: Connection) -> None:
await conn.execute("ALTER TABLE instance ADD COLUMN database_engine TEXT")

View File

@ -1,130 +0,0 @@
# The full URI to the database. SQLite and Postgres are fully supported.
# Format examples:
# SQLite: sqlite:filename.db
# Postgres: postgresql://username:password@hostname/dbname
database: sqlite:maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above.
crypto_database: default
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
database_opts:
min_size: 1
max_size: 10
# Configuration for storing plugin .mbp files
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
# The directories from which plugins should be loaded.
# Duplicate plugin IDs will be moved to the trash.
load:
- ./plugins
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: ./trash
# Configuration for storing plugin databases
plugin_databases:
# The directory where SQLite plugin databases should be stored.
sqlite: ./plugins
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
#
# To use the same connection pool as the default database, set to "default"
# (the default database above must be postgres to do this).
#
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
# SQL queries/psql commands.
postgres: null
# Maximum number of connections per plugin instance.
postgres_max_conns_per_plugin: 3
# Overrides for the default database_opts when using a non-"default" postgres connection string.
postgres_opts: {}
server:
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# Public base URL where the server is visible.
public_url: https://example.com
# The base path for the UI.
ui_base_path: /_matrix/maubot
# The base path for plugin endpoints. The instance ID will be appended directly.
plugin_base_path: /_matrix/maubot/plugin/
# Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path.
override_resource_path: false
# The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# Known homeservers. This is required for the `mbc auth` command and also allows
# more convenient access from the management UI. This is not required to create
# clients in the management UI, since you can also just type the homeserver URL
# into the box there.
homeservers:
matrix.org:
# Client-server API URL
url: https://matrix-client.matrix.org
# registration_shared_secret from synapse config
# You can leave this empty if you don't have access to the homeserver.
# When this is empty, `mbc auth --register` won't work, but `mbc auth` (login) will.
secret: null
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
root: ""
# API feature switches.
api_features:
login: true
plugin: true
plugin_upload: true
instance: true
instance_database: true
client: true
client_proxy: true
client_auth: true
dev_open: true
log: true
# Python logging configuration.
#
# See section 16.7.2 of the Python documentation for more info:
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
logging:
version: 1
formatters:
colored:
(): maubot.lib.color_log.ColorFormatter
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
normal:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: normal
filename: ./maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: colored
loggers:
maubot:
level: DEBUG
mau:
level: DEBUG
aiohttp:
level: INFO
root:
level: DEBUG
handlers: [file, console]

View File

@ -1 +1 @@
from . import command, event, web
from . import event, command

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,46 +13,28 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
NewType,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set)
from abc import ABC, abstractmethod
import asyncio
import functools
import inspect
import re
from mautrix.types import EventType, MessageType
from mautrix.types import MessageType, EventType
from ..matrix import MaubotMessageEvent
from . import event
PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]]
AliasesType = Union[
List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool]
]
CommandHandlerFunc = NewType(
"CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]]
)
CommandHandlerDecorator = NewType(
"CommandHandlerDecorator",
Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"],
)
PassiveCommandHandlerDecorator = NewType(
"PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc]
)
PrefixType = Optional[Union[str, Callable[[], str]]]
AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool]]
CommandHandlerFunc = NewType("CommandHandlerFunc",
Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
CommandHandlerDecorator = NewType("CommandHandlerDecorator",
Callable[[Union['CommandHandler', CommandHandlerFunc]],
'CommandHandler'])
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
Callable[[CommandHandlerFunc], CommandHandlerFunc])
def _split_in_two(val: str, split_by: str) -> List[str]:
@ -62,71 +44,32 @@ def _split_in_two(val: str, split_by: str) -> List[str]:
class CommandHandler:
def __init__(self, func: CommandHandlerFunc) -> None:
self.__mb_func__: CommandHandlerFunc = func
self.__mb_parent__: Optional[CommandHandler] = None
self.__mb_parent__: CommandHandler = None
self.__mb_subcommands__: List[CommandHandler] = []
self.__mb_arguments__: List[Argument] = []
self.__mb_help__: Optional[str] = None
self.__mb_get_name__: Callable[[Any], str] = lambda s: "noname"
self.__mb_help__: str = None
self.__mb_get_name__: Callable[[], str] = None
self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset
self.__mb_require_subcommand__: bool = True
self.__mb_must_consume_args__: bool = True
self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True
self.__mb_event_types__: set[EventType] = {EventType.ROOM_MESSAGE}
self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,)
self.__bound_copies__: Dict[Any, CommandHandler] = {}
self.__bound_instance__: Any = None
def __get__(self, instance, instancetype):
if not instance or self.__bound_instance__:
return self
try:
return self.__bound_copies__[instance]
except KeyError:
new_ch = type(self)(self.__mb_func__)
keys = [
"parent",
"subcommands",
"arguments",
"help",
"get_name",
"is_command_match",
"require_subcommand",
"must_consume_args",
"arg_fallthrough",
"event_handler",
"event_types",
"msgtypes",
]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [
subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__
]
self.__bound_copies__[instance] = new_ch
return new_ch
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
self.__class_instance: Any = None
@staticmethod
def __command_match_unset(self, val: str) -> bool:
def __command_match_unset(self, val: str) -> str:
raise NotImplementedError("Hmm")
async def __call__(
self,
evt: MaubotMessageEvent,
*,
_existing_args: Dict[str, Any] = None,
remaining_val: str = None,
) -> Any:
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
remaining_val: str = None) -> Any:
if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__:
return
if remaining_val is None:
if not evt.content.body or evt.content.body[0] != "!":
return
command, remaining_val = _split_in_two(evt.content.body[1:], " ")
command = command.lower()
if not self.__mb_is_command_match__(self.__bound_instance__, command):
if not self.__mb_is_command_match__(self, command):
return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
@ -146,56 +89,48 @@ class CommandHandler:
await evt.reply(self.__mb_full_help__)
return
if self.__mb_must_consume_args__ and remaining_val.strip():
await evt.reply(self.__mb_full_help__)
return
if self.__bound_instance__:
return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
if self.__class_instance:
return await self.__mb_func__(self.__class_instance, evt, **call_args)
return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, Any]:
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand(
evt, _existing_args=call_args, remaining_val=remaining_val
)
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
return True, await subcommand(evt, _existing_args=call_args,
remaining_val=remaining_val)
return False, None
async def __parse_args__(
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, str]:
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, str]:
for arg in self.__mb_arguments__:
try:
remaining_val, call_args[arg.name] = arg.match(
remaining_val.strip(), evt=evt, instance=self.__bound_instance__
)
if arg.required and call_args[arg.name] is None:
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip())
if arg.required and not call_args[arg.name]:
raise ValueError("Argument required")
except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
return False, remaining_val
except ValueError:
except ValueError as e:
await evt.reply(self.__mb_usage__)
return False, remaining_val
return True, remaining_val
def __get__(self, instance, instancetype):
self.__class_instance = instance
return self
@property
def __mb_full_help__(self) -> str:
usage = self.__mb_usage_without_subcommands__ + "\n\n"
if not self.__mb_require_subcommand__:
usage += f"* {self.__mb_prefix__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__)
return usage
@property
def __mb_usage_args__(self) -> str:
arg_usage = " ".join(
f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__
)
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
for arg in self.__mb_arguments__)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage
@ -206,24 +141,19 @@ class CommandHandler:
@property
def __mb_name__(self) -> str:
return self.__mb_get_name__(self.__bound_instance__)
return self.__mb_get_name__(self.__class_instance)
@property
def __mb_prefix__(self) -> str:
if self.__mb_parent__:
return (
f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} "
f"{self.__mb_name__}"
)
return f"{self.__mb_parent__.__mb_prefix__} {self.__mb_name__}"
return f"!{self.__mb_name__}"
@property
def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (
f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}"
)
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property
@ -233,12 +163,8 @@ class CommandHandler:
@property
def __mb_usage_without_subcommands__(self) -> str:
if not self.__mb_arg_fallthrough__:
if not self.__mb_arguments__:
return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]"
return (
f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}"
)
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_usage_subcommand__}")
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property
@ -247,25 +173,14 @@ class CommandHandler:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__
def subcommand(
self,
name: PrefixType = None,
*,
help: str = None,
aliases: AliasesType = None,
required_subcommand: bool = True,
arg_fallthrough: bool = True,
def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
required_subcommand: bool = True, arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
new(
name,
help=help,
aliases=aliases,
require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough,
)(func)
new(name, help=help, aliases=aliases, require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough)(func)
func.__mb_parent__ = self
func.__mb_event_handler__ = False
self.__mb_subcommands__.append(func)
@ -274,17 +189,9 @@ class CommandHandler:
return decorator
def new(
name: PrefixType = None,
*,
help: str = None,
aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE,
msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True,
arg_fallthrough: bool = True,
must_consume_args: bool = True,
) -> CommandHandlerDecorator:
def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: List[MessageType] = None,
require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
@ -298,24 +205,22 @@ def new(
else:
func.__mb_get_name__ = lambda self: name
else:
func.__mb_get_name__ = lambda self: func.__mb_func__.__name__.replace("_", "-")
func.__mb_get_name__ = lambda self: func.__name__
if callable(aliases):
if len(inspect.getfullargspec(aliases).args) == 1:
func.__mb_is_command_match__ = lambda self, val: aliases(val)
else:
func.__mb_is_command_match__ = aliases
elif isinstance(aliases, (list, set, tuple)):
func.__mb_is_command_match__ = lambda self, val: (
val == func.__mb_get_name__(self) or val in aliases
)
func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_name__
or val in aliases)
else:
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self)
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_name__
# Decorators are executed last to first, so we reverse the argument list.
func.__mb_arguments__.reverse()
func.__mb_require_subcommand__ = require_subcommand
func.__mb_arg_fallthrough__ = arg_fallthrough
func.__mb_must_consume_args__ = must_consume_args
func.__mb_event_types__ = {event_type}
func.__mb_event_type__ = event_type
if msgtypes:
func.__mb_msgtypes__ = msgtypes
return func
@ -331,16 +236,15 @@ class ArgumentSyntaxError(ValueError):
class Argument(ABC):
def __init__(
self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False
) -> None:
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False) -> None:
self.name = name
self.label = label or name
self.required = required
self.pass_raw = pass_raw
@abstractmethod
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
def match(self, val: str) -> Tuple[str, Any]:
pass
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
@ -351,73 +255,51 @@ class Argument(ABC):
class RegexArgument(Argument):
def __init__(
self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matches: str = None,
) -> None:
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matches: str = None) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches)
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
def match(self, val: str) -> Tuple[str, Any]:
orig_val = val
if not self.pass_raw:
val = re.split(r"\s", val, 1)[0]
val = val.split(" ")[0]
match = self.regex.match(val)
if match:
return (
orig_val[: match.start()] + orig_val[match.end() :],
match.groups() or val[match.start() : match.end()],
)
return (orig_val[:match.pos] + orig_val[match.endpos:],
match.groups() or val[match.pos:match.endpos])
return orig_val, None
class CustomArgument(Argument):
def __init__(
self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matcher: Callable[[str], Any],
) -> None:
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return self.matcher(val)
orig_val = val
val = re.split(r"\s", val, 1)[0]
val = val.split(" ")[0]
res = self.matcher(val)
if res is not None:
return orig_val[len(val) :], res
if res:
return orig_val[len(val):], res
return orig_val, None
class SimpleArgument(Argument):
def match(self, val: str, **kwargs) -> Tuple[str, Any]:
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return "", val
res = re.split(r"\s", val, 1)[0]
return val[len(res) :], res
res = val.split(" ")[0]
return val[len(res):], res
def argument(
name: str,
label: str = None,
*,
required: bool = True,
matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False,
) -> CommandHandlerDecorator:
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
) -> CommandHandlerDecorator:
if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser:
@ -426,26 +308,12 @@ def argument(
return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive(
regex: Union[str, Pattern],
*,
msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE,
multiple: bool = False,
case_insensitive: bool = False,
multiline: bool = False,
dot_all: bool = False,
) -> PassiveCommandHandlerDecorator:
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
flags = re.RegexFlag.UNICODE
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dot_all:
flags |= re.DOTALL
regex = re.compile(regex, flags=flags)
regex = re.compile(regex)
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
combine = None
@ -465,14 +333,12 @@ def passive(
return
data = field(evt)
if multiple:
val = [
(data[match.pos : match.endpos], *match.groups())
for match in regex.finditer(data)
]
val = [(data[match.pos:match.endpos], *match.groups())
for match in regex.finditer(data)]
else:
match = regex.search(data)
match = regex.match(data)
if match:
val = (data[match.pos : match.endpos], *match.groups())
val = (data[match.pos:match.endpos], *match.groups())
else:
val = None
if val:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,32 +13,22 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Callable, Union, NewType
from typing import Callable, NewType
from mautrix.client import EventHandler, InternalEventType
from mautrix.types import EventType
from mautrix.client import EventHandler
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler:
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True
if isinstance(var, (EventType, InternalEventType)):
if hasattr(func, "__mb_event_types__"):
func.__mb_event_types__.add(var)
if isinstance(var, EventType):
func.__mb_event_type__ = var
else:
func.__mb_event_types__ = {var}
else:
func.__mb_event_types__ = {EventType.ALL}
func.__mb_event_type__ = EventType.ALL
return func
return decorator if isinstance(var, (EventType, InternalEventType)) else decorator(var)
def off(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = False
return func
return decorator if isinstance(var, EventType) else decorator(var)

View File

@ -1,66 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Awaitable, Callable
from aiohttp import hdrs, web
WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]
WebHandlerDecorator = Callable[[WebHandler], WebHandler]
def head(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_HEAD, path, **kwargs)
def options(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_OPTIONS, path, **kwargs)
def get(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_GET, path, **kwargs)
def post(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_POST, path, **kwargs)
def put(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_PUT, path, **kwargs)
def patch(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_PATCH, path, **kwargs)
def delete(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_DELETE, path, **kwargs)
def view(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_ANY, path, **kwargs)
def handle(method: str, path: str, **kwargs) -> WebHandlerDecorator:
def decorator(handler: WebHandler) -> WebHandler:
try:
handlers = getattr(handler, "__mb_web_handler__")
except AttributeError:
handlers = []
setattr(handler, "__mb_web_handler__", handlers)
handlers.append((method, path, kwargs))
return handler
return decorator

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,92 +13,52 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from collections import defaultdict
import asyncio
import inspect
import io
import logging
from typing import Dict, List, Optional
from asyncio import AbstractEventLoop
import os.path
import logging
import io
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from sqlalchemy.orm import Session
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util import background_task
from mautrix.util.async_db import Database, Scheme, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger
from mautrix.types import UserID
from .db import DBPlugin
from .config import Config
from .client import Client
from .db import DatabaseEngine, Instance as DBInstance
from .lib.optionalalchemy import Engine, MetaData, create_engine
from .lib.plugin_db import ProxyPostgresDatabase
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .__main__ import Maubot
from .server import PluginWebApp
log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
log = logging.getLogger("maubot.instance")
yaml = YAML()
yaml.indent(4)
yaml.width = 200
class PluginInstance(DBInstance):
maubot: "Maubot" = None
cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = []
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
class PluginInstance:
db: Session = None
mb_config: Config = None
loop: AbstractEventLoop = None
cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
log: logging.Logger
loader: PluginLoader | None
client: Client | None
plugin: Plugin | None
config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
loader: PluginLoader
client: Client
plugin: Plugin
config: BaseProxyConfig
base_cfg: RecursiveDict[CommentedMap]
inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table]
started: bool
def __init__(
self,
id: str,
type: str,
enabled: bool,
primary_user: UserID,
config: str = "",
database_engine: DatabaseEngine | None = None,
) -> None:
super().__init__(
id=id,
type=type,
enabled=bool(enabled),
primary_user=primary_user,
config_str=config,
database_engine=database_engine,
)
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def postinit(self) -> None:
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
self.log = log.getChild(self.id)
self.cache[self.id] = self
self.config = None
self.started = False
self.loader = None
@ -106,10 +66,8 @@ class PluginInstance(DBInstance):
self.plugin = None
self.inst_db = None
self.inst_db_tables = None
self.inst_webapp = None
self.inst_webapp_url = None
self.base_cfg = None
self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict:
return {
@ -118,144 +76,41 @@ class PluginInstance(DBInstance):
"enabled": self.enabled,
"started": self.started,
"primary_user": self.primary_user,
"config": self.config_str,
"base_config": self.base_cfg_str,
"database": (
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
),
"database_interface": self.loader.meta.database_type_str if self.loader else "unknown",
"database_engine": self.database_engine_str,
"config": self.db_instance.config,
"database": (self.inst_db is not None
and self.mb_config["api_features.instance_database"]),
}
def _introspect_sqlalchemy(self) -> dict:
metadata = MetaData()
def get_db_tables(self) -> Dict[str, sql.Table]:
if not self.inst_db_tables:
metadata = sql.MetaData()
metadata.reflect(self.inst_db)
return {
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
}
for column in table.columns
},
}
for table in metadata.tables.values()
}
async def _introspect_sqlite(self) -> dict:
q = """
SELECT
m.name AS table_name,
p.cid AS col_id,
p.name AS column_name,
p.type AS data_type,
p.pk AS is_primary,
p.dflt_value AS column_default,
p.[notnull] AS is_nullable
FROM sqlite_master m
LEFT JOIN pragma_table_info((m.name)) p
WHERE m.type = 'table'
ORDER BY table_name, col_id
"""
data = await self.inst_db.fetch(q)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"][col_name] = {
"type": column["data_type"],
"nullable": bool(column["is_nullable"]),
"default": column["column_default"],
"primary": bool(column["is_primary"]),
# TODO uniqueness?
}
return tables
async def _introspect_postgres(self) -> dict:
assert isinstance(self.inst_db, ProxyPostgresDatabase)
q = """
SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default,
tc.constraint_type
FROM information_schema.columns col
LEFT JOIN information_schema.constraint_column_usage ccu
ON ccu.column_name=col.column_name
LEFT JOIN information_schema.table_constraints tc
ON col.table_name=tc.table_name
AND col.table_schema=tc.table_schema
AND ccu.constraint_name=tc.constraint_name
AND ccu.constraint_schema=tc.constraint_schema
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
WHERE col.table_schema=$1
"""
data = await self.inst_db.fetch(q, self.inst_db.schema_name)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"].setdefault(
col_name,
{
"type": column["data_type"],
"nullable": column["is_nullable"],
"default": column["column_default"],
"primary": False,
"unique": False,
},
)
if column["constraint_type"] == "PRIMARY KEY":
tables[table_name]["columns"][col_name]["primary"] = True
elif column["constraint_type"] == "UNIQUE":
tables[table_name]["columns"][col_name]["unique"] = True
return tables
async def get_db_tables(self) -> dict:
if self.inst_db_tables is None:
if isinstance(self.inst_db, Engine):
self.inst_db_tables = self._introspect_sqlalchemy()
elif self.inst_db.scheme == Scheme.SQLITE:
self.inst_db_tables = await self._introspect_sqlite()
else:
self.inst_db_tables = await self._introspect_postgres()
self.inst_db_tables = metadata.tables
return self.inst_db_tables
async def load(self) -> bool:
def load(self) -> bool:
if not self.loader:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
await self.update_enabled(False)
self.db_instance.enabled = False
return False
if not self.client:
self.client = await Client.get(self.primary_user)
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
await self.update_enabled(False)
self.db_instance.enabled = False
return False
if self.loader.meta.webapp:
self.enable_webapp()
if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self)
self.client.references.add(self)
return True
def enable_webapp(self) -> None:
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None:
self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None
self.inst_webapp_url = None
@property
def _sqlite_db_path(self) -> str:
return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
async def delete(self) -> None:
def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
@ -264,89 +119,21 @@ class PluginInstance(DBInstance):
del self.cache[self.id]
except KeyError:
pass
await super().delete()
self.db.delete(self.db_instance)
self.db.commit()
if self.inst_db:
await self.stop_database()
await self.delete_database()
if self.inst_webapp:
self.disable_webapp()
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted")
def load_config(self) -> CommentedMap:
return yaml.load(self.config_str)
return yaml.load(self.db_instance.config)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
val = buf.getvalue()
if val != self.config_str:
self.config_str = val
self.log.debug("Creating background task to save updated config")
background_task.create(self.update())
async def start_database(
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
if self.database_engine is None:
await self.update_db_engine(DatabaseEngine.SQLITE)
elif self.database_engine == DatabaseEngine.POSTGRES:
raise RuntimeError(
"Instance database engine is marked as Postgres, but plugin uses legacy "
"database interface, which doesn't support postgres."
)
self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.database_engine is None:
if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db:
await self.update_db_engine(DatabaseEngine.SQLITE)
else:
await self.update_db_engine(DatabaseEngine.POSTGRES)
instance_db_log = db_log.getChild(self.id)
if self.database_engine == DatabaseEngine.POSTGRES:
if not self.maubot.plugin_postgres_db:
raise RuntimeError(
"Instance database engine is marked as Postgres, but this maubot isn't "
"configured to support Postgres for plugin databases"
)
self.inst_db = ProxyPostgresDatabase(
pool=self.maubot.plugin_postgres_db,
instance_id=self.id,
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
upgrade_table=upgrade_table,
log=instance_db_log,
)
else:
self.inst_db = Database.create(
f"sqlite:{self._sqlite_db_path}",
upgrade_table=upgrade_table,
log=instance_db_log,
)
if actually_start:
await self.inst_db.start()
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
async def stop_database(self) -> None:
if isinstance(self.inst_db, Database):
await self.inst_db.stop()
elif isinstance(self.inst_db, Engine):
self.inst_db.dispose()
else:
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
async def delete_database(self) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.inst_db is None:
await self.start_database(None, actually_start=False)
if isinstance(self.inst_db, ProxyPostgresDatabase):
await self.inst_db.delete()
else:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
self.inst_db = None
self.db_instance.config = buf.getvalue()
async def start(self) -> None:
if self.started:
@ -357,65 +144,30 @@ class PluginInstance(DBInstance):
return
if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not await self.load():
if not self.load():
return
cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None:
self.log.debug("Enabling webapp after plugin meta reload")
self.enable_webapp()
elif not self.loader.meta.webapp and self.inst_webapp is not None:
self.log.debug("Disabling webapp after plugin meta reload")
self.disable_webapp()
if self.loader.meta.database:
try:
await self.start_database(cls.get_db_upgrade_table())
except Exception:
self.log.exception("Failed to start instance database")
await self.update_enabled(False)
return
config_class = cls.get_config_class()
if config_class:
try:
base = await self.loader.read_file("base-config.yaml")
self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
buf = io.StringIO()
yaml.dump(self.base_cfg._data, buf)
self.base_cfg_str = buf.getvalue()
except (FileNotFoundError, KeyError):
self.base_cfg = None
self.base_cfg_str = None
if self.base_cfg:
base_cfg_func = self.base_cfg.clone
else:
def base_cfg_func() -> None:
return None
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(
client=self.client.client,
loop=self.maubot.loop,
http=self.client.http_client,
instance_id=self.id,
log=self.log,
config=self.config,
database=self.inst_db,
loader=self.loader,
webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url,
)
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config,
database=self.inst_db)
try:
await self.plugin.internal_start()
await self.plugin.start()
except Exception:
self.log.exception("Failed to start instance")
await self.update_enabled(False)
self.db_instance.enabled = False
return
self.started = True
self.inst_db_tables = None
self.log.info(
f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}"
)
self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}")
async def stop(self) -> None:
if not self.started:
@ -424,62 +176,67 @@ class PluginInstance(DBInstance):
self.log.debug("Stopping plugin instance...")
self.started = False
try:
await self.plugin.internal_stop()
await self.plugin.stop()
except Exception:
self.log.exception("Failed to stop instance")
self.plugin = None
if self.inst_db:
try:
await self.stop_database()
except Exception:
self.log.exception("Failed to stop instance database")
self.inst_db_tables = None
async def update_id(self, new_id: str | None) -> None:
if new_id is not None and new_id.lower() != self.id:
await super().update_id(new_id.lower())
@classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
) -> Optional['PluginInstance']:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.query.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
async def update_config(self, config: str | None) -> None:
if config is None or self.config_str == config:
@classmethod
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
return
self.config_str = config
self.db_instance.config = config
if self.started and self.plugin is not None:
res = self.plugin.on_external_config_update()
if inspect.isawaitable(res):
await res
await self.update()
self.plugin.on_external_config_update()
async def update_primary_user(self, primary_user: UserID | None) -> bool:
if primary_user is None or primary_user == self.primary_user:
async def update_primary_user(self, primary_user: UserID) -> bool:
if not primary_user or primary_user == self.primary_user:
return True
client = await Client.get(primary_user)
client = Client.get(primary_user)
if not client:
return False
await self.stop()
self.primary_user = client.id
self.db_instance.primary_user = client.id
if self.client:
self.client.references.remove(self)
self.client = client
self.client.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
async def update_type(self, type: str | None) -> bool:
if type is None or type == self.type:
async def update_type(self, type: str) -> bool:
if not type or type == self.type:
return True
try:
loader = PluginLoader.find(type)
except KeyError:
return False
await self.stop()
self.type = loader.meta.id
self.db_instance.type = loader.meta.id
if self.loader:
self.loader.references.remove(self)
self.loader = loader
self.loader.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}")
return True
@ -488,46 +245,37 @@ class PluginInstance(DBInstance):
if started is not None and started != self.started:
await (self.start() if started else self.stop())
async def update_enabled(self, enabled: bool) -> None:
def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
self.enabled = enabled
await self.update()
self.db_instance.enabled = enabled
async def update_db_engine(self, db_engine: DatabaseEngine | None) -> None:
if db_engine is not None and db_engine != self.database_engine:
self.database_engine = db_engine
await self.update()
# region Properties
@classmethod
@async_getter_lock
async def get(
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
pass
@property
def id(self) -> str:
return self.db_instance.id
instance = cast(cls, await super().get(instance_id))
if instance is not None:
instance.postinit()
return instance
@id.setter
def id(self, value: str) -> None:
self.db_instance.id = value
if type and primary_user:
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
await instance.insert()
instance.postinit()
return instance
@property
def type(self) -> str:
return self.db_instance.type
return None
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@classmethod
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
instances = await super().all()
instance: PluginInstance
for instance in instances:
try:
yield cls.cache[instance.id]
except KeyError:
instance.postinit()
yield instance
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
# endregion
def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]:
PluginInstance.db = db
PluginInstance.mb_config = config
PluginInstance.loop = loop
return PluginInstance.all()

View File

@ -1,49 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import (
MAU_COLOR,
MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
client = "maubot.client"
if module.startswith(client + "."):
suffix = ""
if module.endswith(".crypto"):
suffix = f".{MAU_COLOR}crypto{RESET}"
module = module[: -len(".crypto")]
module = module[len(client) + 1 :]
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
instance = "maubot.instance"
if module.startswith(instance + "."):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
instance_db = "maubot.instance_db"
if module.startswith(instance_db + "."):
return f"{MAU_COLOR}{instance_db}{RESET}.{INST_COLOR}{module[len(instance_db) + 1:]}{RESET}"
loader = "maubot.loader"
if module.startswith(loader + "."):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot."):
return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module)

View File

@ -1,9 +0,0 @@
from typing import Any, Awaitable, Callable, Generator
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
self._func = func
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View File

@ -1,19 +0,0 @@
try:
from sqlalchemy import MetaData, asc, create_engine, desc
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
except ImportError:
class FakeError(Exception):
pass
class FakeType:
def __init__(self, *args, **kwargs):
raise Exception("SQLAlchemy is not installed")
def create_engine(*args, **kwargs):
raise Exception("SQLAlchemy is not installed")
MetaData = Engine = FakeType
IntegrityError = OperationalError = FakeError
asc = desc = lambda a: a

View File

@ -1,100 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from contextlib import asynccontextmanager
import asyncio
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
from mautrix.util.async_db.connection import LoggingConnection
from mautrix.util.logging import TraceLogger
remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase
schema_name: str
_quoted_schema: str
_default_search_path: str
_conn_sema: asyncio.Semaphore
_max_conns: int
def __init__(
self,
pool: PostgresDatabase,
instance_id: str,
max_conns: int,
upgrade_table: UpgradeTable | None,
log: TraceLogger | None = None,
) -> None:
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
self._underlying_pool = pool
# Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}"
self._quoted_schema = f'"{self.schema_name}"'
self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
self._max_conns = max_conns
async def start(self) -> None:
async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.trace(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}")
await super().start()
async def stop(self) -> None:
for _ in range(self._max_conns):
try:
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
except asyncio.TimeoutError:
self.log.warning(
"Failed to drain plugin database connection pool, "
"the plugin may be leaking database connections"
)
break
async def delete(self) -> None:
self.log.info(f"Deleting schema {self.schema_name} and all data in it")
try:
await self._underlying_pool.execute(
f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE"
)
except Exception:
self.log.warning("Failed to delete schema", exc_info=True)
@asynccontextmanager
async def acquire(self) -> LoggingConnection:
conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._quoted_schema}")
try:
yield conn
finally:
if not conn.wrapped.is_closed():
try:
await conn.execute(f"SET search_path = {self._default_search_path}")
except Exception:
self.log.exception("Error resetting search_path after use")
await conn.wrapped.close()
else:
self.log.debug("Connection was closed after use, not resetting search_path")
__all__ = ["ProxyPostgresDatabase"]

View File

@ -1,27 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
try:
from mautrix.crypto import StateStore as CryptoStateStore
class PgStateStore(BasePgStateStore, CryptoStateStore):
pass
except ImportError as e:
PgStateStore = BasePgStateStore
__all__ = ["PgStateStore"]

View File

@ -18,28 +18,26 @@ used by the builtin import mechanism for sys.path items that are paths
to Zip archives.
"""
from importlib import _bootstrap # for _verbose_message
from importlib import _bootstrap_external
from importlib import _bootstrap # for _verbose_message
import _imp # for check_hash_based_pycs
import _io # for open
import marshal # for loads
import sys # for modules
import time # for mktime
import _imp # for check_hash_based_pycs
import _io # for open
__all__ = ["ZipImportError", "zipimporter"]
__all__ = ['ZipImportError', 'zipimporter']
def _unpack_uint32(data):
"""Convert 4 bytes in little-endian to an integer."""
assert len(data) == 4
return int.from_bytes(data, "little")
return int.from_bytes(data, 'little')
def _unpack_uint16(data):
"""Convert 2 bytes in little-endian to an integer."""
assert len(data) == 2
return int.from_bytes(data, "little")
return int.from_bytes(data, 'little')
path_sep = _bootstrap_external.path_sep
@ -49,17 +47,15 @@ alt_path_sep = _bootstrap_external.path_separators[1:]
class ZipImportError(ImportError):
pass
# _read_directory() cache
_zip_directory_cache = {}
_module_type = type(sys)
END_CENTRAL_DIR_SIZE = 22
STRING_END_ARCHIVE = b"PK\x05\x06"
STRING_END_ARCHIVE = b'PK\x05\x06'
MAX_COMMENT_LEN = (1 << 16) - 1
class zipimporter:
"""zipimporter(archivepath) -> zipimporter object
@ -81,10 +77,9 @@ class zipimporter:
def __init__(self, path):
if not isinstance(path, str):
import os
path = os.fsdecode(path)
if not path:
raise ZipImportError("archive path is empty", path=path)
raise ZipImportError('archive path is empty', path=path)
if alt_path_sep:
path = path.replace(alt_path_sep, path_sep)
@ -97,14 +92,14 @@ class zipimporter:
# Back up one path element.
dirname, basename = _bootstrap_external._path_split(path)
if dirname == path:
raise ZipImportError("not a Zip file", path=path)
raise ZipImportError('not a Zip file', path=path)
path = dirname
prefix.append(basename)
else:
# it exists
if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG
# it's a not file
raise ZipImportError("not a Zip file", path=path)
raise ZipImportError('not a Zip file', path=path)
break
try:
@ -159,10 +154,11 @@ class zipimporter:
# This is possibly a portion of a namespace
# package. Return the string representing its path,
# without a trailing separator.
return None, [f"{self.archive}{path_sep}{modpath}"]
return None, [f'{self.archive}{path_sep}{modpath}']
return None, []
# Check whether we can satisfy the import of the module named by
# 'fullname'. Return self if we can, None if we can't.
def find_module(self, fullname, path=None):
@ -176,6 +172,7 @@ class zipimporter:
"""
return self.find_loader(fullname, path)[0]
def get_code(self, fullname):
"""get_code(fullname) -> code object.
@ -185,6 +182,7 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname)
return code
def get_data(self, pathname):
"""get_data(pathname) -> string with file data.
@ -196,14 +194,15 @@ class zipimporter:
key = pathname
if pathname.startswith(self.archive + path_sep):
key = pathname[len(self.archive + path_sep) :]
key = pathname[len(self.archive + path_sep):]
try:
toc_entry = self._files[key]
except KeyError:
raise OSError(0, "", key)
raise OSError(0, '', key)
return _get_data(self.archive, toc_entry)
# Return a string matching __file__ for the named module
def get_filename(self, fullname):
"""get_filename(fullname) -> filename string.
@ -215,6 +214,7 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname)
return modpath
def get_source(self, fullname):
"""get_source(fullname) -> source string.
@ -228,9 +228,9 @@ class zipimporter:
path = _get_module_path(self, fullname)
if mi:
fullpath = _bootstrap_external._path_join(path, "__init__.py")
fullpath = _bootstrap_external._path_join(path, '__init__.py')
else:
fullpath = f"{path}.py"
fullpath = f'{path}.py'
try:
toc_entry = self._files[fullpath]
@ -239,6 +239,7 @@ class zipimporter:
return None
return _get_data(self.archive, toc_entry).decode()
# Return a bool signifying whether the module is a package or not.
def is_package(self, fullname):
"""is_package(fullname) -> bool.
@ -251,6 +252,7 @@ class zipimporter:
raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
return mi
# Load and return the module named by 'fullname'.
def load_module(self, fullname):
"""load_module(fullname) -> module.
@ -274,7 +276,7 @@ class zipimporter:
fullpath = _bootstrap_external._path_join(self.archive, path)
mod.__path__ = [fullpath]
if not hasattr(mod, "__builtins__"):
if not hasattr(mod, '__builtins__'):
mod.__builtins__ = __builtins__
_bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath)
exec(code, mod.__dict__)
@ -285,10 +287,11 @@ class zipimporter:
try:
mod = sys.modules[fullname]
except KeyError:
raise ImportError(f"Loaded module {fullname!r} not found in sys.modules")
_bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath)
raise ImportError(f'Loaded module {fullname!r} not found in sys.modules')
_bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath)
return mod
def get_resource_reader(self, fullname):
"""Return the ResourceReader for a package in a zip file.
@ -302,11 +305,11 @@ class zipimporter:
return None
if not _ZipImportResourceReader._registered:
from importlib.abc import ResourceReader
ResourceReader.register(_ZipImportResourceReader)
_ZipImportResourceReader._registered = True
return _ZipImportResourceReader(self, fullname)
def __repr__(self):
return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">'
@ -317,18 +320,16 @@ class zipimporter:
# are swapped by initzipimport() if we run in optimized mode. Also,
# '/' is replaced by path_sep there.
_zip_searchorder = (
(path_sep + "__init__.pyc", True, True),
(path_sep + "__init__.py", False, True),
(".pyc", True, False),
(".py", False, False),
(path_sep + '__init__.pyc', True, True),
(path_sep + '__init__.py', False, True),
('.pyc', True, False),
('.py', False, False),
)
# Given a module name, return the potential file path in the
# archive (without extension).
def _get_module_path(self, fullname):
return self.prefix + fullname.rpartition(".")[2]
return self.prefix + fullname.rpartition('.')[2]
# Does this path represent a directory?
def _is_dir(self, path):
@ -339,7 +340,6 @@ def _is_dir(self, path):
# If dirpath is present in self._files, we have a directory.
return dirpath in self._files
# Return some information about a module.
def _get_module_info(self, fullname):
path = _get_module_path(self, fullname)
@ -352,7 +352,6 @@ def _get_module_info(self, fullname):
# implementation
# _read_directory(archive) -> files dict (new reference)
#
# Given a path to a Zip archive, build a dict, mapping file names
@ -375,7 +374,7 @@ def _get_module_info(self, fullname):
# data_size and file_offset are 0.
def _read_directory(archive):
try:
fp = _io.open(archive, "rb")
fp = _io.open(archive, 'rb')
except OSError:
raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive)
@ -395,33 +394,36 @@ def _read_directory(archive):
fp.seek(0, 2)
file_size = fp.tell()
except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0)
raise ZipImportError(f"can't read Zip file: {archive!r}",
path=archive)
max_comment_start = max(file_size - MAX_COMMENT_LEN -
END_CENTRAL_DIR_SIZE, 0)
try:
fp.seek(max_comment_start)
data = fp.read()
except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
raise ZipImportError(f"can't read Zip file: {archive!r}",
path=archive)
pos = data.rfind(STRING_END_ARCHIVE)
if pos < 0:
raise ZipImportError(f"not a Zip file: {archive!r}", path=archive)
buffer = data[pos : pos + END_CENTRAL_DIR_SIZE]
raise ZipImportError(f'not a Zip file: {archive!r}',
path=archive)
buffer = data[pos:pos+END_CENTRAL_DIR_SIZE]
if len(buffer) != END_CENTRAL_DIR_SIZE:
raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive)
raise ZipImportError(f"corrupt Zip file: {archive!r}",
path=archive)
header_position = file_size - len(data) + pos
header_size = _unpack_uint32(buffer[12:16])
header_offset = _unpack_uint32(buffer[16:20])
if header_position < header_size:
raise ZipImportError(f"bad central directory size: {archive!r}", path=archive)
raise ZipImportError(f'bad central directory size: {archive!r}', path=archive)
if header_position < header_offset:
raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive)
raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive)
header_position -= header_size
arc_offset = header_position - header_offset
if arc_offset < 0:
raise ZipImportError(
f"bad central directory size or offset: {archive!r}", path=archive
)
raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive)
files = {}
# Start of Central Directory
@ -433,12 +435,12 @@ def _read_directory(archive):
while True:
buffer = fp.read(46)
if len(buffer) < 4:
raise EOFError("EOF read where not expected")
raise EOFError('EOF read where not expected')
# Start of file header
if buffer[:4] != b"PK\x01\x02":
if buffer[:4] != b'PK\x01\x02':
break # Bad: Central Dir File Header
if len(buffer) != 46:
raise EOFError("EOF read where not expected")
raise EOFError('EOF read where not expected')
flags = _unpack_uint16(buffer[8:10])
compress = _unpack_uint16(buffer[10:12])
time = _unpack_uint16(buffer[12:14])
@ -452,7 +454,7 @@ def _read_directory(archive):
file_offset = _unpack_uint32(buffer[42:46])
header_size = name_size + extra_size + comment_size
if file_offset > header_offset:
raise ZipImportError(f"bad local header offset: {archive!r}", path=archive)
raise ZipImportError(f'bad local header offset: {archive!r}', path=archive)
file_offset += arc_offset
try:
@ -476,19 +478,18 @@ def _read_directory(archive):
else:
# Historical ZIP filename encoding
try:
name = name.decode("ascii")
name = name.decode('ascii')
except UnicodeDecodeError:
name = name.decode("latin1").translate(cp437_table)
name = name.decode('latin1').translate(cp437_table)
name = name.replace("/", path_sep)
name = name.replace('/', path_sep)
path = _bootstrap_external._path_join(archive, name)
t = (path, compress, data_size, file_size, file_offset, time, date, crc)
files[name] = t
count += 1
_bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive)
_bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive)
return files
# During bootstrap, we may need to load the encodings
# package from a ZIP file. But the cp437 encoding is implemented
# in Python in the encodings package.
@ -497,36 +498,35 @@ def _read_directory(archive):
# the cp437 encoding.
cp437_table = (
# ASCII part, 8 rows x 16 chars
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f"
"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
" !\"#$%&'()*+,-./"
"0123456789:;<=>?"
"@ABCDEFGHIJKLMNO"
"PQRSTUVWXYZ[\\]^_"
"`abcdefghijklmno"
"pqrstuvwxyz{|}~\x7f"
'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f'
' !"#$%&\'()*+,-./'
'0123456789:;<=>?'
'@ABCDEFGHIJKLMNO'
'PQRSTUVWXYZ[\\]^_'
'`abcdefghijklmno'
'pqrstuvwxyz{|}~\x7f'
# non-ASCII part, 16 rows x 8 chars
"\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7"
"\xea\xeb\xe8\xef\xee\xec\xc4\xc5"
"\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9"
"\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192"
"\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba"
"\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb"
"\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556"
"\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510"
"\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f"
"\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567"
"\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b"
"\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580"
"\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4"
"\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229"
"\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248"
"\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0"
'\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7'
'\xea\xeb\xe8\xef\xee\xec\xc4\xc5'
'\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9'
'\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192'
'\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba'
'\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb'
'\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556'
'\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510'
'\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f'
'\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567'
'\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b'
'\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580'
'\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4'
'\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229'
'\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248'
'\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0'
)
_importing_zlib = False
# Return the zlib.decompress function object, or NULL if zlib couldn't
# be imported. The function is cached when found, so subsequent calls
# don't import zlib again.
@ -535,29 +535,28 @@ def _get_decompress_func():
if _importing_zlib:
# Someone has a zlib.py[co] in their Zip file
# let's avoid a stack overflow.
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
raise ZipImportError("can't decompress data; zlib not available")
_importing_zlib = True
try:
from zlib import decompress
except Exception:
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
raise ZipImportError("can't decompress data; zlib not available")
finally:
_importing_zlib = False
_bootstrap._verbose_message("zipimport: zlib available")
_bootstrap._verbose_message('zipimport: zlib available')
return decompress
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
def _get_data(archive, toc_entry):
datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry
if data_size < 0:
raise ZipImportError("negative data size")
raise ZipImportError('negative data size')
with _io.open(archive, "rb") as fp:
with _io.open(archive, 'rb') as fp:
# Check to make sure the local file header is correct
try:
fp.seek(file_offset)
@ -565,11 +564,11 @@ def _get_data(archive, toc_entry):
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
buffer = fp.read(30)
if len(buffer) != 30:
raise EOFError("EOF read where not expected")
raise EOFError('EOF read where not expected')
if buffer[:4] != b"PK\x03\x04":
if buffer[:4] != b'PK\x03\x04':
# Bad: Local File Header
raise ZipImportError(f"bad local file header: {archive!r}", path=archive)
raise ZipImportError(f'bad local file header: {archive!r}', path=archive)
name_size = _unpack_uint16(buffer[26:28])
extra_size = _unpack_uint16(buffer[28:30])
@ -602,17 +601,16 @@ def _eq_mtime(t1, t2):
# dostime only stores even seconds, so be lenient
return abs(t1 - t2) <= 1
# Given the contents of a .py[co] file, unmarshal the data
# and return the code object. Return None if it the magic word doesn't
# match (we do this instead of raising an exception as we fall back
# to .py if available and we don't want to mask other errors).
def _unmarshal_code(pathname, data, mtime):
if len(data) < 16:
raise ZipImportError("bad pyc data")
raise ZipImportError('bad pyc data')
if data[:4] != _bootstrap_external.MAGIC_NUMBER:
_bootstrap._verbose_message("{!r} has bad magic", pathname)
_bootstrap._verbose_message('{!r} has bad magic', pathname)
return None # signal caller to try alternative
flags = _unpack_uint32(data[4:8])
@ -621,57 +619,47 @@ def _unmarshal_code(pathname, data, mtime):
# pycs. We could validate hash-based pycs against the source, but it
# seems likely that most people putting hash-based pycs in a zipfile
# will use unchecked ones.
if _imp.check_hash_based_pycs != "never" and (
flags != 0x1 or _imp.check_hash_based_pycs == "always"
):
if (_imp.check_hash_based_pycs != 'never' and
(flags != 0x1 or _imp.check_hash_based_pycs == 'always')):
return None
elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime):
_bootstrap._verbose_message("{!r} has bad mtime", pathname)
_bootstrap._verbose_message('{!r} has bad mtime', pathname)
return None # signal caller to try alternative
# XXX the pyc's size field is ignored; timestamp collisions are probably
# unimportant with zip files.
code = marshal.loads(data[16:])
if not isinstance(code, _code_type):
raise TypeError(f"compiled module {pathname!r} is not a code object")
raise TypeError(f'compiled module {pathname!r} is not a code object')
return code
_code_type = type(_unmarshal_code.__code__)
# Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source):
source = source.replace(b"\r\n", b"\n")
source = source.replace(b"\r", b"\n")
source = source.replace(b'\r\n', b'\n')
source = source.replace(b'\r', b'\n')
return source
# Given a string buffer containing Python source code, compile it
# and return a code object.
def _compile_source(pathname, source):
source = _normalize_line_endings(source)
return compile(source, pathname, "exec", dont_inherit=True)
return compile(source, pathname, 'exec', dont_inherit=True)
# Convert the date/time values found in the Zip archive to a value
# that's compatible with the time stamp stored in .pyc files.
def _parse_dostime(d, t):
return time.mktime(
(
return time.mktime((
(d >> 9) + 1980, # bits 9..15: year
(d >> 5) & 0xF, # bits 5..8: month
d & 0x1F, # bits 0..4: day
t >> 11, # bits 11..15: hours
(t >> 5) & 0x3F, # bits 8..10: minutes
(t & 0x1F) * 2, # bits 0..7: seconds / 2
-1,
-1,
-1,
)
)
-1, -1, -1))
# Given a path to a .pyc file in the archive, return the
# modification time of the matching .py file, or 0 if no source
@ -679,7 +667,7 @@ def _parse_dostime(d, t):
def _get_mtime_of_source(self, path):
try:
# strip 'c' or 'o' from *.py[co]
assert path[-1:] in ("c", "o")
assert path[-1:] in ('c', 'o')
path = path[:-1]
toc_entry = self._files[path]
# fetch the time stamp of the .py file for comparison
@ -690,14 +678,13 @@ def _get_mtime_of_source(self, path):
except (KeyError, IndexError, TypeError):
return 0
# Get the code object associated with the module specified by
# 'fullname'.
def _get_module_code(self, fullname):
path = _get_module_path(self, fullname)
for suffix, isbytecode, ispackage in _zip_searchorder:
fullpath = path + suffix
_bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2)
_bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2)
try:
toc_entry = self._files[fullpath]
except KeyError:
@ -726,7 +713,6 @@ class _ZipImportResourceReader:
This class is allowed to reference all the innards and private parts of
the zipimporter.
"""
_registered = False
def __init__(self, zipimporter, fullname):
@ -734,10 +720,9 @@ class _ZipImportResourceReader:
self.fullname = fullname
def open_resource(self, resource):
fullname_as_path = self.fullname.replace(".", "/")
path = f"{fullname_as_path}/{resource}"
fullname_as_path = self.fullname.replace('.', '/')
path = f'{fullname_as_path}/{resource}'
from io import BytesIO
try:
return BytesIO(self.zipimporter.get_data(path))
except OSError:
@ -752,8 +737,8 @@ class _ZipImportResourceReader:
def is_resource(self, name):
# Maybe we could do better, but if we can get the data, it's a
# resource. Otherwise it isn't.
fullname_as_path = self.fullname.replace(".", "/")
path = f"{fullname_as_path}/{name}"
fullname_as_path = self.fullname.replace('.', '/')
path = f'{fullname_as_path}/{name}'
try:
self.zipimporter.get_data(path)
except OSError:
@ -769,12 +754,11 @@ class _ZipImportResourceReader:
# top of the archive, and then we iterate through _files looking for
# names inside that "directory".
from pathlib import Path
fullname_path = Path(self.zipimporter.get_filename(self.fullname))
relative_path = fullname_path.relative_to(self.zipimporter.archive)
# Don't forget that fullname names a package, so its path will include
# __init__.py, which we want to ignore.
assert relative_path.name == "__init__.py"
assert relative_path.name == '__init__.py'
package_path = relative_path.parent
subdirs_seen = set()
for filename in self.zipimporter._files:

View File

@ -1,3 +1,2 @@
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
from .meta import DatabaseType, PluginMeta
from .zip import MaubotZipImportError, ZippedPluginLoader
from .abc import PluginLoader, PluginClass, IDConflictError, PluginMeta
from .zip import ZippedPluginLoader, MaubotZipImportError

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING
from abc import ABC, abstractmethod
import asyncio
from attr import dataclass
from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer)
from ..__meta__ import __version__
from ..plugin_base import Plugin
from .meta import PluginMeta
if TYPE_CHECKING:
from ..instance import PluginInstance
@ -32,40 +35,45 @@ class IDConflictError(Exception):
pass
class BasePluginLoader(ABC):
meta: PluginMeta
@property
@abstractmethod
def source(self) -> str:
pass
def sync_read_file(self, path: str) -> bytes:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
async def read_file(self, path: str) -> bytes:
pass
def sync_list_files(self, directory: str) -> list[str]:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
async def list_files(self, directory: str) -> list[str]:
pass
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
class PluginLoader(BasePluginLoader, ABC):
id_cache: dict[str, PluginLoader] = {}
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
@dataclass
class PluginMeta(SerializableAttrs['PluginMeta']):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
meta: PluginMeta
references: set[PluginInstance]
references: Set['PluginInstance']
def __init__(self):
self.references = set()
@classmethod
def find(cls, plugin_id: str) -> PluginLoader:
def find(cls, plugin_id: str) -> 'PluginLoader':
return cls.id_cache[plugin_id]
def to_dict(self) -> dict:
@ -75,22 +83,33 @@ class PluginLoader(BasePluginLoader, ABC):
"instances": [instance.to_dict() for instance in self.references],
}
async def stop_instances(self) -> None:
await asyncio.gather(
*[instance.stop() for instance in self.references if instance.started]
)
async def start_instances(self) -> None:
await asyncio.gather(
*[instance.start() for instance in self.references if instance.enabled]
)
@property
@abstractmethod
async def load(self) -> type[PluginClass]:
def source(self) -> str:
pass
@abstractmethod
async def reload(self) -> type[PluginClass]:
async def read_file(self, path: str) -> bytes:
pass
async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance
in self.references if instance.started])
async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance
in self.references if instance.enabled])
@abstractmethod
async def load(self) -> Type[PluginClass]:
pass
@abstractmethod
async def reload(self) -> Type[PluginClass]:
pass
@abstractmethod
async def unload(self) -> None:
pass
@abstractmethod

View File

@ -1,69 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Optional
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import (
ExtensibleEnum,
SerializableAttrs,
SerializerError,
deserializer,
serializer,
)
from ..__meta__ import __version__
@serializer(Version)
def serialize_version(version: Version) -> str:
return str(version)
@deserializer(Version)
def deserialize_version(version: str) -> Version:
try:
return Version(version)
except InvalidVersion as e:
raise SerializerError("Invalid version") from e
class DatabaseType(ExtensibleEnum):
SQLALCHEMY = "sqlalchemy"
ASYNCPG = "asyncpg"
@dataclass
class PluginMeta(SerializableAttrs):
id: str
version: Version
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
database_type: DatabaseType = DatabaseType.SQLALCHEMY
config: bool = False
webapp: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []
@property
def database_type_str(self) -> Optional[str]:
return self.database_type.value if self.database else None

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,27 +13,22 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Dict, List, Type, Tuple, Optional
from zipfile import ZipFile, BadZipFile
from time import time
from zipfile import BadZipFile, ZipFile
import logging
import os
import sys
import os
from packaging.version import Version
from ruamel.yaml import YAML, YAMLError
from packaging.version import Version
from mautrix.client.api.types.util import SerializerError
from mautrix.types import SerializerError
from ..__meta__ import __version__
from ..config import Config
from ..lib.zipimport import ZipImportError, zipimporter
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader
from .meta import DatabaseType, PluginMeta
from ..config import Config
from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError
current_version = Version(__version__)
yaml = YAML()
@ -54,25 +49,23 @@ class MaubotZipLoadError(MaubotZipImportError):
class ZippedPluginLoader(PluginLoader):
path_cache: dict[str, ZippedPluginLoader] = {}
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
log: logging.Logger = logging.getLogger("maubot.loader.zip")
trash_path: str = "delete"
directories: list[str] = []
directories: List[str] = []
path: str | None
meta: PluginMeta | None
main_class: str | None
main_module: str | None
_loaded: type[PluginClass] | None
_importer: zipimporter | None
_file: ZipFile | None
path: str
meta: PluginMeta
main_class: str
main_module: str
_loaded: Type[PluginClass]
_importer: zipimporter
_file: ZipFile
def __init__(self, path: str) -> None:
super().__init__()
self.path = path
self.meta = None
self.main_class = None
self.main_module = None
self._loaded = None
self._importer = None
self._file = None
@ -81,8 +74,7 @@ class ZippedPluginLoader(PluginLoader):
try:
existing = self.id_cache[self.meta.id]
raise IDConflictError(
f"Plugin with id {self.meta.id} already loaded from {existing.source}"
)
f"Plugin with id {self.meta.id} already loaded from {existing.source}")
except KeyError:
pass
self.path_cache[self.path] = self
@ -90,10 +82,13 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}")
def to_dict(self) -> dict:
return {**super().to_dict(), "path": self.path}
return {
**super().to_dict(),
"path": self.path
}
@classmethod
def get(cls, path: str) -> ZippedPluginLoader:
def get(cls, path: str) -> 'ZippedPluginLoader':
path = os.path.abspath(path)
try:
return cls.path_cache[path]
@ -105,32 +100,16 @@ class ZippedPluginLoader(PluginLoader):
return self.path
def __repr__(self) -> str:
return (
"<ZippedPlugin "
return ("<ZippedPlugin "
f"path='{self.path}' "
f"meta={self.meta} "
f"loaded={self._loaded is not None}>"
)
def sync_read_file(self, path: str) -> bytes:
return self._file.read(path)
f"loaded={self._loaded is not None}>")
async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path)
def sync_list_files(self, directory: str) -> list[str]:
directory = directory.rstrip("/")
return [
file.filename
for file in self._file.filelist
if os.path.dirname(file.filename) == directory
]
async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory)
return self._file.read(path)
@staticmethod
def _read_meta(source) -> tuple[ZipFile, PluginMeta]:
def _read_meta(source) -> Tuple[ZipFile, PluginMeta]:
try:
file = ZipFile(source)
data = file.read("maubot.yaml")
@ -148,16 +127,12 @@ class ZippedPluginLoader(PluginLoader):
meta = PluginMeta.deserialize(meta_dict)
except SerializerError as e:
raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e
if meta.maubot > current_version:
raise MaubotZipMetaError(
f"Plugin requires maubot {meta.maubot}, but this instance is {current_version}"
)
return file, meta
@classmethod
def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]:
def verify_meta(cls, source) -> Tuple[str, Version]:
_, meta = cls._read_meta(source)
return meta.id, meta.version, meta.database_type if meta.database else None
return meta.id, meta.version
def _load_meta(self) -> None:
file, meta = self._read_meta(self.path)
@ -167,7 +142,7 @@ class ZippedPluginLoader(PluginLoader):
if "/" in meta.main_class:
self.main_module, self.main_class = meta.main_class.split("/")[:2]
else:
self.main_module = meta.modules[-1]
self.main_module = meta.modules[0]
self.main_class = meta.main_class
self._file = file
@ -186,24 +161,24 @@ class ZippedPluginLoader(PluginLoader):
code = importer.get_code(self.main_module.replace(".", "/"))
if self.main_class not in code.co_names:
raise MaubotZipPreLoadError(
f"Main class {self.main_class} not in {self.main_module}"
)
f"Main class {self.main_class} not in {self.main_module}")
except ZipImportError as e:
raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e
raise MaubotZipPreLoadError(
f"Main module {self.main_module} not found in file") from e
for module in self.meta.modules:
try:
importer.find_module(module)
except ZipImportError as e:
raise MaubotZipPreLoadError(f"Module {module} not found in file") from e
async def load(self, reset_cache: bool = False) -> type[PluginClass]:
async def load(self, reset_cache: bool = False) -> Type[PluginClass]:
try:
return self._load(reset_cache)
except MaubotZipImportError:
self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}")
raise
def _load(self, reset_cache: bool = False) -> type[PluginClass]:
def _load(self, reset_cache: bool = False) -> Type[PluginClass]:
if self._loaded is not None and not reset_cache:
return self._loaded
self._load_meta()
@ -232,26 +207,21 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}")
return plugin
async def reload(self, new_path: str | None = None) -> type[PluginClass]:
self._unload()
if new_path is not None and new_path != self.path:
try:
del self.path_cache[self.path]
except KeyError:
pass
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
await self.unload()
if new_path is not None:
self.path = new_path
self.path_cache[self.path] = self
return await self.load(reset_cache=True)
def _unload(self) -> None:
async def unload(self) -> None:
for name, mod in list(sys.modules.items()):
if (getattr(mod, "__file__", "") or "").startswith(self.path):
if getattr(mod, "__file__", "").startswith(self.path):
del sys.modules[name]
self._loaded = None
self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}")
async def delete(self) -> None:
self._unload()
await self.unload()
try:
del self.path_cache[self.path]
except KeyError:
@ -269,22 +239,12 @@ class ZippedPluginLoader(PluginLoader):
self.path = None
@classmethod
def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None:
def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None:
if cls.trash_path == "delete":
try:
os.remove(file_path)
except FileNotFoundError:
pass
else:
new_name = new_name or f"{int(time())}-{reason}-{os.path.basename(file_path)}"
try:
os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name)))
except OSError as e:
cls.log.warning(f"Failed to rename {file_path}: {e} - trying to delete")
try:
os.remove(file_path)
except FileNotFoundError:
pass
@classmethod
def load_all(cls):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from asyncio import AbstractEventLoop
import importlib
from aiohttp import web
from ...config import Config
from .base import routes, get_config, set_config, set_loop
from .auth import check_token
from .base import get_config, routes, set_config
from .middleware import auth, error
@ -31,15 +30,14 @@ def features(request: web.Request) -> web.Response:
if err is None:
return web.json_response(data)
else:
return web.json_response(
{
return web.json_response({
"login": data["login"],
}
)
})
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
set_config(cfg)
set_loop(loop)
for pkg, enabled in cfg["api_features"].items():
if enabled:
importlib.import_module(f"maubot.management.api.{pkg}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Optional
from time import time
from aiohttp import web
@ -22,7 +21,7 @@ from aiohttp import web
from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token
from .base import get_config, routes
from .base import routes, get_config
from .responses import resp
@ -34,25 +33,22 @@ def is_valid_token(token: str) -> bool:
def create_token(user: UserID) -> str:
return sign_token(
get_config()["server.unshared_secret"],
{
return sign_token(get_config()["server.unshared_secret"], {
"user_id": user,
"created_at": int(time()),
},
)
})
def get_token(request: web.Request) -> str:
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", "")
token = request.query.get("access_token", None)
else:
token = token[len("Bearer ") :]
token = token[len("Bearer "):]
return token
def check_token(request: web.Request) -> web.Response | None:
def check_token(request: web.Request) -> Optional[web.Response]:
token = get_token(request)
if not token:
return resp.no_token

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,17 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import asyncio
from aiohttp import web
import asyncio
from ...__meta__ import __version__
from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef()
_config: Config | None = None
_config: Config = None
_loop: asyncio.AbstractEventLoop = None
def set_config(config: Config) -> None:
@ -35,6 +33,17 @@ def get_config() -> Config:
return _config
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
global _loop
_loop = loop
def get_loop() -> asyncio.AbstractEventLoop:
return _loop
@routes.get("/version")
async def version(_: web.Request) -> web.Response:
return web.json_response({"version": __version__})
return web.json_response({
"version": __version__
})

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,23 +13,20 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Optional
from json import JSONDecodeError
import logging
from aiohttp import web
from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient
from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError
from mautrix.types import FilterID, SyncToken, UserID
from ...db import DBClient
from ...client import Client
from .base import routes
from .responses import resp
log = logging.getLogger("maubot.server.client")
@routes.get("/clients")
async def get_clients(_: web.Request) -> web.Response:
@ -39,94 +36,63 @@ async def get_clients(_: web.Request) -> web.Response:
@routes.get("/client/{id}")
async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = await Client.get(user_id)
client = Client.get(user_id, None)
if not client:
return resp.client_not_found
return resp.found(client.to_dict())
async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
device_id = data.get("device_id", None)
new_client = MatrixClient(
mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
client_session=Client.http_client,
)
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
loop=Client.loop, client_session=Client.http_client)
try:
whoami = await new_client.whoami()
except MatrixInvalidToken as e:
mxid = await new_client.whoami()
except MatrixInvalidToken:
return resp.bad_client_access_token
except MatrixRequestError:
log.warning(f"Failed to get whoami from {homeserver} for new client", exc_info=True)
return resp.bad_client_access_details
except MatrixConnectionError:
log.warning(f"Failed to connect to {homeserver} for new client", exc_info=True)
return resp.bad_client_connection_details
if user_id is None:
existing_client = await Client.get(whoami.user_id)
existing_client = Client.get(mxid, None)
if existing_client is not None:
return resp.user_exists
elif whoami.user_id != user_id:
return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
client = await Client.get(
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
)
client.enabled = data.get("enabled", True)
client.sync = data.get("sync", True)
client.autojoin = data.get("autojoin", True)
client.online = data.get("online", True)
client.displayname = data.get("displayname", "disable")
client.avatar_url = data.get("avatar_url", "disable")
await client.update()
elif mxid != user_id:
return resp.mxid_mismatch(mxid)
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
client = Client(db_instance)
Client.db.add(db_instance)
Client.db.commit()
await client.start()
return resp.created(client.to_dict())
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
async def _update_client(client: Client, data: dict) -> web.Response:
try:
await client.update_access_details(
data.get("access_token"), data.get("homeserver"), data.get("device_id")
)
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None))
except MatrixInvalidToken:
return resp.bad_client_access_token
except MatrixRequestError:
log.warning(
f"Failed to get whoami from homeserver to update client details", exc_info=True
)
return resp.bad_client_access_details
except MatrixConnectionError:
log.warning(f"Failed to connect to homeserver to update client details", exc_info=True)
return resp.bad_client_connection_details
except ValueError as e:
str_err = str(e)
if str_err.startswith("MXID mismatch"):
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
await client.update_avatar_url(data.get("avatar_url"), save=False)
await client.update_displayname(data.get("displayname"), save=False)
await client.update_started(data.get("started"))
await client.update_enabled(data.get("enabled"), save=False)
await client.update_autojoin(data.get("autojoin"), save=False)
await client.update_online(data.get("online"), save=False)
await client.update_sync(data.get("sync"), save=False)
await client.update()
return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = await Client.get(user_id)
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data, is_login=is_login)
return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):])
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
await client.update_started(data.get("started", None))
client.enabled = data.get("enabled", client.enabled)
client.autojoin = data.get("autojoin", client.autojoin)
client.sync = data.get("sync", client.sync)
return resp.updated(client.to_dict())
@routes.post("/client/new")
@ -140,33 +106,27 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info["id"]
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try:
data = await request.json()
except JSONDecodeError:
return resp.body_not_json
return await _create_or_update_client(user_id, data)
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data)
@routes.delete("/client/{id}")
async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info["id"]
client = await Client.get(user_id)
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return resp.client_not_found
if len(client.references) > 0:
return resp.client_in_use
if client.started:
await client.stop()
await client.delete()
client.delete()
return resp.deleted
@routes.post("/client/{id}/clearcache")
async def clear_client_cache(request: web.Request) -> web.Response:
user_id = request.match_info["id"]
client = await Client.get(user_id)
if not client:
return resp.client_not_found
await client.clear_cache()
return resp.ok

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,261 +13,109 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import NamedTuple
from http import HTTPStatus
from typing import Dict, Tuple, NamedTuple, Optional
from json import JSONDecodeError
import asyncio
import hashlib
import hmac
import random
import string
import hashlib
from aiohttp import web
from yarl import URL
from mautrix.api import Method, Path, SynapseAdminPath
from mautrix.client import ClientAPI
from mautrix.api import HTTPAPI, Path, Method
from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import get_config, routes
from .client import _create_client, _create_or_update_client
from .base import routes, get_config, get_loop
from .responses import resp
def known_homeservers() -> dict[str, dict[str, str]]:
return get_config()["homeservers"]
def registration_secrets() -> Dict[str, Dict[str, str]]:
return get_config()["registration_secrets"]
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False):
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(user.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
return mac.hexdigest()
@routes.get("/client/auth/servers")
async def get_known_servers(_: web.Request) -> web.Response:
return web.json_response({key: value["url"] for key, value in known_homeservers().items()})
async def get_registerable_servers(_: web.Request) -> web.Response:
return web.json_response(list(registration_secrets().keys()))
class AuthRequestInfo(NamedTuple):
server_name: str
client: ClientAPI
secret: str
username: str
password: str
user_type: str
device_name: str
update_client: bool
sso: bool
AuthRequestInfo = NamedTuple("AuthRequestInfo", api=HTTPAPI, secret=str, username=str, password=str)
truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(
request: web.Request,
) -> tuple[AuthRequestInfo | None, web.Response | None]:
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
Optional[web.Response]]:
server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None)
server = registration_secrets().get(server_name, None)
if not server:
return None, resp.server_not_found
try:
body = await request.json()
except JSONDecodeError:
return None, resp.body_not_json
sso = request.query.get("sso", "").lower() in truthy_strings
try:
username = body["username"]
password = body["password"]
except KeyError:
if not sso:
return None, resp.username_or_password_missing
username = password = None
try:
base_url = server["url"]
secret = server["secret"]
except KeyError:
return None, resp.invalid_server
return (
AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
),
None,
)
def generate_mac(
secret: str,
nonce: str,
username: str,
password: str,
admin: bool = False,
user_type: str = None,
) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(username.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type is not None:
mac.update(b"\x00")
mac.update(user_type.encode("utf8"))
return mac.hexdigest()
api = HTTPAPI(base_url, "", loop=get_loop())
return (api, secret, username, password), None
@routes.post("/client/auth/{server}/register")
async def register(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
info, err = await read_client_auth_request(request)
if err is not None:
return err
if req.sso:
return resp.registration_no_sso
elif not req.secret:
return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register
res = await req.client.api.request(Method.GET, path)
content = {
"nonce": res["nonce"],
"username": req.username,
"password": req.password,
"admin": False,
"user_type": req.user_type,
}
content["mac"] = generate_mac(**content, secret=req.secret)
api, secret, username, password = info
res = await api.request(Method.GET, Path.admin.register)
nonce = res["nonce"]
mac = generate_mac(secret, nonce, username, password)
try:
raw_res = await req.client.api.request(Method.POST, path, content=content)
return web.json_response(await api.request(Method.POST, Path.admin.register, content={
"nonce": nonce,
"username": username,
"password": password,
"admin": False,
"mac": mac,
}))
except MatrixRequestError as e:
return web.json_response(
{
return web.json_response({
"errcode": e.errcode,
"error": e.message,
"http_status": e.http_status,
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
login_res = LoginResponse.deserialize(raw_res)
if req.update_client:
return await _create_client(
login_res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
},
)
return web.json_response(login_res.serialize())
}, status=e.http_status)
@routes.post("/client/auth/{server}/login")
async def login(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
info, err = await read_client_auth_request(request)
if err is not None:
return err
if req.sso:
return await _do_sso(req)
else:
return await _do_login(req)
async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (
URL(cfg["server.public_url"])
/ "_matrix/maubot/v1/client/auth_external_sso/complete"
/ waiter_id
)
sso_url = req.client.api.base_url.with_path(str(Path.v3.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
api, _, username, password = info
try:
if req.sso:
res = await req.client.login(
token=login_token,
login_type=LoginType.TOKEN,
device_id=device_id,
store_access_token=False,
initial_device_display_name=req.device_name,
)
else:
res = await req.client.login(
identifier=req.username,
login_type=LoginType.PASSWORD,
password=req.password,
device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False,
)
return web.json_response(await api.request(Method.POST, Path.login, content={
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": username,
},
"password": password,
"device_id": "maubot",
}))
except MatrixRequestError as e:
return web.json_response(
{
return web.json_response({
"errcode": e.errcode,
"error": e.message,
},
status=e.http_status,
)
if req.update_client:
return await _create_or_update_client(
res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
},
is_login=True,
)
return web.json_response(res.serialize())
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
async def wait_sso(request: web.Request) -> web.Response:
waiter_id = request.match_info["id"]
req, fut = sso_waiters[waiter_id]
try:
login_token = await fut
finally:
sso_waiters.pop(waiter_id, None)
return await _do_login(req, login_token)
@routes.get("/client/auth_external_sso/complete/{id}")
async def complete_sso(request: web.Request) -> web.Response:
try:
_, fut = sso_waiters[request.match_info["id"]]
except KeyError:
return web.Response(status=404, text="Invalid session ID\n")
if fut.cancelled():
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
elif fut.done():
return web.Response(status=200, text="The login token was already received\n")
try:
fut.set_result(request.query["loginToken"])
except KeyError:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(
status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n",
)
}, status=e.http_status)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import client as http, web
from aiohttp import web, client as http
from ...client import Client
from .base import routes
@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
@routes.view("/proxy/{id}/{path:_matrix/.+}")
async def proxy(request: web.Request) -> web.StreamResponse:
user_id = request.match_info.get("id", None)
client = await Client.get(user_id)
client = Client.get(user_id, None)
if not client:
return resp.client_not_found
@ -36,7 +36,6 @@ async def proxy(request: web.Request) -> web.StreamResponse:
except KeyError:
pass
headers = request.headers.copy()
del headers["Host"]
headers["Authorization"] = f"Bearer {client.access_token}"
if "X-Forwarded-For" not in headers:
peer = request.transport.get_extra_info("peername")
@ -45,9 +44,8 @@ async def proxy(request: web.Request) -> web.StreamResponse:
headers["X-Forwarded-For"] = f"{host}:{port}"
data = await request.read()
async with http.request(
request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data
) as proxy_resp:
async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers,
params=query, data=data) as proxy_resp:
response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers)
await response.prepare(request)
async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from string import Template
import asyncio
from subprocess import run
import re
from aiohttp import web
from ruamel.yaml import YAML
from aiohttp import web
from .base import routes
@ -27,7 +27,9 @@ enabled = False
@routes.get("/debug/open")
async def check_enabled(_: web.Request) -> web.Response:
return web.json_response({"enabled": enabled})
return web.json_response({
"enabled": enabled,
})
try:
@ -38,6 +40,7 @@ try:
editor_command = Template(cfg["editor"])
pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]]
@routes.post("/debug/open")
async def open_file(request: web.Request) -> web.Response:
data = await request.json()
@ -48,9 +51,13 @@ try:
cmd = editor_command.substitute(path=path, line=data["line"])
except (KeyError, ValueError):
return web.Response(status=400)
res = await asyncio.create_subprocess_shell(cmd)
stdout, stderr = await res.communicate()
return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr})
res = run(cmd, shell=True)
return web.json_response({
"return": res.returncode,
"stdout": res.stdout,
"stderr": res.stderr
})
enabled = True
except Exception:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,10 @@ from json import JSONDecodeError
from aiohttp import web
from ...client import Client
from ...db import DBPlugin
from ...instance import PluginInstance
from ...loader import PluginLoader
from ...client import Client
from .base import routes
from .responses import resp
@ -31,50 +32,52 @@ async def get_instances(_: web.Request) -> web.Response:
@routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return resp.instance_not_found
return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type")
primary_user = data.get("primary_user")
plugin_type = data.get("type", None)
primary_user = data.get("primary_user", None)
if not plugin_type:
return resp.plugin_type_required
elif not primary_user:
return resp.primary_user_required
elif not await Client.get(primary_user):
elif not Client.get(primary_user):
return resp.primary_user_not_found
try:
PluginLoader.find(plugin_type)
except KeyError:
return resp.plugin_type_not_found
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
instance.enabled = data.get("enabled", True)
instance.config_str = data.get("config") or ""
await instance.update()
await instance.load()
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance)
instance.load()
PluginInstance.db.add(db_instance)
PluginInstance.db.commit()
await instance.start()
return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user")):
if not await instance.update_primary_user(data.get("primary_user", None)):
return resp.primary_user_not_found
await instance.update_id(data.get("id"))
await instance.update_enabled(data.get("enabled"))
await instance.update_config(data.get("config"))
await instance.update_started(data.get("started"))
await instance.update_type(data.get("type"))
instance.update_id(data.get("id", None))
instance.update_enabled(data.get("enabled", None))
instance.update_config(data.get("config", None))
await instance.update_started(data.get("started", None))
await instance.update_type(data.get("type", None))
instance.db.commit()
return resp.updated(instance.to_dict())
@routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
try:
data = await request.json()
except JSONDecodeError:
@ -87,11 +90,11 @@ async def update_instance(request: web.Request) -> web.Response:
@routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return resp.instance_not_found
if instance.started:
await instance.stop()
await instance.delete()
instance.delete()
return resp.deleted

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,67 +13,80 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Union, TYPE_CHECKING
from datetime import datetime
from aiohttp import web
from asyncpg import PostgresError
import aiosqlite
from mautrix.util.async_db import Database
from sqlalchemy import Table, Column, asc, desc, exc
from sqlalchemy.orm import Query
from sqlalchemy.engine.result import ResultProxy, RowProxy
from ...instance import PluginInstance
from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc
from .base import routes
from .responses import resp
@routes.get("/instance/{id}/database")
async def get_database(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
return resp.plugin_has_no_database
return web.json_response(await instance.get_db_tables())
if TYPE_CHECKING:
table: Table
column: Column
return web.json_response({
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
} for column in table.columns
},
} for table in instance.get_db_tables().values()
})
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
@routes.get("/instance/{id}/database/{table}")
async def get_table(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
return resp.plugin_has_no_database
tables = await instance.get_db_tables()
tables = instance.get_db_tables()
try:
table = tables[request.match_info.get("table", "")]
except KeyError:
return resp.table_not_found
try:
order = [tuple(order.split(":")) for order in request.query.getall("order")]
order = [
(
(asc if sort.lower() == "asc" else desc)(table.columns[column])
if sort
else table.columns[column]
)
for column, sort in order
]
order = [(asc if sort.lower() == "asc" else desc)(table.columns[column])
if sort else table.columns[column]
for column, sort in order]
except KeyError:
order = []
limit = int(request.query.get("limit", "100"))
if isinstance(instance.inst_db, Engine):
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
limit = int(request.query.get("limit", 100))
return execute_query(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query")
async def query(request: web.Request) -> web.Response:
instance_id = request.match_info["id"].lower()
instance = await PluginInstance.get(instance_id)
instance_id = request.match_info.get("id", "")
instance = PluginInstance.get(instance_id, None)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
@ -83,76 +96,28 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"]
except KeyError:
return resp.query_missing
rows_as_dict = data.get("rows_as_dict", False)
if isinstance(instance.inst_db, Engine):
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
elif isinstance(instance.inst_db, Database):
return execute_query(instance, sql_query,
rows_as_dict=data.get("rows_as_dict", False))
def execute_query(instance: PluginInstance, sql_query: Union[str, Query],
rows_as_dict: bool = False) -> web.Response:
try:
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
except (PostgresError, aiosqlite.Error) as e:
return resp.sql_error(e, sql_query)
else:
return resp.unsupported_plugin_database
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
async def _execute_query_asyncpg(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
data = {"ok": True, "query": sql_query}
if sql_query.upper().startswith("SELECT"):
res = await instance.inst_db.fetch(sql_query)
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
if len(res) > 0:
# TODO can we find column names when there are no rows?
data["columns"] = list(res[0].keys())
else:
res = await instance.inst_db.execute(sql_query)
if isinstance(res, str):
data["status_msg"] = res
elif isinstance(res, aiosqlite.Cursor):
data["rowcount"] = res.rowcount
# data["inserted_primary_key"] = res.lastrowid
else:
data["status_msg"] = "unknown status"
return web.json_response(data)
def _execute_query_sqlalchemy(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
assert isinstance(instance.inst_db, Engine)
try:
res = instance.inst_db.execute(sql_query)
except IntegrityError as e:
res: ResultProxy = instance.inst_db.execute(sql_query)
except exc.IntegrityError as e:
return resp.sql_integrity_error(e, sql_query)
except OperationalError as e:
except exc.OperationalError as e:
return resp.sql_operational_error(e, sql_query)
data = {
"ok": True,
"query": str(sql_query),
}
if res.returns_rows:
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
row: RowProxy
data["rows"] = [({key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
else [check_type(value) for value in row])
for row in res]
data["columns"] = res.keys()
else:
data["rowcount"] = res.rowcount

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,63 +13,31 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from collections import deque
from typing import Deque, List
from datetime import datetime
import asyncio
from collections import deque
import logging
import asyncio
from aiohttp import web, web_ws
from mautrix.util import background_task
from aiohttp import web
from .base import routes, get_loop
from .auth import is_valid_token
from .base import routes
BUILTIN_ATTRS = {
"args",
"asctime",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
INCLUDE_ATTRS = {
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"name",
"pathname",
}
BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name",
"pathname", "process", "processName", "relativeCreated", "stack_info", "thread",
"threadName"}
INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name",
"pathname"}
EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS
MAX_LINES = 2048
class LogCollector(logging.Handler):
lines: deque[dict]
lines: Deque[dict]
formatter: logging.Formatter
listeners: list[web.WebSocketResponse]
loop: asyncio.AbstractEventLoop
listeners: List[web.WebSocketResponse]
def __init__(self, level=logging.NOTSET) -> None:
super().__init__(level)
@ -87,7 +55,9 @@ class LogCollector(logging.Handler):
# JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license)
# https://github.com/marselester/json-log-formatter
content = {
name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS
name: value
for name, value in record.__dict__.items()
if name not in EXCLUDE_ATTRS
}
content["id"] = str(record.relativeCreated)
content["msg"] = record.getMessage()
@ -99,7 +69,7 @@ class LogCollector(logging.Handler):
for name, value in content.items():
if isinstance(value, datetime):
content[name] = value.astimezone().isoformat()
asyncio.run_coroutine_threadsafe(self.send(content), loop=self.loop)
asyncio.ensure_future(self.send(content))
self.lines.append(content)
async def send(self, record: dict) -> None:
@ -111,18 +81,17 @@ class LogCollector(logging.Handler):
handler = LogCollector()
log_root = logging.getLogger("maubot")
log = logging.getLogger("maubot.server.websocket")
sockets = []
def init(loop: asyncio.AbstractEventLoop) -> None:
logging.root.addHandler(handler)
handler.loop = loop
def init() -> None:
log_root.addHandler(handler)
async def stop_all() -> None:
log.debug("Closing log listener websockets")
logging.root.removeHandler(handler)
log_root.removeHandler(handler)
for socket in sockets:
try:
await socket.close(code=1012)
@ -139,15 +108,14 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
authenticated = False
async def close_if_not_authenticated():
await asyncio.sleep(5)
await asyncio.sleep(5, loop=get_loop())
if not authenticated:
await ws.close(code=4000)
log.debug(f"Connection from {request.remote} terminated due to no authentication")
background_task.create(close_if_not_authenticated())
asyncio.ensure_future(close_if_not_authenticated())
try:
msg: web_ws.WSMessage
async for msg in ws:
if msg.type != web.WSMsgType.TEXT:
continue

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,9 @@ import json
from aiohttp import web
from .auth import create_token
from .base import get_config, routes
from .base import routes, get_config
from .responses import resp
from .auth import create_token
@routes.post("/auth/login")
async def login(request: web.Request) -> web.Response:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable
import base64
from typing import Callable, Awaitable
import logging
from aiohttp import web
from .responses import resp
from .auth import check_token
from .base import get_config
from .responses import resp
Handler = Callable[[web.Request], Awaitable[web.Response]]
log = logging.getLogger("maubot.server")
@ -29,13 +28,8 @@ log = logging.getLogger("maubot.server")
@web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len("/_matrix/maubot/v1") :]
if (
subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/")
or subpath == "/features"
or subpath == "/logs"
):
subpath = request.path[len(get_config()["server.base_path"]):]
if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs":
return await handler(request)
err = check_token(request)
if err is not None:
@ -52,18 +46,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
return resp.path_not_found
elif ex.status_code == 405:
return resp.method_not_allowed
return web.json_response(
{
"httpexception": {
"headers": {key: value for key, value in ex.headers.items()},
"class": type(ex).__name__,
"body": ex.text or base64.b64encode(ex.body),
},
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}",
return web.json_response({
"error": f"Unhandled HTTP {ex.status}",
"errcode": f"unhandled_http_{ex.status}",
},
status=ex.status,
)
}, status=ex.status)
except Exception:
log.exception("Error in handler")
return resp.internal_server_error

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,9 @@ import traceback
from aiohttp import web
from ...loader import MaubotZipImportError, PluginLoader
from .base import routes
from ...loader import PluginLoader, MaubotZipImportError
from .responses import resp
from .base import routes
@routes.get("/plugins")
@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
@routes.get("/plugin/{id}")
async def get_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return resp.plugin_not_found
return resp.found(plugin.to_dict())
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
@routes.delete("/plugin/{id}")
async def delete_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return resp.plugin_not_found
elif len(plugin.references) > 0:
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
@routes.post("/plugin/{id}/reload")
async def reload_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"]
plugin = PluginLoader.id_cache.get(plugin_id)
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return resp.plugin_not_found

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -15,39 +15,27 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from io import BytesIO
from time import time
import logging
import traceback
import os.path
import re
import traceback
from aiohttp import web
from packaging.version import Version
from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader
from .base import get_config, routes
from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError
from .responses import resp
try:
import sqlalchemy
has_alchemy = True
except ImportError:
has_alchemy = False
log = logging.getLogger("maubot.server.upload")
from .base import routes, get_config
@routes.put("/plugin/{id}")
async def put_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info["id"]
plugin_id = request.match_info.get("id", None)
content = await request.read()
file = BytesIO(content)
try:
pid, version, db_type = ZippedPluginLoader.verify_meta(file)
pid, version = ZippedPluginLoader.verify_meta(file)
except MaubotZipImportError as e:
return resp.plugin_import_error(str(e), traceback.format_exc())
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
return resp.sqlalchemy_not_installed
if pid != plugin_id:
return resp.pid_mismatch
plugin = PluginLoader.id_cache.get(plugin_id, None)
@ -64,11 +52,9 @@ async def upload_plugin(request: web.Request) -> web.Response:
content = await request.read()
file = BytesIO(content)
try:
pid, version, db_type = ZippedPluginLoader.verify_meta(file)
pid, version = ZippedPluginLoader.verify_meta(file)
except MaubotZipImportError as e:
return resp.plugin_import_error(str(e), traceback.format_exc())
if db_type == DatabaseType.SQLALCHEMY and not has_alchemy:
return resp.sqlalchemy_not_installed
plugin = PluginLoader.id_cache.get(pid, None)
if not plugin:
return await upload_new_plugin(content, pid, version)
@ -92,20 +78,15 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R
return resp.created(plugin.to_dict())
async def upload_replacement_plugin(
plugin: ZippedPluginLoader, content: bytes, new_version: Version
) -> web.Response:
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
new_version: Version) -> web.Response:
dirname = os.path.dirname(plugin.path)
old_filename = os.path.basename(plugin.path)
if str(plugin.meta.version) in old_filename:
replacement = (
str(new_version)
if plugin.meta.version != new_version
else f"{new_version}-ts{int(time() * 1000)}"
)
filename = re.sub(
f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename
)
replacement = (new_version if plugin.meta.version != new_version
else f"{new_version}-ts{int(time())}")
filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?",
replacement, old_filename)
else:
filename = old_filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp"
@ -117,29 +98,12 @@ async def upload_replacement_plugin(
try:
await plugin.reload(new_path=path)
except MaubotZipImportError as e:
log.exception(f"Error loading updated version of {plugin.meta.id}, rolling back")
try:
await plugin.reload(new_path=old_path)
await plugin.start_instances()
except MaubotZipImportError:
log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True)
finally:
ZippedPluginLoader.trash(path, reason="failed_update")
pass
return resp.plugin_import_error(str(e), traceback.format_exc())
try:
await plugin.start_instances()
except Exception as e:
log.exception(f"Error starting {plugin.meta.id} instances after update, rolling back")
try:
await plugin.stop_instances()
await plugin.reload(new_path=old_path)
await plugin.start_instances()
except Exception:
log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True)
finally:
ZippedPluginLoader.trash(path, reason="failed_update")
return resp.plugin_reload_error(str(e), traceback.format_exc())
log.debug(f"Successfully updated {plugin.meta.id}, moving old version to trash")
ZippedPluginLoader.trash(old_path, reason="update")
return resp.updated(plugin.to_dict())

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,457 +13,270 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING
from http import HTTPStatus
from aiohttp import web
from asyncpg import PostgresError
import aiosqlite
if TYPE_CHECKING:
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.exc import OperationalError, IntegrityError
class _Response:
@property
def body_not_json(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def plugin_type_required(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def primary_user_required(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def bad_client_access_token(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Invalid access token",
"errcode": "bad_client_access_token",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def bad_client_access_details(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details",
},
status=HTTPStatus.BAD_REQUEST,
)
"errcode": "bad_client_access_details"
}, status=HTTPStatus.BAD_REQUEST)
@property
def bad_client_connection_details(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Could not connect to homeserver",
"errcode": "bad_client_connection_details",
},
status=HTTPStatus.BAD_REQUEST,
)
"errcode": "bad_client_connection_details"
}, status=HTTPStatus.BAD_REQUEST)
def mxid_mismatch(self, found: str) -> web.Response:
return web.json_response(
{
"error": (
"The Matrix user ID of the client and the user ID of the access token don't "
f"match. Access token is for user {found}"
),
return web.json_response({
"error": "The Matrix user ID of the client and the user ID of the access token don't "
f"match. Access token is for user {found}",
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response(
{
"error": (
"The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def pid_mismatch(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "The ID in the path does not match the ID of the uploaded plugin",
"errcode": "pid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def username_or_password_missing(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Username or password missing",
"errcode": "username_or_password_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def query_missing(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Query missing",
"errcode": "query_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod
def sql_error(error: PostgresError | aiosqlite.Error, query: str) -> web.Response:
return web.json_response(
{
"ok": False,
"query": query,
"error": str(error),
"errcode": "sql_error",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@staticmethod
def sql_operational_error(error: OperationalError, query: str) -> web.Response:
return web.json_response(
{
return web.json_response({
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_operational_error",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@staticmethod
def sql_integrity_error(error: IntegrityError, query: str) -> web.Response:
return web.json_response(
{
return web.json_response({
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_integrity_error",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def bad_auth(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
},
status=HTTPStatus.UNAUTHORIZED,
)
}, status=HTTPStatus.UNAUTHORIZED)
@property
def no_token(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Authorization token missing",
"errcode": "auth_token_missing",
},
status=HTTPStatus.UNAUTHORIZED,
)
}, status=HTTPStatus.UNAUTHORIZED)
@property
def invalid_token(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Invalid authorization token",
"errcode": "auth_token_invalid",
},
status=HTTPStatus.UNAUTHORIZED,
)
}, status=HTTPStatus.UNAUTHORIZED)
@property
def plugin_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Plugin not found",
"errcode": "plugin_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def client_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Client not found",
"errcode": "client_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def primary_user_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Client for given primary user not found",
"errcode": "primary_user_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def instance_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Plugin instance not found",
"errcode": "instance_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def plugin_type_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Given plugin type not found",
"errcode": "plugin_type_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def path_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def server_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Registration target server not found",
"errcode": "server_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_secret_not_found(self) -> web.Response:
return web.json_response(
{
"error": "Config does not have a registration secret for that server",
"errcode": "registration_secret_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_no_sso(self) -> web.Response:
return web.json_response(
{
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def sso_not_supported(self) -> web.Response:
return web.json_response(
{
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
},
status=HTTPStatus.FORBIDDEN,
)
}, status=HTTPStatus.NOT_FOUND)
@property
def plugin_has_no_database(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Given plugin does not have a database",
"errcode": "plugin_has_no_database",
}
)
@property
def unsupported_plugin_database(self) -> web.Response:
return web.json_response(
{
"error": "The database type is not supported by this API",
"errcode": "unsupported_plugin_database",
}
)
@property
def sqlalchemy_not_installed(self) -> web.Response:
return web.json_response(
{
"error": "This plugin requires a legacy database, but SQLAlchemy is not installed",
"errcode": "unsupported_plugin_database",
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
})
@property
def table_not_found(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Given table not found in plugin database",
"errcode": "table_not_found",
}
)
})
@property
def method_not_allowed(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Method not allowed",
"errcode": "method_not_allowed",
},
status=HTTPStatus.METHOD_NOT_ALLOWED,
)
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
@property
def user_exists(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
},
status=HTTPStatus.CONFLICT,
)
}, status=HTTPStatus.CONFLICT)
@property
def plugin_exists(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "A plugin with the same ID as the uploaded plugin already exists",
"errcode": "plugin_exists",
},
status=HTTPStatus.CONFLICT,
)
"errcode": "plugin_exists"
}, status=HTTPStatus.CONFLICT)
@property
def plugin_in_use(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
}, status=HTTPStatus.PRECONDITION_FAILED)
@property
def client_in_use(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Plugin instances with this client as their primary user still exist",
"errcode": "client_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
}, status=HTTPStatus.PRECONDITION_FAILED)
@staticmethod
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_invalid",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@staticmethod
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_reload_fail",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
@property
def internal_server_error(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Internal server error",
"errcode": "internal_server_error",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
@property
def invalid_server(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Invalid registration server object in maubot configuration",
"errcode": "invalid_server",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
@property
def unsupported_plugin_loader(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader",
},
status=HTTPStatus.BAD_REQUEST,
)
}, status=HTTPStatus.BAD_REQUEST)
@property
def not_implemented(self) -> web.Response:
return web.json_response(
{
return web.json_response({
"error": "Not implemented",
"errcode": "not_implemented",
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
}, status=HTTPStatus.NOT_IMPLEMENTED)
@property
def ok(self) -> web.Response:
return web.json_response(
{"success": True},
status=HTTPStatus.OK,
)
return web.json_response({
"success": True,
}, status=HTTPStatus.OK)
@property
def deleted(self) -> web.Response:
@ -473,15 +286,19 @@ class _Response:
def found(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK)
@staticmethod
def updated(data: dict, is_login: bool = False) -> web.Response:
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def updated(self, data: dict) -> web.Response:
return self.found(data)
def logged_in(self, token: str) -> web.Response:
return self.found({"token": token})
return self.found({
"token": token,
})
def pong(self, user: str, features: dict) -> web.Response:
return self.found({"username": user, "features": features})
return self.found({
"username": user,
"features": features,
})
@staticmethod
def created(data: dict) -> web.Response:

View File

@ -1,2 +1,93 @@
# Maubot Management API
This document has been moved to docs.mau.fi: <https://docs.mau.fi/maubot/management-api.html>
Most of the API is simple HTTP+JSON and has OpenAPI documentation (see
[spec.yaml](spec.yaml), [rendered](https://maubot.xyz/spec/)). However,
some parts of the API aren't documented in the OpenAPI document.
## Matrix API proxy
The full Matrix API can be accessed for each client with a request to
`/_matrix/maubot/v1/proxy/<user>/<path>`. `<user>` is the Matrix user
ID of the user to access the API as and `<path>` is the whole API
path to access (e.g. `/_matrix/client/r0/whoami`).
The body, headers, query parameters, etc are sent to the Matrix server
as-is, with a few exceptions:
* The `Authorization` header will be replaced with the access token
for the Matrix user from the maubot database.
* The `access_token` query parameter will be removed.
## Log viewing
1. Open websocket to `/_matrix/maubot/v1/logs`.
2. Send authentication token as a plain string.
3. Server will respond with `{"auth_success": true}` and then with
`{"history": ...}` where `...` is a list of log entries.
4. Server will send new log entries as JSON.
### Log entry object format
Log entries are a JSON-serialized form of Python log records.
Log entries will always have:
* `id` - A string that should uniquely identify the row. Currently
uses the `relativeCreated` field of Python logging records.
* `msg` - The text in the entry.
* `time` - The ISO date when the log entry was created.
Log entries should also always have:
* `levelname` - The log level (e.g. `DEBUG` or `ERROR`).
* `levelno` - The integer log level.
* `name` - The name of the logger. Common values:
* `maubot.client.<mxid>` - Client loggers (Matrix HTTP requests)
* `maubot.instance.<id>` - Plugin instance loggers
* `maubot.loader.zip` - The zip plugin loader (plugins don't
have their own logs)
* `module` - The Python module name where the log call happened.
* `pathname` - The full path of the file where the log call happened.
* `filename` - The file name portion of `pathname`
* `lineno` - The line in code where the log call happened.
* `funcName` - The name of the function where the log call happened.
Log entries might have:
* `exc_info` - The formatted exception info if an exception was logged.
* `matrix_http_request` - The info about a Matrix HTTP request. Subfields:
* `method` - The HTTP method used.
* `path` - The API path used.
* `content` - The content sent.
* `user` - The appservice user who the request was ran as.
## Debug file open
For debug and development purposes, the API and frontend support
clicking on lines in stack traces to open that line in the selected
editor.
### Configuration
First, the directory where maubot is run from must have a
`.dev-open-cfg.yaml` file. The file should contain the following
fields:
* `editor` - The command to run to open a file.
* `$path` is replaced with the full file path.
* `$line` is replaced with the line number.
* `pathmap` - A list of find-and-replaces to execute on paths.
These are needed to map things like `.mbp` files to the extracted
sources on disk. Each pathmap entry should have:
* `find` - The regex to match.
* `replace` - The replacement. May insert capture groups with Python
syntax (e.g. `\1`)
Example file:
```yaml
editor: pycharm --line $line $path
pathmap:
- find: "maubot/plugins/xyz\\.maubot\\.(.+)-v.+(?:-ts[0-9]+)?.mbp"
replace: "mbplugins/\\1"
- find: "maubot/.venv/lib/python3.6/site-packages/mautrix"
replace: "mautrix-python/mautrix"
```
### API
Clients can `GET /_matrix/maubot/v1/debug/open` to check if the file
open endpoint has been set up. The response is a JSON object with a
single field `enabled`. If the value is true, the endpoint can be used.
To open files, clients can `POST /_matrix/maubot/v1/debug/open` with
a JSON body containing
* `path` - The full file path to open
* `line` - The line number to open

View File

@ -366,7 +366,7 @@ paths:
schema:
$ref: '#/components/schemas/MatrixClient'
responses:
202:
200:
description: Client updated
content:
application/json:
@ -399,169 +399,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/Error'
'/client/{id}/clearcache':
parameters:
- name: id
in: path
description: The Matrix user ID of the client to change
required: true
schema:
type: string
put:
operationId: clear_client_cache
summary: Clear the sync/state cache of a Matrix client
tags: [Clients]
responses:
200:
description: Cache cleared
content:
application/json:
schema:
type: object
properties:
success:
type: boolean
401:
$ref: '#/components/responses/Unauthorized'
404:
$ref: '#/components/responses/ClientNotFound'
/client/auth/servers:
get:
operationId: get_client_auth_servers
summary: Get the list of servers you can register or log in on via the maubot server
tags: [Clients]
responses:
200:
description: OK
content:
application/json:
schema:
type: object
description: Key-value map from server name to homeserver URL
additionalProperties:
type: string
description: The homeserver URL
example:
maunium.net: https://maunium.net
example.com: https://matrix.example.org
401:
$ref: '#/components/responses/Unauthorized'
'/client/auth/{server}/register':
parameters:
- name: server
in: path
description: The server name to register the account on.
required: true
schema:
type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post:
operationId: client_auth_register
summary: |
Register a new account on the given Matrix server using the shared registration
secret configured into the maubot server.
tags: [Clients]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixAuthentication'
responses:
200:
description: Registration successful
content:
application/json:
schema:
type: object
properties:
access_token:
type: string
example: syt_123_456_789
user_id:
type: string
example: '@putkiteippi:maunium.net'
device_id:
type: string
example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401:
$ref: '#/components/responses/Unauthorized'
409:
description: |
There is already a client with the user ID of that token.
This should usually not happen, because the user ID was just created.
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
500:
$ref: '#/components/responses/MatrixServerError'
'/client/auth/{server}/login':
parameters:
- name: server
in: path
description: The server name to log in to.
required: true
schema:
type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post:
operationId: client_auth_login
summary: Log in to the given Matrix server via the maubot server
tags: [Clients]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixAuthentication'
responses:
200:
description: Login successful
content:
application/json:
schema:
type: object
properties:
user_id:
type: string
example: '@putkiteippi:maunium.net'
access_token:
type: string
example: syt_123_456_789
device_id:
type: string
example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
202:
description: Client updated (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401:
$ref: '#/components/responses/Unauthorized'
500:
$ref: '#/components/responses/MatrixServerError'
components:
responses:
@ -595,23 +432,6 @@ components:
application/json:
schema:
$ref: '#/components/schemas/Error'
MatrixServerError:
description: The Matrix server returned an error
content:
application/json:
schema:
type: object
properties:
errcode:
type: string
description: The `errcode` returned by the server.
error:
type: string
description: The human-readable error returned by the server.
http_status:
type: integer
description: The HTTP status returned by the server.
securitySchemes:
bearer:
type: http
@ -668,61 +488,31 @@ components:
type: string
example: '@putkiteippi:maunium.net'
readOnly: true
description: The Matrix user ID of this client.
homeserver:
type: string
example: 'https://maunium.net'
description: The homeserver URL for this client.
access_token:
type: string
description: The Matrix access token for this client.
device_id:
type: string
description: The Matrix device ID corresponding to the access token.
fingerprint:
type: string
description: The encryption device fingerprint for verification.
enabled:
type: boolean
example: true
description: Whether or not this client is enabled.
started:
type: boolean
example: true
description: Whether or not this client and its instances have been started.
sync:
type: boolean
example: true
description: Whether or not syncing is enabled on this client.
sync_ok:
type: boolean
example: true
description: Whether or not the previous sync was successful on this client.
autojoin:
type: boolean
example: true
description: Whether or not this client should automatically join rooms when invited.
displayname:
type: string
example: J. E. Saarinen
description: The display name for this client.
avatar_url:
type: string
example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU'
description: The content URI of the avatar for this client.
instances:
type: array
readOnly: true
items:
$ref: '#/components/schemas/PluginInstance'
MatrixAuthentication:
type: object
properties:
username:
type: string
example: putkiteippi
description: The user ID localpart to register/log in as.
password:
type: string
example: p455w0rd
description: The password for/of the user.

View File

@ -1,27 +1,17 @@
{
"name": "maubot-manager",
"version": "0.1.1",
"version": "0.1.0",
"private": true,
"author": "Tulir Asokan",
"license": "AGPL-3.0-or-later",
"repository": {
"type": "git",
"url": "git+https://github.com/maubot/maubot.git"
},
"bugs": {
"url": "https://github.com/maubot/maubot/issues"
},
"homepage": ".",
"dependencies": {
"react": "^17.0.2",
"react-ace": "^9.4.1",
"react-contextmenu": "^2.14.0",
"react-dom": "^17.0.2",
"react-json-tree": "^0.16.1",
"react-router-dom": "^5.3.0",
"react-scripts": "5.0.0",
"react-select": "^5.2.1",
"sass": "^1.34.1"
"node-sass": "^4.9.4",
"react": "^16.6.0",
"react-ace": "^6.2.0",
"react-contextmenu": "^2.10.0",
"react-dom": "^16.6.0",
"react-json-tree": "^0.11.0",
"react-router-dom": "^4.3.1",
"react-scripts": "2.0.5",
"react-select": "^2.1.1"
},
"scripts": {
"start": "react-scripts start",
@ -30,11 +20,16 @@
"eject": "react-scripts eject"
},
"browserslist": [
"last 2 firefox versions",
"last 2 and_ff versions",
"last 2 chrome versions",
"last 2 and_chr versions",
"last 1 safari versions",
"last 1 ios_saf versions"
]
"last 5 firefox versions",
"last 3 and_ff versions",
"last 5 chrome versions",
"last 3 and_chr versions",
"last 2 safari versions",
"last 2 ios_saf versions"
],
"homepage": ".",
"devDependencies": {
"sass-lint": "^1.12.1",
"sass-lint-auto-fix": "^0.15.0"
}
}

View File

@ -1,6 +1,6 @@
<!--
maubot - A plugin-based Matrix bot system.
Copyright (C) 2022 Tulir Asokan
Copyright (C) 2018 Tulir Asokan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
@ -20,11 +20,12 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
<head>
<meta charset="utf-8">
<link rel="shortcut icon" href="%PUBLIC_URL%/favicon.png">
<link rel="stylesheet" type="text/css"
href="https://fonts.googleapis.com/css?family=Raleway:300,400,700">
<link rel="stylesheet" type="text/css"
href="https://cdn.jsdelivr.net/gh/tonsky/FiraCode@1.206/distr/fira_code.css">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<meta name="theme-color" content="#50D367">
<meta property="og:title" content="Maubot Manager"/>
<meta property="og:description" content="Maubot management interface"/>
<meta property="og:image" content="%PUBLIC_URL%/favicon.png"/>
<link rel="manifest" href="%PUBLIC_URL%/manifest.json">
<title>Maubot Manager</title>
</head>

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
let BASE_PATH = "/_matrix/maubot/v1"
export function setBasePath(basePath) {
BASE_PATH = basePath
}
export const BASE_PATH = "/_matrix/maubot/v1"
function getHeaders(contentType = "application/json") {
return {
@ -40,8 +36,8 @@ async function defaultDelete(type, id) {
return await resp.json()
}
async function defaultPut(type, entry, id = undefined, suffix = undefined) {
const resp = await fetch(`${BASE_PATH}/${type}/${id || entry.id}${suffix || ""}`, {
async function defaultPut(type, entry, id = undefined) {
const resp = await fetch(`${BASE_PATH}/${type}/${id || entry.id}`, {
headers: getHeaders(),
body: JSON.stringify(entry),
method: "PUT",
@ -214,10 +210,10 @@ export async function uploadAvatar(id, data, mime) {
}
export function getAvatarURL({ id, avatar_url }) {
if (!avatar_url?.startsWith("mxc://")) {
return null
}
avatar_url = avatar_url || ""
if (avatar_url.startsWith("mxc://")) {
avatar_url = avatar_url.substr("mxc://".length)
}
return `${BASE_PATH}/proxy/${id}/_matrix/media/r0/download/${avatar_url}?access_token=${
localStorage.accessToken}`
}
@ -225,33 +221,13 @@ export function getAvatarURL({ id, avatar_url }) {
export const putClient = client => defaultPut("client", client)
export const deleteClient = id => defaultDelete("client", id)
export async function clearClientCache(id) {
const resp = await fetch(`${BASE_PATH}/client/${id}/clearcache`, {
headers: getHeaders(),
method: "POST",
})
return await resp.json()
}
export const getClientAuthServers = () => defaultGet("/client/auth/servers")
export async function doClientAuth(server, type, username, password) {
const resp = await fetch(`${BASE_PATH}/client/auth/${server}/${type}`, {
headers: getHeaders(),
body: JSON.stringify({ username, password }),
method: "POST",
})
return await resp.json()
}
// eslint-disable-next-line import/no-anonymous-default-export
export default {
login, ping, setBasePath, getFeatures, remoteGetFeatures,
BASE_PATH,
login, ping, getFeatures, remoteGetFeatures,
openLogSocket,
debugOpenFile, debugOpenFileEnabled, updateDebugOpenFileEnabled,
getInstances, getInstance, putInstance, deleteInstance,
getInstanceDatabase, queryInstanceDatabase,
getPlugins, getPlugin, uploadPlugin, deletePlugin,
getClients, getClient, uploadAvatar, getAvatarURL, putClient, deleteClient, clearClientCache,
getClientAuthServers, doClientAuth,
getClients, getClient, uploadAvatar, getAvatarURL, putClient, deleteClient,
}

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -15,7 +15,6 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import React from "react"
import Select from "react-select"
import CreatableSelect from "react-select/creatable"
import Switch from "./Switch"
export const PrefTable = ({ children, wrapperClass }) => {
@ -57,14 +56,10 @@ export const PrefSwitch = ({ rowName, active, origActive, fullWidth = false, ...
</PrefRow>
)
export const PrefSelect = ({
rowName, value, origValue, fullWidth = false, creatable = false, ...args
}) => (
export const PrefSelect = ({ rowName, value, origValue, fullWidth = false, ...args }) => (
<PrefRow name={rowName} fullWidth={fullWidth} labelFor={rowName}
changed={origValue !== undefined && value.id !== origValue}>
{creatable
? <CreatableSelect className="select" {...args} id={rowName} value={value}/>
: <Select className="select" {...args} id={rowName} value={value}/>}
<Select className="select" {...args} id={rowName} value={value}/>
</PrefRow>
)

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -23,13 +23,11 @@ class Switch extends Component {
}
}
componentDidUpdate(prevProps) {
if (prevProps.active !== this.props.active) {
componentWillReceiveProps(nextProps) {
this.setState({
active: this.props.active,
active: nextProps.active,
})
}
}
toggle = () => {
if (this.props.onToggle) {

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -30,8 +30,7 @@ class Login extends Component {
inputChanged = event => this.setState({ [event.target.name]: event.target.value })
login = async evt => {
evt.preventDefault()
login = async () => {
this.setState({ loading: true })
const resp = await api.login(this.state.username, this.state.password)
if (resp.token) {
@ -54,17 +53,17 @@ class Login extends Component {
</div>
}
return <div className="login-wrapper">
<form className={`login ${this.state.error && "errored"}`} onSubmit={this.login}>
<div className={`login ${this.state.error && "errored"}`}>
<h1>Maubot Manager</h1>
<input type="text" placeholder="Username" value={this.state.username}
name="username" onChange={this.inputChanged}/>
<input type="password" placeholder="Password" value={this.state.password}
name="password" onChange={this.inputChanged}/>
<button onClick={this.login} type="submit">
<button onClick={this.login}>
{this.state.loading ? <Spinner/> : "Log in"}
</button>
{this.state.error && <div className="error">{this.state.error}</div>}
</form>
</div>
</div>
}
}

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -30,8 +30,7 @@ class Main extends Component {
}
}
async componentDidMount() {
await this.getBasePath()
async componentWillMount() {
if (localStorage.accessToken) {
await this.ping()
} else {
@ -40,19 +39,6 @@ class Main extends Component {
this.setState({ pinged: true })
}
async getBasePath() {
try {
const resp = await fetch(process.env.PUBLIC_URL + "/paths.json", {
headers: { "Content-Type": "application/json" },
})
const apiPaths = await resp.json()
api.setBasePath(apiPaths.api_path)
} catch (err) {
console.error("Failed to get API path:", err)
}
}
async ping() {
try {
const username = await api.ping()

View File

@ -1,18 +1,3 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import React, { Component } from "react"
import { Link } from "react-router-dom"
import api from "../../api"
@ -23,7 +8,7 @@ class BaseMainView extends Component {
this.state = Object.assign(this.initialState, props.entry)
}
UNSAFE_componentWillReceiveProps(nextProps) {
componentWillReceiveProps(nextProps) {
const newState = Object.assign(this.initialState, nextProps.entry)
for (const key of this.entryKeys) {
if (this.props.entry[key] === nextProps.entry[key]) {

Some files were not shown because too many files have changed in this diff Show More