Blacken and isort code

This commit is contained in:
Tulir Asokan 2022-03-25 14:22:37 +02:00
parent 6257979e7c
commit 068e268c63
97 changed files with 1781 additions and 1086 deletions

26
.github/workflows/python-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.10"
- uses: isort/isort-action@master
with:
sortPaths: "./maubot"
- uses: psf/black@stable
with:
src: "./maubot"
version: "22.1.0"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-yaml
pre-commit run -av check-added-large-files

View File

@ -1,4 +1,11 @@
# maubot # maubot
![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg)
[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE)
[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases)
[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry)
[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
A plugin-based [Matrix](https://matrix.org) bot system written in Python. A plugin-based [Matrix](https://matrix.org) bot system written in Python.
## Documentation ## Documentation

3
dev-requirements.txt Normal file
View File

@ -0,0 +1,3 @@
pre-commit>=2.10.1,<3
isort>=5.10.1,<6
black==22.1.0

View File

@ -1,4 +1,4 @@
from .__meta__ import __version__
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .plugin_base import Plugin from .plugin_base import Plugin
from .plugin_server import PluginWebApp from .plugin_server import PluginWebApp
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .__meta__ import __version__

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,15 +17,15 @@ import asyncio
from mautrix.util.program import Program from mautrix.util.program import Program
from .__meta__ import __version__
from .client import Client, init as init_client_class
from .config import Config from .config import Config
from .db import init as init_db from .db import init as init_db
from .server import MaubotServer
from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .lib.future_awaitable import FutureAwaitable from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__ from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .server import MaubotServer
class Maubot(Program): class Maubot(Program):
@ -41,6 +41,7 @@ class Maubot(Program):
def prepare_log_websocket(self) -> None: def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all from .management.api.log import init, stop_all
init(self.loop) init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all)) self.add_shutdown_actions(FutureAwaitable(stop_all))

View File

@ -1,2 +1,3 @@
from . import app from . import app
app() app()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by

View File

@ -1,2 +1,2 @@
from .cliq import command, option from .cliq import command, option
from .validators import SPDXValidator, VersionValidator, PathValidator from .validators import PathValidator, SPDXValidator, VersionValidator

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Callable, Union, Optional, Type from __future__ import annotations
import functools
import traceback from typing import Any, Callable
import inspect
import asyncio import asyncio
import functools
import inspect
import traceback
import aiohttp from colorama import Fore
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from questionary import prompt from questionary import prompt
from colorama import Fore import aiohttp
import click import click
from ..base import app from ..base import app
from ..config import get_token from ..config import get_token
from .validators import Required, ClickValidator from .validators import ClickValidator, Required
def with_http(func): def with_http(func):
@ -105,7 +106,7 @@ def command(help: str) -> Callable[[Callable], Callable]:
return decorator return decorator
def yesno(val: str) -> Optional[bool]: def yesno(val: str) -> bool | None:
if not val: if not val:
return None return None
elif isinstance(val, bool): elif isinstance(val, bool):
@ -119,11 +120,20 @@ def yesno(val: str) -> Optional[bool]:
yesno.__name__ = "yes/no" yesno.__name__ = "yes/no"
def option(short: str, long: str, message: str = None, help: str = None, def option(
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None, short: str,
validator: Type[Validator] = None, required: bool = False, long: str,
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True, message: str = None,
required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]: help: str = None,
click_type: str | Callable[[str], Any] = None,
inq_type: str = None,
validator: type[Validator] = None,
required: bool = False,
default: str | bool | None = None,
is_flag: bool = False,
prompt: bool = True,
required_unless: str | list | dict = None,
) -> Callable[[Callable], Callable]:
if not message: if not message:
message = long[2].upper() + long[3:] message = long[2].upper() + long[3:]
@ -139,9 +149,9 @@ def option(short: str, long: str, message: str = None, help: str = None,
if not hasattr(func, "__inquirer_questions__"): if not hasattr(func, "__inquirer_questions__"):
func.__inquirer_questions__ = {} func.__inquirer_questions__ = {}
q = { q = {
"type": (inq_type if isinstance(inq_type, str) "type": (
else ("input" if not is_flag inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm")
else "confirm")), ),
"name": long[2:], "name": long[2:],
"message": message, "message": message,
} }

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,9 +16,9 @@
from typing import Callable from typing import Callable
import os import os
from packaging.version import Version, InvalidVersion from packaging.version import InvalidVersion, Version
from prompt_toolkit.validation import Validator, ValidationError
from prompt_toolkit.document import Document from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
import click import click
from ..util import spdx as spdxlib from ..util import spdx as spdxlib

View File

@ -1 +1 @@
from . import upload, build, login, init, logs, auth from . import auth, build, init, login, logs, upload

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import webbrowser
import json import json
import webbrowser
from colorama import Fore from colorama import Fore
from yarl import URL from yarl import URL
@ -26,12 +26,16 @@ from ..cliq import cliq
history_count: int = 10 history_count: int = 10
friendly_errors = { friendly_errors = {
"server_not_found": "Registration target server not found.\n\n" "server_not_found": (
"To log in or register through maubot, you must add the server to the\n" "Registration target server not found.\n\n"
"homeservers section in the config. If you only want to log in,\n" "To log in or register through maubot, you must add the server to the\n"
"leave the `secret` field empty.", "homeservers section in the config. If you only want to log in,\n"
"registration_no_sso": "The register operation is only for registering with a password.\n\n" "leave the `secret` field empty."
"To register with SSO, simply leave out the --register flag.", ),
"registration_no_sso": (
"The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag."
),
} }
@ -46,26 +50,58 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
@cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with", @cliq.option(
required_unless=["list", "sso"]) "-u", "--username", help="The username to log in with", required_unless=["list", "sso"]
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password", )
required_unless=["list", "sso"]) @cliq.option(
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="", "-p",
required=False, prompt=False) "--password",
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True, help="The password to log in with",
default=False) inq_type="password",
@click.option("-c", "--update-client", help="Instead of returning the access token, " required_unless=["list", "sso"],
"create or update a client in maubot using it", )
is_flag=True, default=False) @cliq.option(
"-s",
"--server",
help="The maubot instance to log in through",
default="",
required=False,
prompt=False,
)
@click.option(
"-r", "--register", help="Register instead of logging in", is_flag=True, default=False
)
@click.option(
"-c",
"--update-client",
help="Instead of returning the access token, " "create or update a client in maubot using it",
is_flag=True,
default=False,
)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option("-o", "--sso", help="Use single sign-on instead of password login", @click.option(
is_flag=True, default=False) "-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)", )
default="Maubot", required=False) @click.option(
"-n",
"--device-name",
help="The initial e2ee device displayname (only for login)",
default="Maubot",
required=False,
)
@cliq.with_authenticated_http @cliq.with_authenticated_http
async def auth(homeserver: str, username: str, password: str, server: str, register: bool, async def auth(
list: bool, update_client: bool, device_name: str, sso: bool, homeserver: str,
sess: aiohttp.ClientSession) -> None: username: str,
password: str,
server: str,
register: bool,
list: bool,
update_client: bool,
device_name: str,
sso: bool,
sess: aiohttp.ClientSession,
) -> None:
if list: if list:
await list_servers(server, sess) await list_servers(server, sess)
return return
@ -88,8 +124,9 @@ async def auth(homeserver: str, username: str, password: str, server: str, regis
await print_response(resp, is_register=register) await print_response(resp, is_register=register)
async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, async def wait_sso(
server: str, homeserver: str) -> None: resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str
) -> None:
data = await resp.json() data = await resp.json()
sso_url, reg_id = data["sso_url"], data["id"] sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}") print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
@ -110,9 +147,11 @@ async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> Non
elif resp.status in (201, 202): elif resp.status in (201, 202):
data = await resp.json() data = await resp.json()
action = "created" if resp.status == 201 else "updated" action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for " print(
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}") f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}"
)
else: else:
await print_error(resp, is_register) await print_error(resp, is_register)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,26 +13,28 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union, IO from __future__ import annotations
from typing import IO
from io import BytesIO from io import BytesIO
import zipfile
import asyncio import asyncio
import glob import glob
import os import os
import zipfile
from ruamel.yaml import YAML, YAMLError
from aiohttp import ClientSession from aiohttp import ClientSession
from questionary import prompt
from colorama import Fore from colorama import Fore
from questionary import prompt
from ruamel.yaml import YAML, YAMLError
import click import click
from mautrix.types import SerializerError from mautrix.types import SerializerError
from ...loader import PluginMeta from ...loader import PluginMeta
from ..cliq.validators import PathValidator
from ..base import app from ..base import app
from ..config import get_token
from ..cliq import cliq from ..cliq import cliq
from ..cliq.validators import PathValidator
from ..config import get_token
from .upload import upload_file from .upload import upload_file
yaml = YAML() yaml = YAML()
@ -44,7 +46,7 @@ def zipdir(zip, dir):
zip.write(os.path.join(root, file)) zip.write(os.path.join(root, file))
def read_meta(path: str) -> Optional[PluginMeta]: def read_meta(path: str) -> PluginMeta | None:
try: try:
with open(os.path.join(path, "maubot.yaml")) as meta_file: with open(os.path.join(path, "maubot.yaml")) as meta_file:
try: try:
@ -65,7 +67,7 @@ def read_meta(path: str) -> Optional[PluginMeta]:
return meta return meta
def read_output_path(output: str, meta: PluginMeta) -> Optional[str]: def read_output_path(output: str, meta: PluginMeta) -> str | None:
directory = os.getcwd() directory = os.getcwd()
filename = f"{meta.id}-v{meta.version}.mbp" filename = f"{meta.id}-v{meta.version}.mbp"
if not output: if not output:
@ -73,18 +75,15 @@ def read_output_path(output: str, meta: PluginMeta) -> Optional[str]:
elif os.path.isdir(output): elif os.path.isdir(output):
output = os.path.join(output, filename) output = os.path.join(output, filename)
elif os.path.exists(output): elif os.path.exists(output):
override = prompt({ q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}]
"type": "confirm", override = prompt(q)["override"]
"name": "override",
"message": f"{output} exists, override?"
})["override"]
if not override: if not override:
return None return None
os.remove(output) os.remove(output)
return os.path.abspath(output) return os.path.abspath(output)
def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None: def write_plugin(meta: PluginMeta, output: str | IO) -> None:
with zipfile.ZipFile(output, "w") as zip: with zipfile.ZipFile(output, "w") as zip:
meta_dump = BytesIO() meta_dump = BytesIO()
yaml.dump(meta.serialize(), meta_dump) yaml.dump(meta.serialize(), meta_dump)
@ -104,7 +103,7 @@ def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None:
@cliq.with_authenticated_http @cliq.with_authenticated_http
async def upload_plugin(output: Union[str, IO], *, server: str, sess: ClientSession) -> None: async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None:
server, token = get_token(server) server, token = get_token(server)
if not token: if not token:
return return
@ -115,14 +114,20 @@ async def upload_plugin(output: Union[str, IO], *, server: str, sess: ClientSess
await upload_file(sess, output, server) await upload_file(sess, output, server)
@app.command(short_help="Build a maubot plugin", @app.command(
help="Build a maubot plugin. First parameter is the path to root of the plugin " short_help="Build a maubot plugin",
"to build. You can also use --output to specify output file.") help=(
"Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file."
),
)
@click.argument("path", default=os.getcwd()) @click.argument("path", default=os.getcwd())
@click.option("-o", "--output", help="Path to output built plugin to", @click.option(
type=PathValidator.click_type) "-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type
@click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True, )
default=False) @click.option(
"-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False
)
@click.option("-s", "--server", help="Server to upload built plugin to") @click.option("-s", "--server", help="Server to upload built plugin to")
def build(path: str, output: str, upload: bool, server: str) -> None: def build(path: str, output: str, upload: bool, server: str) -> None:
meta = read_meta(path) meta = read_meta(path)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,11 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from pkg_resources import resource_string
import os import os
from packaging.version import Version
from jinja2 import Template from jinja2 import Template
from packaging.version import Version
from pkg_resources import resource_string
from .. import cliq from .. import cliq
from ..cliq import SPDXValidator, VersionValidator from ..cliq import SPDXValidator, VersionValidator
@ -40,26 +40,55 @@ def load_templates():
@cliq.command(help="Initialize a new maubot plugin") @cliq.command(help="Initialize a new maubot plugin")
@cliq.option("-n", "--name", help="The name of the project", required=True, @cliq.option(
default=os.path.basename(os.getcwd())) "-n",
@cliq.option("-i", "--id", message="ID", required=True, "--name",
help="The maubot plugin ID (Java package name format)") help="The name of the project",
@cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)", required=True,
default="0.1.0", validator=VersionValidator, required=True) default=os.path.basename(os.getcwd()),
@cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later", )
help="The license for the project (SPDX identifier)", required=False) @cliq.option(
@cliq.option("-c", "--config", message="Should the plugin include a config?", "-i",
help="Include a config in the plugin stub", default=False, is_flag=True) "--id",
message="ID",
required=True,
help="The maubot plugin ID (Java package name format)",
)
@cliq.option(
"-v",
"--version",
help="Initial version for project (PEP-440 format)",
default="0.1.0",
validator=VersionValidator,
required=True,
)
@cliq.option(
"-l",
"--license",
validator=SPDXValidator,
default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)",
required=False,
)
@cliq.option(
"-c",
"--config",
message="Should the plugin include a config?",
help="Include a config in the plugin stub",
default=False,
is_flag=True,
)
def init(name: str, id: str, version: Version, license: str, config: bool) -> None: def init(name: str, id: str, version: Version, license: str, config: bool) -> None:
load_templates() load_templates()
main_class = name[0].upper() + name[1:] main_class = name[0].upper() + name[1:]
meta = meta_template.render(id=id, version=str(version), license=license, config=config, meta = meta_template.render(
main_class=main_class) id=id, version=str(version), license=license, config=config, main_class=main_class
)
with open("maubot.yaml", "w") as file: with open("maubot.yaml", "w") as file:
file.write(meta) file.write(meta)
if license: if license:
with open("LICENSE", "w") as file: with open("LICENSE", "w") as file:
file.write(spdx.get(license)["text"]) file.write(spdx.get(license)["licenseText"])
if not os.path.isdir(name): if not os.path.isdir(name):
os.mkdir(name) os.mkdir(name)
mod = mod_template.render(config=config, name=main_class) mod = mod_template.render(config=config, name=main_class)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -20,17 +20,39 @@ from colorama import Fore
from yarl import URL from yarl import URL
import aiohttp import aiohttp
from ..config import save_config, config
from ..cliq import cliq from ..cliq import cliq
from ..config import config, save_config
@cliq.command(help="Log in to a Maubot instance") @cliq.command(help="Log in to a Maubot instance")
@cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True) @cliq.option(
@cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True) "-u",
@cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True) "--username",
@cliq.option("-a", "--alias", help="Alias to reference the server without typing the full URL", default="", required=False) help="The username of your account",
default=os.environ.get("USER", None),
required=True,
)
@cliq.option(
"-p", "--password", help="The password to your account", inq_type="password", required=True
)
@cliq.option(
"-s",
"--server",
help="The server to log in to",
default="http://localhost:29316",
required=True,
)
@cliq.option(
"-a",
"--alias",
help="Alias to reference the server without typing the full URL",
default="",
required=False,
)
@cliq.with_http @cliq.with_http
async def login(server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession) -> None: async def login(
server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession
) -> None:
data = { data = {
"username": username, "username": username,
"password": password, "password": password,

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,14 +16,14 @@
from datetime import datetime from datetime import datetime
import asyncio import asyncio
from aiohttp import ClientSession, WSMessage, WSMsgType
from colorama import Fore from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
import click import click
from mautrix.types import Obj from mautrix.types import Obj
from ..config import get_token
from ..base import app from ..base import app
from ..config import get_token
history_count: int = 10 history_count: int = 10
@ -50,7 +50,7 @@ def logs(server: str, tail: int) -> None:
def parsedate(entry: Obj) -> None: def parsedate(entry: Obj) -> None:
i = entry.time.index("+") i = entry.time.index("+")
i = entry.time.index(":", i) i = entry.time.index(":", i)
entry.time = entry.time[:i] + entry.time[i + 1:] entry.time = entry.time[:i] + entry.time[i + 1 :]
entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z") entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z")
@ -66,13 +66,16 @@ levelcolors = {
def print_entry(entry: dict) -> None: def print_entry(entry: dict) -> None:
entry = Obj(**entry) entry = Obj(**entry)
parsedate(entry) parsedate(entry)
print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}" print(
.format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), "{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format(
level=entry.levelname, date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
levelcolor=levelcolors.get(entry.levelname, ""), level=entry.levelname,
resetcolor=Fore.RESET, levelcolor=levelcolors.get(entry.levelname, ""),
logger=entry.name, resetcolor=Fore.RESET,
message=entry.msg)) logger=entry.name,
message=entry.msg,
)
)
if entry.exc_info: if entry.exc_info:
print(entry.exc_info) print(entry.exc_info)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -43,8 +43,10 @@ async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> Non
async with sess.post(url, data=file, headers=headers) as resp: async with sess.post(url, data=file, headers=headers) as resp:
if resp.status in (200, 201): if resp.status in (200, 201):
data = await resp.json() data = await resp.json()
print(f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} " print(
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}") f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} "
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}"
)
else: else:
try: try:
err = await resp.json() err = await resp.json()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Optional, Dict, Any from __future__ import annotations
from typing import Any
import json import json
import os import os
from colorama import Fore from colorama import Fore
config: Dict[str, Any] = { config: dict[str, Any] = {
"servers": {}, "servers": {},
"aliases": {}, "aliases": {},
"default_server": None, "default_server": None,
@ -27,9 +29,9 @@ config: Dict[str, Any] = {
configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config")) configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config"))
def get_default_server() -> Tuple[Optional[str], Optional[str]]: def get_default_server() -> tuple[str | None, str | None]:
try: try:
server: Optional[str] = config["default_server"] server: str < None = config["default_server"]
except KeyError: except KeyError:
server = None server = None
if server is None: if server is None:
@ -38,7 +40,7 @@ def get_default_server() -> Tuple[Optional[str], Optional[str]]:
return server, _get_token(server) return server, _get_token(server)
def get_token(server: str) -> Tuple[Optional[str], Optional[str]]: def get_token(server: str) -> tuple[str | None, str | None]:
if not server: if not server:
return get_default_server() return get_default_server()
if server in config["aliases"]: if server in config["aliases"]:
@ -46,14 +48,14 @@ def get_token(server: str) -> Tuple[Optional[str], Optional[str]]:
return server, _get_token(server) return server, _get_token(server)
def _resolve_alias(alias: str) -> Optional[str]: def _resolve_alias(alias: str) -> str | None:
try: try:
return config["aliases"][alias] return config["aliases"][alias]
except KeyError: except KeyError:
return None return None
def _get_token(server: str) -> Optional[str]: def _get_token(server: str) -> str | None:
try: try:
return config["servers"][server] return config["servers"][server]
except KeyError: except KeyError:

Binary file not shown.

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Optional from __future__ import annotations
import zipfile
import pkg_resources
import json
spdx_list: Optional[Dict[str, Dict[str, str]]] = None import json
import zipfile
import pkg_resources
spdx_list: dict[str, dict[str, str]] | None = None
def load() -> None: def load() -> None:
@ -31,7 +33,7 @@ def load() -> None:
spdx_list = json.load(file) spdx_list = json.load(file)
def get(id: str) -> Dict[str, str]: def get(id: str) -> dict[str, str]:
if not spdx_list: if not spdx_list:
load() load()
return spdx_list[id.lower()] return spdx_list[id.lower()]

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,32 +13,46 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
import asyncio import asyncio
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
DeviceID,
EventFilter,
EventType,
Filter,
FilterID,
Membership,
PresenceState,
RoomEventFilter,
RoomFilter,
StateEvent,
StateFilter,
StrippedStateEvent,
SyncToken,
UserID,
)
from .lib.store_proxy import SyncStoreProxy
from .db import DBClient from .db import DBClient
from .lib.store_proxy import SyncStoreProxy
from .matrix import MaubotMatrixClient from .matrix import MaubotMatrixClient
try: try:
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, PgCryptoStore from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
from mautrix.util.async_db import Database as AsyncDatabase from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore): class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass pass
crypto_import_error = None crypto_import_error = None
except ImportError as e: except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
@ -46,8 +60,8 @@ except ImportError as e:
crypto_import_error = e crypto_import_error = e
if TYPE_CHECKING: if TYPE_CHECKING:
from .instance import PluginInstance
from .config import Config from .config import Config
from .instance import PluginInstance
log = logging.getLogger("maubot.client") log = logging.getLogger("maubot.client")
@ -55,20 +69,20 @@ log = logging.getLogger("maubot.client")
class Client: class Client:
log: logging.Logger = None log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {} cache: dict[UserID, Client] = {}
http_client: ClientSession = None http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: Optional['AsyncDatabase'] = None crypto_db: AsyncDatabase | None = None
references: Set['PluginInstance'] references: set[PluginInstance]
db_instance: DBClient db_instance: DBClient
client: MaubotMatrixClient client: MaubotMatrixClient
crypto: Optional['OlmMachine'] crypto: OlmMachine | None
crypto_store: Optional['PgCryptoStore'] crypto_store: PgCryptoStore | None
started: bool started: bool
remote_displayname: Optional[str] remote_displayname: str | None
remote_avatar_url: Optional[ContentURI] remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None: def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance self.db_instance = db_instance
@ -79,11 +93,17 @@ class Client:
self.sync_ok = True self.sync_ok = True
self.remote_displayname = None self.remote_displayname = None
self.remote_avatar_url = None self.remote_avatar_url = None
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, self.client = MaubotMatrixClient(
token=self.access_token, client_session=self.http_client, mxid=self.id,
log=self.log, loop=self.loop, device_id=self.device_id, base_url=self.homeserver,
sync_store=SyncStoreProxy(self.db_instance), token=self.access_token,
state_store=self.global_state_store) client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
if self.enable_crypto: if self.enable_crypto:
self._prepare_crypto() self._prepare_crypto()
else: else:
@ -104,8 +124,10 @@ class Client:
return False return False
elif not OlmMachine: elif not OlmMachine:
global crypto_import_error global crypto_import_error
self.log.warning("Client has device ID, but encryption dependencies not installed", self.log.warning(
exc_info=crypto_import_error) "Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
# Clear the stack trace after it's logged once to avoid spamming logs # Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None crypto_import_error = None
return False return False
@ -115,8 +137,9 @@ class Client:
return True return True
def _prepare_crypto(self) -> None: def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", self.crypto_store = PgCryptoStore(
db=self.crypto_db) account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto self.client.crypto = self.crypto
@ -133,13 +156,13 @@ class Client:
for event_type, func in handlers: for event_type, func in handlers:
self.client.remove_event_handler(event_type, func) self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None: async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok self.sync_ok = ok
return handler return handler
async def start(self, try_n: Optional[int] = 0) -> None: async def start(self, try_n: int | None = 0) -> None:
try: try:
if try_n > 0: if try_n > 0:
await asyncio.sleep(try_n * 10) await asyncio.sleep(try_n * 10)
@ -152,15 +175,16 @@ class Client:
await self.crypto_store.open() await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id() crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id: if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database, " self.log.warning(
"resetting encryption") "Mismatching device ID in crypto store and main database, " "resetting encryption"
)
await self.crypto_store.delete() await self.crypto_store.delete()
crypto_device_id = None crypto_device_id = None
await self.crypto.load() await self.crypto.load()
if not crypto_device_id: if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id) await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: Optional[int] = 0) -> None: async def _start(self, try_n: int | None = 0) -> None:
if not self.enabled: if not self.enabled:
self.log.debug("Not starting disabled client") self.log.debug("Not starting disabled client")
return return
@ -179,8 +203,9 @@ class Client:
self.log.exception("Failed to get /account/whoami, disabling client") self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False self.db_instance.enabled = False
else: else:
self.log.warning(f"Failed to get /account/whoami, " self.log.warning(
f"retrying in {(try_n + 1) * 10}s: {e}") f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
)
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return return
if whoami.user_id != self.id: if whoami.user_id != self.id:
@ -188,25 +213,30 @@ class Client:
self.db_instance.enabled = False self.db_instance.enabled = False
return return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(f"Device ID mismatch: expected {self.device_id}, " self.log.error(
f"but got {whoami.device_id}") f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
self.db_instance.enabled = False self.db_instance.enabled = False
return return
if not self.filter_id: if not self.filter_id:
self.db_instance.edit(filter_id=await self.client.create_filter(Filter( self.db_instance.edit(
room=RoomFilter( filter_id=await self.client.create_filter(
timeline=RoomEventFilter( Filter(
limit=50, room=RoomFilter(
lazy_load_members=True, timeline=RoomEventFilter(
), limit=50,
state=StateFilter( lazy_load_members=True,
lazy_load_members=True, ),
state=StateFilter(
lazy_load_members=True,
),
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
) )
), )
presence=EventFilter( )
not_types=[EventType.PRESENCE],
),
)))
if self.displayname != "disable": if self.displayname != "disable":
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable": if self.avatar_url != "disable":
@ -258,8 +288,9 @@ class Client:
"homeserver": self.homeserver, "homeserver": self.homeserver,
"access_token": self.access_token, "access_token": self.access_token,
"device_id": self.device_id, "device_id": self.device_id,
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account "fingerprint": (
else None), self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
),
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"sync": self.sync, "sync": self.sync,
@ -274,7 +305,7 @@ class Client:
} }
@classmethod @classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try: try:
return cls.cache[user_id] return cls.cache[user_id]
except KeyError: except KeyError:
@ -284,7 +315,7 @@ class Client:
return Client(db_instance) return Client(db_instance)
@classmethod @classmethod
def all(cls) -> Iterable['Client']: def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all()) return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None: async def _handle_tombstone(self, evt: StateEvent) -> None:
@ -324,8 +355,12 @@ class Client:
else: else:
await self._update_remote_profile() await self._update_remote_profile()
async def update_access_details(self, access_token: Optional[str], homeserver: Optional[str], async def update_access_details(
device_id: Optional[str] = None) -> None: self,
access_token: str | None,
homeserver: str | None,
device_id: str | None = None,
) -> None:
if not access_token and not homeserver: if not access_token and not homeserver:
return return
if device_id is None: if device_id is None:
@ -338,10 +373,16 @@ class Client:
and device_id == self.device_id and device_id == self.device_id
): ):
return return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, new_client = MaubotMatrixClient(
token=access_token or self.access_token, loop=self.loop, mxid=self.id,
device_id=device_id, client_session=self.http_client, base_url=homeserver or self.homeserver,
log=self.log, state_store=self.global_state_store) token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
whoami = await new_client.whoami() whoami = await new_client.whoami()
if whoami.user_id != self.id: if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}") raise ValueError(f"MXID mismatch: {whoami.user_id}")
@ -455,7 +496,7 @@ class Client:
# endregion # endregion
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]: def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop) Client.http_client = ClientSession(loop=loop)
Client.loop = loop Client.loop = loop

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,9 +14,10 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import random import random
import string
import bcrypt
import re import re
import string
import bcrypt
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
@ -64,8 +65,9 @@ class Config(BaseFileConfig):
if password and not bcrypt_regex.match(password): if password and not bcrypt_regex.match(password):
if password == "password": if password == "password":
password = self._new_token() password = self._new_token()
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"), base["admins"][username] = bcrypt.hashpw(
bcrypt.gensalt()).decode("utf-8") password.encode("utf-8"), bcrypt.gensalt()
).decode("utf-8")
copy("api_features.login") copy("api_features.login")
copy("api_features.plugin") copy("api_features.plugin")
copy("api_features.plugin_upload") copy("api_features.plugin_upload")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,13 +17,13 @@ from typing import Iterable, Optional
import logging import logging
import sys import sys
from sqlalchemy import Column, String, Boolean, ForeignKey, Text from sqlalchemy import Boolean, Column, ForeignKey, String, Text
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
from mautrix.util.db import Base
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.db import Base
from .config import Config from .config import Config
@ -34,17 +34,19 @@ class DBPlugin(Base):
id: str = Column(String(255), primary_key=True) id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False) type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False) enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(String(255), primary_user: UserID = Column(
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), String(255),
nullable=False) ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
config: str = Column(Text, nullable=False, default='') nullable=False,
)
config: str = Column(Text, nullable=False, default="")
@classmethod @classmethod
def all(cls) -> Iterable['DBPlugin']: def all(cls) -> Iterable["DBPlugin"]:
return cls._select_all() return cls._select_all()
@classmethod @classmethod
def get(cls, id: str) -> Optional['DBPlugin']: def get(cls, id: str) -> Optional["DBPlugin"]:
return cls._select_one_or_none(cls.c.id == id) return cls._select_one_or_none(cls.c.id == id)
@ -68,11 +70,11 @@ class DBClient(Base):
avatar_url: ContentURI = Column(String(255), nullable=False, default="") avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod @classmethod
def all(cls) -> Iterable['DBClient']: def all(cls) -> Iterable["DBClient"]:
return cls._select_all() return cls._select_all()
@classmethod @classmethod
def get(cls, id: str) -> Optional['DBClient']: def get(cls, id: str) -> Optional["DBClient"]:
return cls._select_one_or_none(cls.c.id == id) return cls._select_one_or_none(cls.c.id == id)
@ -87,15 +89,20 @@ def init(config: Config) -> Engine:
log = logging.getLogger("maubot.db") log = logging.getLogger("maubot.db")
if db.has_table("client") and db.has_table("plugin"): if db.has_table("client") and db.has_table("plugin"):
log.warning("alembic_version table not found, but client and plugin tables found. " log.warning(
"Assuming pre-Alembic database and inserting version.") "alembic_version table not found, but client and plugin tables found. "
db.execute("CREATE TABLE IF NOT EXISTS alembic_version (" "Assuming pre-Alembic database and inserting version."
" version_num VARCHAR(32) PRIMARY KEY" )
");") db.execute(
"CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");"
)
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');") db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
else: else:
log.critical("alembic_version table not found. " log.critical(
"Did you forget to `alembic upgrade head`?") "alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
)
sys.exit(10) sys.exit(10)
return db return db

View File

@ -1 +1 @@
from . import event, command, web from . import command, event, web

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,29 +13,46 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, from typing import (
Dict, Tuple, Set, Iterable) Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
NewType,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import functools import functools
import inspect import inspect
import re import re
from mautrix.types import MessageType, EventType from mautrix.types import EventType, MessageType
from ..matrix import MaubotMessageEvent from ..matrix import MaubotMessageEvent
from . import event from . import event
PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]] PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]]
AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool], AliasesType = Union[
Callable[[Any, str], bool]] List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool]
CommandHandlerFunc = NewType("CommandHandlerFunc", ]
Callable[[MaubotMessageEvent, Any], Awaitable[Any]]) CommandHandlerFunc = NewType(
CommandHandlerDecorator = NewType("CommandHandlerDecorator", "CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]]
Callable[[Union['CommandHandler', CommandHandlerFunc]], )
'CommandHandler']) CommandHandlerDecorator = NewType(
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", "CommandHandlerDecorator",
Callable[[CommandHandlerFunc], CommandHandlerFunc]) Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"],
)
PassiveCommandHandlerDecorator = NewType(
"PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc]
)
def _split_in_two(val: str, split_by: str) -> List[str]: def _split_in_two(val: str, split_by: str) -> List[str]:
@ -67,15 +84,26 @@ class CommandHandler:
return self.__bound_copies__[instance] return self.__bound_copies__[instance]
except KeyError: except KeyError:
new_ch = type(self)(self.__mb_func__) new_ch = type(self)(self.__mb_func__)
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match", keys = [
"require_subcommand", "arg_fallthrough", "event_handler", "event_type", "parent",
"msgtypes"] "subcommands",
"arguments",
"help",
"get_name",
"is_command_match",
"require_subcommand",
"arg_fallthrough",
"event_handler",
"event_type",
"msgtypes",
]
for key in keys: for key in keys:
key = f"__mb_{key}__" key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key)) setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype) new_ch.__mb_subcommands__ = [
for subcmd in self.__mb_subcommands__] subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__
]
self.__bound_copies__[instance] = new_ch self.__bound_copies__[instance] = new_ch
return new_ch return new_ch
@ -83,8 +111,13 @@ class CommandHandler:
def __command_match_unset(self, val: str) -> bool: def __command_match_unset(self, val: str) -> bool:
raise NotImplementedError("Hmm") raise NotImplementedError("Hmm")
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, async def __call__(
remaining_val: str = None) -> Any: self,
evt: MaubotMessageEvent,
*,
_existing_args: Dict[str, Any] = None,
remaining_val: str = None,
) -> Any:
if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__: if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__:
return return
if remaining_val is None: if remaining_val is None:
@ -120,21 +153,25 @@ class CommandHandler:
return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
return await self.__mb_func__(evt, **call_args) return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], async def __call_subcommand__(
remaining_val: str) -> Tuple[bool, Any]: self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ") command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__: for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand(evt, _existing_args=call_args, return True, await subcommand(
remaining_val=remaining_val) evt, _existing_args=call_args, remaining_val=remaining_val
)
return False, None return False, None
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], async def __parse_args__(
remaining_val: str) -> Tuple[bool, str]: self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, str]:
for arg in self.__mb_arguments__: for arg in self.__mb_arguments__:
try: try:
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip(), evt=evt, remaining_val, call_args[arg.name] = arg.match(
instance=self.__bound_instance__) remaining_val.strip(), evt=evt, instance=self.__bound_instance__
)
if arg.required and call_args[arg.name] is None: if arg.required and call_args[arg.name] is None:
raise ValueError("Argument required") raise ValueError("Argument required")
except ArgumentSyntaxError as e: except ArgumentSyntaxError as e:
@ -155,8 +192,9 @@ class CommandHandler:
@property @property
def __mb_usage_args__(self) -> str: def __mb_usage_args__(self) -> str:
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]" arg_usage = " ".join(
for arg in self.__mb_arguments__) f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__
)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__: if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__ arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage return arg_usage
@ -172,15 +210,19 @@ class CommandHandler:
@property @property
def __mb_prefix__(self) -> str: def __mb_prefix__(self) -> str:
if self.__mb_parent__: if self.__mb_parent__:
return (f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " return (
f"{self.__mb_name__}") f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} "
f"{self.__mb_name__}"
)
return f"!{self.__mb_name__}" return f"!{self.__mb_name__}"
@property @property
def __mb_usage_inline__(self) -> str: def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__: if not self.__mb_arg_fallthrough__:
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" return (
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}") f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}"
)
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}" return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property @property
@ -192,8 +234,10 @@ class CommandHandler:
if not self.__mb_arg_fallthrough__: if not self.__mb_arg_fallthrough__:
if not self.__mb_arguments__: if not self.__mb_arguments__:
return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]" return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]"
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" return (
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}") f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}"
)
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property @property
@ -202,14 +246,25 @@ class CommandHandler:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__ return self.__mb_usage_without_subcommands__
def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, def subcommand(
required_subcommand: bool = True, arg_fallthrough: bool = True, self,
) -> CommandHandlerDecorator: name: PrefixType = None,
*,
help: str = None,
aliases: AliasesType = None,
required_subcommand: bool = True,
arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler): if not isinstance(func, CommandHandler):
func = CommandHandler(func) func = CommandHandler(func)
new(name, help=help, aliases=aliases, require_subcommand=required_subcommand, new(
arg_fallthrough=arg_fallthrough)(func) name,
help=help,
aliases=aliases,
require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough,
)(func)
func.__mb_parent__ = self func.__mb_parent__ = self
func.__mb_event_handler__ = False func.__mb_event_handler__ = False
self.__mb_subcommands__.append(func) self.__mb_subcommands__.append(func)
@ -218,10 +273,17 @@ class CommandHandler:
return decorator return decorator
def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, def new(
event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None, name: PrefixType = None,
require_subcommand: bool = True, arg_fallthrough: bool = True, *,
must_consume_args: bool = True) -> CommandHandlerDecorator: help: str = None,
aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE,
msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True,
arg_fallthrough: bool = True,
must_consume_args: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler): if not isinstance(func, CommandHandler):
func = CommandHandler(func) func = CommandHandler(func)
@ -242,8 +304,9 @@ def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = Non
else: else:
func.__mb_is_command_match__ = aliases func.__mb_is_command_match__ = aliases
elif isinstance(aliases, (list, set, tuple)): elif isinstance(aliases, (list, set, tuple)):
func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_get_name__(self) func.__mb_is_command_match__ = lambda self, val: (
or val in aliases) val == func.__mb_get_name__(self) or val in aliases
)
else: else:
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self) func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self)
# Decorators are executed last to first, so we reverse the argument list. # Decorators are executed last to first, so we reverse the argument list.
@ -267,8 +330,9 @@ class ArgumentSyntaxError(ValueError):
class Argument(ABC): class Argument(ABC):
def __init__(self, name: str, label: str = None, *, required: bool = False, def __init__(
pass_raw: bool = False) -> None: self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False
) -> None:
self.name = name self.name = name
self.label = label or name self.label = label or name
self.required = required self.required = required
@ -286,8 +350,15 @@ class Argument(ABC):
class RegexArgument(Argument): class RegexArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False, def __init__(
pass_raw: bool = False, matches: str = None) -> None: self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matches: str = None,
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw) super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$" matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches) self.regex = re.compile(matches)
@ -298,14 +369,23 @@ class RegexArgument(Argument):
val = re.split(r"\s", val, 1)[0] val = re.split(r"\s", val, 1)[0]
match = self.regex.match(val) match = self.regex.match(val)
if match: if match:
return (orig_val[:match.start()] + orig_val[match.end():], return (
match.groups() or val[match.start():match.end()]) orig_val[: match.start()] + orig_val[match.end() :],
match.groups() or val[match.start() : match.end()],
)
return orig_val, None return orig_val, None
class CustomArgument(Argument): class CustomArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False, def __init__(
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None: self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matcher: Callable[[str], Any],
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw) super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher self.matcher = matcher
@ -316,7 +396,7 @@ class CustomArgument(Argument):
val = re.split(r"\s", val, 1)[0] val = re.split(r"\s", val, 1)[0]
res = self.matcher(val) res = self.matcher(val)
if res is not None: if res is not None:
return orig_val[len(val):], res return orig_val[len(val) :], res
return orig_val, None return orig_val, None
@ -325,12 +405,18 @@ class SimpleArgument(Argument):
if self.pass_raw: if self.pass_raw:
return "", val return "", val
res = re.split(r"\s", val, 1)[0] res = re.split(r"\s", val, 1)[0]
return val[len(res):], res return val[len(res) :], res
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None, def argument(
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False name: str,
) -> CommandHandlerDecorator: label: str = None,
*,
required: bool = True,
matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False,
) -> CommandHandlerDecorator:
if matches: if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw) return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser: elif parser:
@ -339,11 +425,17 @@ def argument(name: str, label: str = None, *, required: bool = True, matches: Op
return SimpleArgument(name, label, required=required, pass_raw=pass_raw) return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,), def passive(
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, regex: Union[str, Pattern],
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False, *,
case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
) -> PassiveCommandHandlerDecorator: field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE,
multiple: bool = False,
case_insensitive: bool = False,
multiline: bool = False,
dot_all: bool = False,
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern): if not isinstance(regex, Pattern):
flags = re.RegexFlag.UNICODE flags = re.RegexFlag.UNICODE
if case_insensitive: if case_insensitive:
@ -372,12 +464,14 @@ def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (Me
return return
data = field(evt) data = field(evt)
if multiple: if multiple:
val = [(data[match.pos:match.endpos], *match.groups()) val = [
for match in regex.finditer(data)] (data[match.pos : match.endpos], *match.groups())
for match in regex.finditer(data)
]
else: else:
match = regex.search(data) match = regex.search(data)
if match: if match:
val = (data[match.pos:match.endpos], *match.groups()) val = (data[match.pos : match.endpos], *match.groups())
else: else:
val = None val = None
if val: if val:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,16 +13,17 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Union, NewType from __future__ import annotations
from typing import Callable, NewType
from mautrix.types import EventType
from mautrix.client import EventHandler, InternalEventType from mautrix.client import EventHandler, InternalEventType
from mautrix.types import EventType
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: Union[EventType, InternalEventType, EventHandler] def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler:
) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler: def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True func.__mb_event_handler__ = True
if isinstance(var, (EventType, InternalEventType)): if isinstance(var, (EventType, InternalEventType)):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,9 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Any, Awaitable from typing import Any, Awaitable, Callable
from aiohttp import web, hdrs from aiohttp import hdrs, web
WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]] WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]
WebHandlerDecorator = Callable[[WebHandler], WebHandler] WebHandlerDecorator = Callable[[WebHandler], WebHandler]

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,24 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Iterable, TYPE_CHECKING from __future__ import annotations
from asyncio import AbstractEventLoop
import os.path from typing import TYPE_CHECKING, Iterable
import logging from asyncio import AbstractEventLoop
import io import io
import logging
import os.path
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .db import DBPlugin
from .config import Config
from .client import Client from .client import Client
from .config import Config
from .db import DBPlugin
from .loader import PluginLoader, ZippedPluginLoader from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
@ -43,23 +45,23 @@ yaml.width = 200
class PluginInstance: class PluginInstance:
webserver: 'MaubotServer' = None webserver: MaubotServer = None
mb_config: Config = None mb_config: Config = None
loop: AbstractEventLoop = None loop: AbstractEventLoop = None
cache: Dict[str, 'PluginInstance'] = {} cache: dict[str, PluginInstance] = {}
plugin_directories: List[str] = [] plugin_directories: list[str] = []
log: logging.Logger log: logging.Logger
loader: PluginLoader loader: PluginLoader
client: Client client: Client
plugin: Plugin plugin: Plugin
config: BaseProxyConfig config: BaseProxyConfig
base_cfg: Optional[RecursiveDict[CommentedMap]] base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: Optional[str] base_cfg_str: str | None
inst_db: sql.engine.Engine inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table] inst_db_tables: dict[str, sql.Table]
inst_webapp: Optional['PluginWebApp'] inst_webapp: PluginWebApp | None
inst_webapp_url: Optional[str] inst_webapp_url: str | None
started: bool started: bool
def __init__(self, db_instance: DBPlugin): def __init__(self, db_instance: DBPlugin):
@ -87,11 +89,12 @@ class PluginInstance:
"primary_user": self.primary_user, "primary_user": self.primary_user,
"config": self.db_instance.config, "config": self.db_instance.config,
"base_config": self.base_cfg_str, "base_config": self.base_cfg_str,
"database": (self.inst_db is not None "database": (
and self.mb_config["api_features.instance_database"]), self.inst_db is not None and self.mb_config["api_features.instance_database"]
),
} }
def get_db_tables(self) -> Dict[str, sql.Table]: def get_db_tables(self) -> dict[str, sql.Table]:
if not self.inst_db_tables: if not self.inst_db_tables:
metadata = sql.MetaData() metadata = sql.MetaData()
metadata.reflect(self.inst_db) metadata.reflect(self.inst_db)
@ -147,7 +150,8 @@ class PluginInstance:
self.inst_db.dispose() self.inst_db.dispose()
ZippedPluginLoader.trash( ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted") reason="deleted",
)
if self.inst_webapp: if self.inst_webapp:
self.disable_webapp() self.disable_webapp()
@ -194,13 +198,23 @@ class PluginInstance:
if self.base_cfg: if self.base_cfg:
base_cfg_func = self.base_cfg.clone base_cfg_func = self.base_cfg.clone
else: else:
def base_cfg_func() -> None: def base_cfg_func() -> None:
return None return None
self.config = config_class(self.load_config, base_cfg_func, self.save_config) self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client, self.plugin = cls(
instance_id=self.id, log=self.log, config=self.config, client=self.client.client,
database=self.inst_db, loader=self.loader, webapp=self.inst_webapp, loop=self.loop,
webapp_url=self.inst_webapp_url) http=self.client.http_client,
instance_id=self.id,
log=self.log,
config=self.config,
database=self.inst_db,
loader=self.loader,
webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url,
)
try: try:
await self.plugin.internal_start() await self.plugin.internal_start()
except Exception: except Exception:
@ -209,8 +223,10 @@ class PluginInstance:
return return
self.started = True self.started = True
self.inst_db_tables = None self.inst_db_tables = None
self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " self.log.info(
f"with user {self.client.id}") f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}"
)
async def stop(self) -> None: async def stop(self) -> None:
if not self.started: if not self.started:
@ -226,8 +242,7 @@ class PluginInstance:
self.inst_db_tables = None self.inst_db_tables = None
@classmethod @classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
) -> Optional['PluginInstance']:
try: try:
return cls.cache[instance_id] return cls.cache[instance_id]
except KeyError: except KeyError:
@ -237,7 +252,7 @@ class PluginInstance:
return PluginInstance(db_instance) return PluginInstance(db_instance)
@classmethod @classmethod
def all(cls) -> Iterable['PluginInstance']: def all(cls) -> Iterable[PluginInstance]:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all()) return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None: def update_id(self, new_id: str) -> None:
@ -317,8 +332,9 @@ class PluginInstance:
# endregion # endregion
def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop def init(
) -> Iterable[PluginInstance]: config: Config, webserver: MaubotServer, loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config PluginInstance.mb_config = config
PluginInstance.loop = loop PluginInstance.loop = loop
PluginInstance.webserver = webserver PluginInstance.webserver = webserver

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2020 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,13 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR, from mautrix.util.logging.color import (
MXID_COLOR, RESET) MAU_COLOR,
MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
INST_COLOR = PREFIX + "35m" # magenta INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue LOADER_COLOR = PREFIX + "36m" # blue

View File

@ -1,4 +1,5 @@
from typing import Callable, Awaitable, Generator, Any from typing import Any, Awaitable, Callable, Generator
class FutureAwaitable: class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None: def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
@ -6,4 +7,3 @@ class FutureAwaitable:
def __await__(self) -> Generator[Any, None, None]: def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__() return self._func().__await__()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by

View File

@ -18,26 +18,28 @@ used by the builtin import mechanism for sys.path items that are paths
to Zip archives. to Zip archives.
""" """
from importlib import _bootstrap_external
from importlib import _bootstrap # for _verbose_message from importlib import _bootstrap # for _verbose_message
import _imp # for check_hash_based_pycs from importlib import _bootstrap_external
import _io # for open
import marshal # for loads import marshal # for loads
import sys # for modules import sys # for modules
import time # for mktime import time # for mktime
__all__ = ['ZipImportError', 'zipimporter'] import _imp # for check_hash_based_pycs
import _io # for open
__all__ = ["ZipImportError", "zipimporter"]
def _unpack_uint32(data): def _unpack_uint32(data):
"""Convert 4 bytes in little-endian to an integer.""" """Convert 4 bytes in little-endian to an integer."""
assert len(data) == 4 assert len(data) == 4
return int.from_bytes(data, 'little') return int.from_bytes(data, "little")
def _unpack_uint16(data): def _unpack_uint16(data):
"""Convert 2 bytes in little-endian to an integer.""" """Convert 2 bytes in little-endian to an integer."""
assert len(data) == 2 assert len(data) == 2
return int.from_bytes(data, 'little') return int.from_bytes(data, "little")
path_sep = _bootstrap_external.path_sep path_sep = _bootstrap_external.path_sep
@ -47,15 +49,17 @@ alt_path_sep = _bootstrap_external.path_separators[1:]
class ZipImportError(ImportError): class ZipImportError(ImportError):
pass pass
# _read_directory() cache # _read_directory() cache
_zip_directory_cache = {} _zip_directory_cache = {}
_module_type = type(sys) _module_type = type(sys)
END_CENTRAL_DIR_SIZE = 22 END_CENTRAL_DIR_SIZE = 22
STRING_END_ARCHIVE = b'PK\x05\x06' STRING_END_ARCHIVE = b"PK\x05\x06"
MAX_COMMENT_LEN = (1 << 16) - 1 MAX_COMMENT_LEN = (1 << 16) - 1
class zipimporter: class zipimporter:
"""zipimporter(archivepath) -> zipimporter object """zipimporter(archivepath) -> zipimporter object
@ -77,9 +81,10 @@ class zipimporter:
def __init__(self, path): def __init__(self, path):
if not isinstance(path, str): if not isinstance(path, str):
import os import os
path = os.fsdecode(path) path = os.fsdecode(path)
if not path: if not path:
raise ZipImportError('archive path is empty', path=path) raise ZipImportError("archive path is empty", path=path)
if alt_path_sep: if alt_path_sep:
path = path.replace(alt_path_sep, path_sep) path = path.replace(alt_path_sep, path_sep)
@ -92,14 +97,14 @@ class zipimporter:
# Back up one path element. # Back up one path element.
dirname, basename = _bootstrap_external._path_split(path) dirname, basename = _bootstrap_external._path_split(path)
if dirname == path: if dirname == path:
raise ZipImportError('not a Zip file', path=path) raise ZipImportError("not a Zip file", path=path)
path = dirname path = dirname
prefix.append(basename) prefix.append(basename)
else: else:
# it exists # it exists
if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG
# it's a not file # it's a not file
raise ZipImportError('not a Zip file', path=path) raise ZipImportError("not a Zip file", path=path)
break break
try: try:
@ -154,11 +159,10 @@ class zipimporter:
# This is possibly a portion of a namespace # This is possibly a portion of a namespace
# package. Return the string representing its path, # package. Return the string representing its path,
# without a trailing separator. # without a trailing separator.
return None, [f'{self.archive}{path_sep}{modpath}'] return None, [f"{self.archive}{path_sep}{modpath}"]
return None, [] return None, []
# Check whether we can satisfy the import of the module named by # Check whether we can satisfy the import of the module named by
# 'fullname'. Return self if we can, None if we can't. # 'fullname'. Return self if we can, None if we can't.
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
@ -172,7 +176,6 @@ class zipimporter:
""" """
return self.find_loader(fullname, path)[0] return self.find_loader(fullname, path)[0]
def get_code(self, fullname): def get_code(self, fullname):
"""get_code(fullname) -> code object. """get_code(fullname) -> code object.
@ -182,7 +185,6 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname) code, ispackage, modpath = _get_module_code(self, fullname)
return code return code
def get_data(self, pathname): def get_data(self, pathname):
"""get_data(pathname) -> string with file data. """get_data(pathname) -> string with file data.
@ -194,15 +196,14 @@ class zipimporter:
key = pathname key = pathname
if pathname.startswith(self.archive + path_sep): if pathname.startswith(self.archive + path_sep):
key = pathname[len(self.archive + path_sep):] key = pathname[len(self.archive + path_sep) :]
try: try:
toc_entry = self._files[key] toc_entry = self._files[key]
except KeyError: except KeyError:
raise OSError(0, '', key) raise OSError(0, "", key)
return _get_data(self.archive, toc_entry) return _get_data(self.archive, toc_entry)
# Return a string matching __file__ for the named module # Return a string matching __file__ for the named module
def get_filename(self, fullname): def get_filename(self, fullname):
"""get_filename(fullname) -> filename string. """get_filename(fullname) -> filename string.
@ -214,7 +215,6 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname) code, ispackage, modpath = _get_module_code(self, fullname)
return modpath return modpath
def get_source(self, fullname): def get_source(self, fullname):
"""get_source(fullname) -> source string. """get_source(fullname) -> source string.
@ -228,9 +228,9 @@ class zipimporter:
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
if mi: if mi:
fullpath = _bootstrap_external._path_join(path, '__init__.py') fullpath = _bootstrap_external._path_join(path, "__init__.py")
else: else:
fullpath = f'{path}.py' fullpath = f"{path}.py"
try: try:
toc_entry = self._files[fullpath] toc_entry = self._files[fullpath]
@ -239,7 +239,6 @@ class zipimporter:
return None return None
return _get_data(self.archive, toc_entry).decode() return _get_data(self.archive, toc_entry).decode()
# Return a bool signifying whether the module is a package or not. # Return a bool signifying whether the module is a package or not.
def is_package(self, fullname): def is_package(self, fullname):
"""is_package(fullname) -> bool. """is_package(fullname) -> bool.
@ -252,7 +251,6 @@ class zipimporter:
raise ZipImportError(f"can't find module {fullname!r}", name=fullname) raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
return mi return mi
# Load and return the module named by 'fullname'. # Load and return the module named by 'fullname'.
def load_module(self, fullname): def load_module(self, fullname):
"""load_module(fullname) -> module. """load_module(fullname) -> module.
@ -276,7 +274,7 @@ class zipimporter:
fullpath = _bootstrap_external._path_join(self.archive, path) fullpath = _bootstrap_external._path_join(self.archive, path)
mod.__path__ = [fullpath] mod.__path__ = [fullpath]
if not hasattr(mod, '__builtins__'): if not hasattr(mod, "__builtins__"):
mod.__builtins__ = __builtins__ mod.__builtins__ = __builtins__
_bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath) _bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath)
exec(code, mod.__dict__) exec(code, mod.__dict__)
@ -287,11 +285,10 @@ class zipimporter:
try: try:
mod = sys.modules[fullname] mod = sys.modules[fullname]
except KeyError: except KeyError:
raise ImportError(f'Loaded module {fullname!r} not found in sys.modules') raise ImportError(f"Loaded module {fullname!r} not found in sys.modules")
_bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath) _bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath)
return mod return mod
def get_resource_reader(self, fullname): def get_resource_reader(self, fullname):
"""Return the ResourceReader for a package in a zip file. """Return the ResourceReader for a package in a zip file.
@ -305,11 +302,11 @@ class zipimporter:
return None return None
if not _ZipImportResourceReader._registered: if not _ZipImportResourceReader._registered:
from importlib.abc import ResourceReader from importlib.abc import ResourceReader
ResourceReader.register(_ZipImportResourceReader) ResourceReader.register(_ZipImportResourceReader)
_ZipImportResourceReader._registered = True _ZipImportResourceReader._registered = True
return _ZipImportResourceReader(self, fullname) return _ZipImportResourceReader(self, fullname)
def __repr__(self): def __repr__(self):
return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">' return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">'
@ -320,16 +317,17 @@ class zipimporter:
# are swapped by initzipimport() if we run in optimized mode. Also, # are swapped by initzipimport() if we run in optimized mode. Also,
# '/' is replaced by path_sep there. # '/' is replaced by path_sep there.
_zip_searchorder = ( _zip_searchorder = (
(path_sep + '__init__.pyc', True, True), (path_sep + "__init__.pyc", True, True),
(path_sep + '__init__.py', False, True), (path_sep + "__init__.py", False, True),
('.pyc', True, False), (".pyc", True, False),
('.py', False, False), (".py", False, False),
) )
# Given a module name, return the potential file path in the # Given a module name, return the potential file path in the
# archive (without extension). # archive (without extension).
def _get_module_path(self, fullname): def _get_module_path(self, fullname):
return self.prefix + fullname.rpartition('.')[2] return self.prefix + fullname.rpartition(".")[2]
# Does this path represent a directory? # Does this path represent a directory?
def _is_dir(self, path): def _is_dir(self, path):
@ -340,6 +338,7 @@ def _is_dir(self, path):
# If dirpath is present in self._files, we have a directory. # If dirpath is present in self._files, we have a directory.
return dirpath in self._files return dirpath in self._files
# Return some information about a module. # Return some information about a module.
def _get_module_info(self, fullname): def _get_module_info(self, fullname):
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
@ -374,7 +373,7 @@ def _get_module_info(self, fullname):
# data_size and file_offset are 0. # data_size and file_offset are 0.
def _read_directory(archive): def _read_directory(archive):
try: try:
fp = _io.open(archive, 'rb') fp = _io.open(archive, "rb")
except OSError: except OSError:
raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive)
@ -394,36 +393,33 @@ def _read_directory(archive):
fp.seek(0, 2) fp.seek(0, 2)
file_size = fp.tell() file_size = fp.tell()
except OSError: except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
path=archive) max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0)
max_comment_start = max(file_size - MAX_COMMENT_LEN -
END_CENTRAL_DIR_SIZE, 0)
try: try:
fp.seek(max_comment_start) fp.seek(max_comment_start)
data = fp.read() data = fp.read()
except OSError: except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}", raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
path=archive)
pos = data.rfind(STRING_END_ARCHIVE) pos = data.rfind(STRING_END_ARCHIVE)
if pos < 0: if pos < 0:
raise ZipImportError(f'not a Zip file: {archive!r}', raise ZipImportError(f"not a Zip file: {archive!r}", path=archive)
path=archive) buffer = data[pos : pos + END_CENTRAL_DIR_SIZE]
buffer = data[pos:pos+END_CENTRAL_DIR_SIZE]
if len(buffer) != END_CENTRAL_DIR_SIZE: if len(buffer) != END_CENTRAL_DIR_SIZE:
raise ZipImportError(f"corrupt Zip file: {archive!r}", raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive)
path=archive)
header_position = file_size - len(data) + pos header_position = file_size - len(data) + pos
header_size = _unpack_uint32(buffer[12:16]) header_size = _unpack_uint32(buffer[12:16])
header_offset = _unpack_uint32(buffer[16:20]) header_offset = _unpack_uint32(buffer[16:20])
if header_position < header_size: if header_position < header_size:
raise ZipImportError(f'bad central directory size: {archive!r}', path=archive) raise ZipImportError(f"bad central directory size: {archive!r}", path=archive)
if header_position < header_offset: if header_position < header_offset:
raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive) raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive)
header_position -= header_size header_position -= header_size
arc_offset = header_position - header_offset arc_offset = header_position - header_offset
if arc_offset < 0: if arc_offset < 0:
raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive) raise ZipImportError(
f"bad central directory size or offset: {archive!r}", path=archive
)
files = {} files = {}
# Start of Central Directory # Start of Central Directory
@ -435,12 +431,12 @@ def _read_directory(archive):
while True: while True:
buffer = fp.read(46) buffer = fp.read(46)
if len(buffer) < 4: if len(buffer) < 4:
raise EOFError('EOF read where not expected') raise EOFError("EOF read where not expected")
# Start of file header # Start of file header
if buffer[:4] != b'PK\x01\x02': if buffer[:4] != b"PK\x01\x02":
break # Bad: Central Dir File Header break # Bad: Central Dir File Header
if len(buffer) != 46: if len(buffer) != 46:
raise EOFError('EOF read where not expected') raise EOFError("EOF read where not expected")
flags = _unpack_uint16(buffer[8:10]) flags = _unpack_uint16(buffer[8:10])
compress = _unpack_uint16(buffer[10:12]) compress = _unpack_uint16(buffer[10:12])
time = _unpack_uint16(buffer[12:14]) time = _unpack_uint16(buffer[12:14])
@ -454,7 +450,7 @@ def _read_directory(archive):
file_offset = _unpack_uint32(buffer[42:46]) file_offset = _unpack_uint32(buffer[42:46])
header_size = name_size + extra_size + comment_size header_size = name_size + extra_size + comment_size
if file_offset > header_offset: if file_offset > header_offset:
raise ZipImportError(f'bad local header offset: {archive!r}', path=archive) raise ZipImportError(f"bad local header offset: {archive!r}", path=archive)
file_offset += arc_offset file_offset += arc_offset
try: try:
@ -478,18 +474,19 @@ def _read_directory(archive):
else: else:
# Historical ZIP filename encoding # Historical ZIP filename encoding
try: try:
name = name.decode('ascii') name = name.decode("ascii")
except UnicodeDecodeError: except UnicodeDecodeError:
name = name.decode('latin1').translate(cp437_table) name = name.decode("latin1").translate(cp437_table)
name = name.replace('/', path_sep) name = name.replace("/", path_sep)
path = _bootstrap_external._path_join(archive, name) path = _bootstrap_external._path_join(archive, name)
t = (path, compress, data_size, file_size, file_offset, time, date, crc) t = (path, compress, data_size, file_size, file_offset, time, date, crc)
files[name] = t files[name] = t
count += 1 count += 1
_bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive) _bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive)
return files return files
# During bootstrap, we may need to load the encodings # During bootstrap, we may need to load the encodings
# package from a ZIP file. But the cp437 encoding is implemented # package from a ZIP file. But the cp437 encoding is implemented
# in Python in the encodings package. # in Python in the encodings package.
@ -498,31 +495,31 @@ def _read_directory(archive):
# the cp437 encoding. # the cp437 encoding.
cp437_table = ( cp437_table = (
# ASCII part, 8 rows x 16 chars # ASCII part, 8 rows x 16 chars
'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f' "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f"
'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f' "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
' !"#$%&\'()*+,-./' " !\"#$%&'()*+,-./"
'0123456789:;<=>?' "0123456789:;<=>?"
'@ABCDEFGHIJKLMNO' "@ABCDEFGHIJKLMNO"
'PQRSTUVWXYZ[\\]^_' "PQRSTUVWXYZ[\\]^_"
'`abcdefghijklmno' "`abcdefghijklmno"
'pqrstuvwxyz{|}~\x7f' "pqrstuvwxyz{|}~\x7f"
# non-ASCII part, 16 rows x 8 chars # non-ASCII part, 16 rows x 8 chars
'\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7' "\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7"
'\xea\xeb\xe8\xef\xee\xec\xc4\xc5' "\xea\xeb\xe8\xef\xee\xec\xc4\xc5"
'\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9' "\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9"
'\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192' "\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192"
'\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba' "\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba"
'\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb' "\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb"
'\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556' "\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556"
'\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510' "\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510"
'\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f' "\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f"
'\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567' "\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567"
'\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b' "\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b"
'\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580' "\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580"
'\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4' "\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4"
'\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229' "\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229"
'\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248' "\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248"
'\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0' "\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0"
) )
_importing_zlib = False _importing_zlib = False
@ -535,28 +532,29 @@ def _get_decompress_func():
if _importing_zlib: if _importing_zlib:
# Someone has a zlib.py[co] in their Zip file # Someone has a zlib.py[co] in their Zip file
# let's avoid a stack overflow. # let's avoid a stack overflow.
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
raise ZipImportError("can't decompress data; zlib not available") raise ZipImportError("can't decompress data; zlib not available")
_importing_zlib = True _importing_zlib = True
try: try:
from zlib import decompress from zlib import decompress
except Exception: except Exception:
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
raise ZipImportError("can't decompress data; zlib not available") raise ZipImportError("can't decompress data; zlib not available")
finally: finally:
_importing_zlib = False _importing_zlib = False
_bootstrap._verbose_message('zipimport: zlib available') _bootstrap._verbose_message("zipimport: zlib available")
return decompress return decompress
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data. # Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
def _get_data(archive, toc_entry): def _get_data(archive, toc_entry):
datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry
if data_size < 0: if data_size < 0:
raise ZipImportError('negative data size') raise ZipImportError("negative data size")
with _io.open(archive, 'rb') as fp: with _io.open(archive, "rb") as fp:
# Check to make sure the local file header is correct # Check to make sure the local file header is correct
try: try:
fp.seek(file_offset) fp.seek(file_offset)
@ -564,11 +562,11 @@ def _get_data(archive, toc_entry):
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
buffer = fp.read(30) buffer = fp.read(30)
if len(buffer) != 30: if len(buffer) != 30:
raise EOFError('EOF read where not expected') raise EOFError("EOF read where not expected")
if buffer[:4] != b'PK\x03\x04': if buffer[:4] != b"PK\x03\x04":
# Bad: Local File Header # Bad: Local File Header
raise ZipImportError(f'bad local file header: {archive!r}', path=archive) raise ZipImportError(f"bad local file header: {archive!r}", path=archive)
name_size = _unpack_uint16(buffer[26:28]) name_size = _unpack_uint16(buffer[26:28])
extra_size = _unpack_uint16(buffer[28:30]) extra_size = _unpack_uint16(buffer[28:30])
@ -601,16 +599,17 @@ def _eq_mtime(t1, t2):
# dostime only stores even seconds, so be lenient # dostime only stores even seconds, so be lenient
return abs(t1 - t2) <= 1 return abs(t1 - t2) <= 1
# Given the contents of a .py[co] file, unmarshal the data # Given the contents of a .py[co] file, unmarshal the data
# and return the code object. Return None if it the magic word doesn't # and return the code object. Return None if it the magic word doesn't
# match (we do this instead of raising an exception as we fall back # match (we do this instead of raising an exception as we fall back
# to .py if available and we don't want to mask other errors). # to .py if available and we don't want to mask other errors).
def _unmarshal_code(pathname, data, mtime): def _unmarshal_code(pathname, data, mtime):
if len(data) < 16: if len(data) < 16:
raise ZipImportError('bad pyc data') raise ZipImportError("bad pyc data")
if data[:4] != _bootstrap_external.MAGIC_NUMBER: if data[:4] != _bootstrap_external.MAGIC_NUMBER:
_bootstrap._verbose_message('{!r} has bad magic', pathname) _bootstrap._verbose_message("{!r} has bad magic", pathname)
return None # signal caller to try alternative return None # signal caller to try alternative
flags = _unpack_uint32(data[4:8]) flags = _unpack_uint32(data[4:8])
@ -619,47 +618,57 @@ def _unmarshal_code(pathname, data, mtime):
# pycs. We could validate hash-based pycs against the source, but it # pycs. We could validate hash-based pycs against the source, but it
# seems likely that most people putting hash-based pycs in a zipfile # seems likely that most people putting hash-based pycs in a zipfile
# will use unchecked ones. # will use unchecked ones.
if (_imp.check_hash_based_pycs != 'never' and if _imp.check_hash_based_pycs != "never" and (
(flags != 0x1 or _imp.check_hash_based_pycs == 'always')): flags != 0x1 or _imp.check_hash_based_pycs == "always"
):
return None return None
elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime): elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime):
_bootstrap._verbose_message('{!r} has bad mtime', pathname) _bootstrap._verbose_message("{!r} has bad mtime", pathname)
return None # signal caller to try alternative return None # signal caller to try alternative
# XXX the pyc's size field is ignored; timestamp collisions are probably # XXX the pyc's size field is ignored; timestamp collisions are probably
# unimportant with zip files. # unimportant with zip files.
code = marshal.loads(data[16:]) code = marshal.loads(data[16:])
if not isinstance(code, _code_type): if not isinstance(code, _code_type):
raise TypeError(f'compiled module {pathname!r} is not a code object') raise TypeError(f"compiled module {pathname!r} is not a code object")
return code return code
_code_type = type(_unmarshal_code.__code__) _code_type = type(_unmarshal_code.__code__)
# Replace any occurrences of '\r\n?' in the input string with '\n'. # Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings. # This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source): def _normalize_line_endings(source):
source = source.replace(b'\r\n', b'\n') source = source.replace(b"\r\n", b"\n")
source = source.replace(b'\r', b'\n') source = source.replace(b"\r", b"\n")
return source return source
# Given a string buffer containing Python source code, compile it # Given a string buffer containing Python source code, compile it
# and return a code object. # and return a code object.
def _compile_source(pathname, source): def _compile_source(pathname, source):
source = _normalize_line_endings(source) source = _normalize_line_endings(source)
return compile(source, pathname, 'exec', dont_inherit=True) return compile(source, pathname, "exec", dont_inherit=True)
# Convert the date/time values found in the Zip archive to a value # Convert the date/time values found in the Zip archive to a value
# that's compatible with the time stamp stored in .pyc files. # that's compatible with the time stamp stored in .pyc files.
def _parse_dostime(d, t): def _parse_dostime(d, t):
return time.mktime(( return time.mktime(
(d >> 9) + 1980, # bits 9..15: year (
(d >> 5) & 0xF, # bits 5..8: month (d >> 9) + 1980, # bits 9..15: year
d & 0x1F, # bits 0..4: day (d >> 5) & 0xF, # bits 5..8: month
t >> 11, # bits 11..15: hours d & 0x1F, # bits 0..4: day
(t >> 5) & 0x3F, # bits 8..10: minutes t >> 11, # bits 11..15: hours
(t & 0x1F) * 2, # bits 0..7: seconds / 2 (t >> 5) & 0x3F, # bits 8..10: minutes
-1, -1, -1)) (t & 0x1F) * 2, # bits 0..7: seconds / 2
-1,
-1,
-1,
)
)
# Given a path to a .pyc file in the archive, return the # Given a path to a .pyc file in the archive, return the
# modification time of the matching .py file, or 0 if no source # modification time of the matching .py file, or 0 if no source
@ -667,7 +676,7 @@ def _parse_dostime(d, t):
def _get_mtime_of_source(self, path): def _get_mtime_of_source(self, path):
try: try:
# strip 'c' or 'o' from *.py[co] # strip 'c' or 'o' from *.py[co]
assert path[-1:] in ('c', 'o') assert path[-1:] in ("c", "o")
path = path[:-1] path = path[:-1]
toc_entry = self._files[path] toc_entry = self._files[path]
# fetch the time stamp of the .py file for comparison # fetch the time stamp of the .py file for comparison
@ -678,13 +687,14 @@ def _get_mtime_of_source(self, path):
except (KeyError, IndexError, TypeError): except (KeyError, IndexError, TypeError):
return 0 return 0
# Get the code object associated with the module specified by # Get the code object associated with the module specified by
# 'fullname'. # 'fullname'.
def _get_module_code(self, fullname): def _get_module_code(self, fullname):
path = _get_module_path(self, fullname) path = _get_module_path(self, fullname)
for suffix, isbytecode, ispackage in _zip_searchorder: for suffix, isbytecode, ispackage in _zip_searchorder:
fullpath = path + suffix fullpath = path + suffix
_bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2) _bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2)
try: try:
toc_entry = self._files[fullpath] toc_entry = self._files[fullpath]
except KeyError: except KeyError:
@ -713,6 +723,7 @@ class _ZipImportResourceReader:
This class is allowed to reference all the innards and private parts of This class is allowed to reference all the innards and private parts of
the zipimporter. the zipimporter.
""" """
_registered = False _registered = False
def __init__(self, zipimporter, fullname): def __init__(self, zipimporter, fullname):
@ -720,9 +731,10 @@ class _ZipImportResourceReader:
self.fullname = fullname self.fullname = fullname
def open_resource(self, resource): def open_resource(self, resource):
fullname_as_path = self.fullname.replace('.', '/') fullname_as_path = self.fullname.replace(".", "/")
path = f'{fullname_as_path}/{resource}' path = f"{fullname_as_path}/{resource}"
from io import BytesIO from io import BytesIO
try: try:
return BytesIO(self.zipimporter.get_data(path)) return BytesIO(self.zipimporter.get_data(path))
except OSError: except OSError:
@ -737,8 +749,8 @@ class _ZipImportResourceReader:
def is_resource(self, name): def is_resource(self, name):
# Maybe we could do better, but if we can get the data, it's a # Maybe we could do better, but if we can get the data, it's a
# resource. Otherwise it isn't. # resource. Otherwise it isn't.
fullname_as_path = self.fullname.replace('.', '/') fullname_as_path = self.fullname.replace(".", "/")
path = f'{fullname_as_path}/{name}' path = f"{fullname_as_path}/{name}"
try: try:
self.zipimporter.get_data(path) self.zipimporter.get_data(path)
except OSError: except OSError:
@ -754,11 +766,12 @@ class _ZipImportResourceReader:
# top of the archive, and then we iterate through _files looking for # top of the archive, and then we iterate through _files looking for
# names inside that "directory". # names inside that "directory".
from pathlib import Path from pathlib import Path
fullname_path = Path(self.zipimporter.get_filename(self.fullname)) fullname_path = Path(self.zipimporter.get_filename(self.fullname))
relative_path = fullname_path.relative_to(self.zipimporter.archive) relative_path = fullname_path.relative_to(self.zipimporter.archive)
# Don't forget that fullname names a package, so its path will include # Don't forget that fullname names a package, so its path will include
# __init__.py, which we want to ignore. # __init__.py, which we want to ignore.
assert relative_path.name == '__init__.py' assert relative_path.name == "__init__.py"
package_path = relative_path.parent package_path = relative_path.parent
subdirs_seen = set() subdirs_seen = set()
for filename in self.zipimporter._files: for filename in self.zipimporter._files:

View File

@ -1,2 +1,2 @@
from .abc import BasePluginLoader, PluginLoader, PluginClass, IDConflictError, PluginMeta from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
from .zip import ZippedPluginLoader, MaubotZipImportError from .zip import MaubotZipImportError, ZippedPluginLoader

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from attr import dataclass from attr import dataclass
from packaging.version import Version, InvalidVersion from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__ from ..__meta__ import __version__
from ..plugin_base import Plugin from ..plugin_base import Plugin
@ -89,16 +89,16 @@ class BasePluginLoader(ABC):
class PluginLoader(BasePluginLoader, ABC): class PluginLoader(BasePluginLoader, ABC):
id_cache: Dict[str, 'PluginLoader'] = {} id_cache: Dict[str, "PluginLoader"] = {}
meta: PluginMeta meta: PluginMeta
references: Set['PluginInstance'] references: Set["PluginInstance"]
def __init__(self): def __init__(self):
self.references = set() self.references = set()
@classmethod @classmethod
def find(cls, plugin_id: str) -> 'PluginLoader': def find(cls, plugin_id: str) -> "PluginLoader":
return cls.id_cache[plugin_id] return cls.id_cache[plugin_id]
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -109,12 +109,14 @@ class PluginLoader(BasePluginLoader, ABC):
} }
async def stop_instances(self) -> None: async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance await asyncio.gather(
in self.references if instance.started]) *[instance.stop() for instance in self.references if instance.started]
)
async def start_instances(self) -> None: async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance await asyncio.gather(
in self.references if instance.enabled]) *[instance.start() for instance in self.references if instance.enabled]
)
@abstractmethod @abstractmethod
async def load(self) -> Type[PluginClass]: async def load(self) -> Type[PluginClass]:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Type, Tuple, Optional from __future__ import annotations
from zipfile import ZipFile, BadZipFile
from time import time from time import time
import logging from zipfile import BadZipFile, ZipFile
import sys import logging
import os import os
import sys
from ruamel.yaml import YAML, YAMLError
from packaging.version import Version from packaging.version import Version
from ruamel.yaml import YAML, YAMLError
from mautrix.types import SerializerError from mautrix.types import SerializerError
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
from ..config import Config from ..config import Config
from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
yaml = YAML() yaml = YAML()
@ -50,23 +51,25 @@ class MaubotZipLoadError(MaubotZipImportError):
class ZippedPluginLoader(PluginLoader): class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {} path_cache: dict[str, ZippedPluginLoader] = {}
log: logging.Logger = logging.getLogger("maubot.loader.zip") log: logging.Logger = logging.getLogger("maubot.loader.zip")
trash_path: str = "delete" trash_path: str = "delete"
directories: List[str] = [] directories: list[str] = []
path: str path: str | None
meta: PluginMeta meta: PluginMeta | None
main_class: str main_class: str | None
main_module: str main_module: str | None
_loaded: Type[PluginClass] _loaded: type[PluginClass] | None
_importer: zipimporter _importer: zipimporter | None
_file: ZipFile _file: ZipFile | None
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__() super().__init__()
self.path = path self.path = path
self.meta = None self.meta = None
self.main_class = None
self.main_module = None
self._loaded = None self._loaded = None
self._importer = None self._importer = None
self._file = None self._file = None
@ -75,7 +78,8 @@ class ZippedPluginLoader(PluginLoader):
try: try:
existing = self.id_cache[self.meta.id] existing = self.id_cache[self.meta.id]
raise IDConflictError( raise IDConflictError(
f"Plugin with id {self.meta.id} already loaded from {existing.source}") f"Plugin with id {self.meta.id} already loaded from {existing.source}"
)
except KeyError: except KeyError:
pass pass
self.path_cache[self.path] = self self.path_cache[self.path] = self
@ -83,13 +87,10 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}") self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}")
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {**super().to_dict(), "path": self.path}
**super().to_dict(),
"path": self.path
}
@classmethod @classmethod
def get(cls, path: str) -> 'ZippedPluginLoader': def get(cls, path: str) -> ZippedPluginLoader:
path = os.path.abspath(path) path = os.path.abspath(path)
try: try:
return cls.path_cache[path] return cls.path_cache[path]
@ -101,10 +102,12 @@ class ZippedPluginLoader(PluginLoader):
return self.path return self.path
def __repr__(self) -> str: def __repr__(self) -> str:
return ("<ZippedPlugin " return (
f"path='{self.path}' " "<ZippedPlugin "
f"meta={self.meta} " f"path='{self.path}' "
f"loaded={self._loaded is not None}>") f"meta={self.meta} "
f"loaded={self._loaded is not None}>"
)
def sync_read_file(self, path: str) -> bytes: def sync_read_file(self, path: str) -> bytes:
return self._file.read(path) return self._file.read(path)
@ -112,16 +115,19 @@ class ZippedPluginLoader(PluginLoader):
async def read_file(self, path: str) -> bytes: async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path) return self.sync_read_file(path)
def sync_list_files(self, directory: str) -> List[str]: def sync_list_files(self, directory: str) -> list[str]:
directory = directory.rstrip("/") directory = directory.rstrip("/")
return [file.filename for file in self._file.filelist return [
if os.path.dirname(file.filename) == directory] file.filename
for file in self._file.filelist
if os.path.dirname(file.filename) == directory
]
async def list_files(self, directory: str) -> List[str]: async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory) return self.sync_list_files(directory)
@staticmethod @staticmethod
def _read_meta(source) -> Tuple[ZipFile, PluginMeta]: def _read_meta(source) -> tuple[ZipFile, PluginMeta]:
try: try:
file = ZipFile(source) file = ZipFile(source)
data = file.read("maubot.yaml") data = file.read("maubot.yaml")
@ -142,7 +148,7 @@ class ZippedPluginLoader(PluginLoader):
return file, meta return file, meta
@classmethod @classmethod
def verify_meta(cls, source) -> Tuple[str, Version]: def verify_meta(cls, source) -> tuple[str, Version]:
_, meta = cls._read_meta(source) _, meta = cls._read_meta(source)
return meta.id, meta.version return meta.id, meta.version
@ -173,24 +179,24 @@ class ZippedPluginLoader(PluginLoader):
code = importer.get_code(self.main_module.replace(".", "/")) code = importer.get_code(self.main_module.replace(".", "/"))
if self.main_class not in code.co_names: if self.main_class not in code.co_names:
raise MaubotZipPreLoadError( raise MaubotZipPreLoadError(
f"Main class {self.main_class} not in {self.main_module}") f"Main class {self.main_class} not in {self.main_module}"
)
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipPreLoadError( raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e
f"Main module {self.main_module} not found in file") from e
for module in self.meta.modules: for module in self.meta.modules:
try: try:
importer.find_module(module) importer.find_module(module)
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipPreLoadError(f"Module {module} not found in file") from e raise MaubotZipPreLoadError(f"Module {module} not found in file") from e
async def load(self, reset_cache: bool = False) -> Type[PluginClass]: async def load(self, reset_cache: bool = False) -> type[PluginClass]:
try: try:
return self._load(reset_cache) return self._load(reset_cache)
except MaubotZipImportError: except MaubotZipImportError:
self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}") self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}")
raise raise
def _load(self, reset_cache: bool = False) -> Type[PluginClass]: def _load(self, reset_cache: bool = False) -> type[PluginClass]:
if self._loaded is not None and not reset_cache: if self._loaded is not None and not reset_cache:
return self._loaded return self._loaded
self._load_meta() self._load_meta()
@ -219,7 +225,7 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}") self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}")
return plugin return plugin
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]: async def reload(self, new_path: str | None = None) -> type[PluginClass]:
await self.unload() await self.unload()
if new_path is not None: if new_path is not None:
self.path = new_path self.path = new_path
@ -251,7 +257,7 @@ class ZippedPluginLoader(PluginLoader):
self.path = None self.path = None
@classmethod @classmethod
def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None: def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None:
if cls.trash_path == "delete": if cls.trash_path == "delete":
os.remove(file_path) os.remove(file_path)
else: else:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
import importlib import importlib
from aiohttp import web
from ...config import Config from ...config import Config
from .base import routes, get_config, set_config, set_loop
from .auth import check_token from .auth import check_token
from .base import get_config, routes, set_config, set_loop
from .middleware import auth, error from .middleware import auth, error
@ -30,9 +31,11 @@ def features(request: web.Request) -> web.Response:
if err is None: if err is None:
return web.json_response(data) return web.json_response(data)
else: else:
return web.json_response({ return web.json_response(
"login": data["login"], {
}) "login": data["login"],
}
)
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from __future__ import annotations
from time import time from time import time
from aiohttp import web from aiohttp import web
@ -21,7 +22,7 @@ from aiohttp import web
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token from mautrix.util.signed_token import sign_token, verify_token
from .base import routes, get_config from .base import get_config, routes
from .responses import resp from .responses import resp
@ -33,10 +34,13 @@ def is_valid_token(token: str) -> bool:
def create_token(user: UserID) -> str: def create_token(user: UserID) -> str:
return sign_token(get_config()["server.unshared_secret"], { return sign_token(
"user_id": user, get_config()["server.unshared_secret"],
"created_at": int(time()), {
}) "user_id": user,
"created_at": int(time()),
},
)
def get_token(request: web.Request) -> str: def get_token(request: web.Request) -> str:
@ -44,11 +48,11 @@ def get_token(request: web.Request) -> str:
if not token or not token.startswith("Bearer "): if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", None) token = request.query.get("access_token", None)
else: else:
token = token[len("Bearer "):] token = token[len("Bearer ") :]
return token return token
def check_token(request: web.Request) -> Optional[web.Response]: def check_token(request: web.Request) -> web.Response | None:
token = get_token(request) token = get_token(request)
if not token: if not token:
return resp.no_token return resp.no_token

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web from __future__ import annotations
import asyncio import asyncio
from aiohttp import web
from ...__meta__ import __version__ from ...__meta__ import __version__
from ...config import Config from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef() routes: web.RouteTableDef = web.RouteTableDef()
_config: Config = None _config: Config | None = None
_loop: asyncio.AbstractEventLoop = None _loop: asyncio.AbstractEventLoop | None = None
def set_config(config: Config) -> None: def set_config(config: Config) -> None:
@ -44,6 +47,4 @@ def get_loop() -> asyncio.AbstractEventLoop:
@routes.get("/version") @routes.get("/version")
async def version(_: web.Request) -> web.Response: async def version(_: web.Request) -> web.Response:
return web.json_response({ return web.json_response({"version": __version__})
"version": __version__
})

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,17 +13,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from __future__ import annotations
from json import JSONDecodeError from json import JSONDecodeError
from aiohttp import web from aiohttp import web
from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient from mautrix.client import Client as MatrixClient
from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError
from mautrix.types import FilterID, SyncToken, UserID
from ...db import DBClient
from ...client import Client from ...client import Client
from ...db import DBClient
from .base import routes from .base import routes
from .responses import resp from .responses import resp
@ -42,12 +43,17 @@ async def get_client(request: web.Request) -> web.Response:
return resp.found(client.to_dict()) return resp.found(client.to_dict())
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
homeserver = data.get("homeserver", None) homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None) access_token = data.get("access_token", None)
device_id = data.get("device_id", None) device_id = data.get("device_id", None)
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token, new_client = MatrixClient(
loop=Client.loop, client_session=Client.http_client) mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
loop=Client.loop,
client_session=Client.http_client,
)
try: try:
whoami = await new_client.whoami() whoami = await new_client.whoami()
except MatrixInvalidToken: except MatrixInvalidToken:
@ -64,13 +70,20 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
return resp.mxid_mismatch(whoami.user_id) return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id: elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id) return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token, db_instance = DBClient(
enabled=data.get("enabled", True), next_batch=SyncToken(""), id=whoami.user_id,
filter_id=FilterID(""), sync=data.get("sync", True), homeserver=homeserver,
autojoin=data.get("autojoin", True), online=data.get("online", True), access_token=access_token,
displayname=data.get("displayname", "disable"), enabled=data.get("enabled", True),
avatar_url=data.get("avatar_url", "disable"), next_batch=SyncToken(""),
device_id=device_id) filter_id=FilterID(""),
sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
online=data.get("online", True),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id,
)
client = Client(db_instance) client = Client(db_instance)
client.db_instance.insert() client.db_instance.insert()
await client.start() await client.start()
@ -79,9 +92,11 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try: try:
await client.update_access_details(data.get("access_token", None), await client.update_access_details(
data.get("homeserver", None), data.get("access_token", None),
data.get("device_id", None)) data.get("homeserver", None),
data.get("device_id", None),
)
except MatrixInvalidToken: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
except MatrixRequestError: except MatrixRequestError:
@ -91,9 +106,9 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
except ValueError as e: except ValueError as e:
str_err = str(e) str_err = str(e)
if str_err.startswith("MXID mismatch"): if str_err.startswith("MXID mismatch"):
return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):]) return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"): elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: "):]) return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
with client.db_instance.edit_mode(): with client.db_instance.edit_mode():
await client.update_avatar_url(data.get("avatar_url", None)) await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None)) await client.update_displayname(data.get("displayname", None))
@ -105,8 +120,9 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
return resp.updated(client.to_dict(), is_login=is_login) return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False async def _create_or_update_client(
) -> web.Response: user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = Client.get(user_id, None) client = Client.get(user_id, None)
if not client: if not client:
return await _create_client(user_id, data) return await _create_client(user_id, data)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,26 +13,26 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Tuple, NamedTuple, Optional from typing import Dict, NamedTuple, Optional, Tuple
from json import JSONDecodeError
from http import HTTPStatus from http import HTTPStatus
import hashlib from json import JSONDecodeError
import asyncio import asyncio
import hashlib
import hmac
import random import random
import string import string
import hmac
from aiohttp import web from aiohttp import web
from yarl import URL from yarl import URL
from mautrix.api import SynapseAdminPath, Method, Path from mautrix.api import Method, Path, SynapseAdminPath
from mautrix.errors import MatrixRequestError
from mautrix.client import ClientAPI from mautrix.client import ClientAPI
from mautrix.types import LoginType, LoginResponse from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import routes, get_config, get_loop from .base import get_config, get_loop, routes
from .client import _create_client, _create_or_update_client
from .responses import resp from .responses import resp
from .client import _create_or_update_client, _create_client
def known_homeservers() -> Dict[str, Dict[str, str]]: def known_homeservers() -> Dict[str, Dict[str, str]]:
@ -59,8 +59,9 @@ class AuthRequestInfo(NamedTuple):
truthy_strings = ("1", "true", "yes") truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], async def read_client_auth_request(
Optional[web.Response]]: request: web.Request,
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
server_name = request.match_info.get("server", None) server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None) server = known_homeservers().get(server_name, None)
if not server: if not server:
@ -81,21 +82,30 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
base_url = server["url"] base_url = server["url"]
except KeyError: except KeyError:
return None, resp.invalid_server return None, resp.invalid_server
return AuthRequestInfo( return (
server_name=server_name, AuthRequestInfo(
client=ClientAPI(base_url=base_url, loop=get_loop()), server_name=server_name,
secret=server.get("secret"), client=ClientAPI(base_url=base_url, loop=get_loop()),
username=username, secret=server.get("secret"),
password=password, username=username,
user_type=body.get("user_type", "bot"), password=password,
device_name=body.get("device_name", "Maubot"), user_type=body.get("user_type", "bot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings, device_name=body.get("device_name", "Maubot"),
sso=sso, update_client=request.query.get("update_client", "").lower() in truthy_strings,
), None sso=sso,
),
None,
)
def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False, def generate_mac(
user_type: str = None) -> str: secret: str,
nonce: str,
username: str,
password: str,
admin: bool = False,
user_type: str = None,
) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1) mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8")) mac.update(nonce.encode("utf-8"))
mac.update(b"\x00") mac.update(b"\x00")
@ -132,18 +142,24 @@ async def register(request: web.Request) -> web.Response:
try: try:
raw_res = await req.client.api.request(Method.POST, path, content=content) raw_res = await req.client.api.request(Method.POST, path, content=content)
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response({ return web.json_response(
"errcode": e.errcode, {
"error": e.message, "errcode": e.errcode,
"http_status": e.http_status, "error": e.message,
}, status=HTTPStatus.INTERNAL_SERVER_ERROR) "http_status": e.http_status,
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
login_res = LoginResponse.deserialize(raw_res) login_res = LoginResponse.deserialize(raw_res)
if req.update_client: if req.update_client:
return await _create_client(login_res.user_id, { return await _create_client(
"homeserver": str(req.client.api.base_url), login_res.user_id,
"access_token": login_res.access_token, {
"device_id": login_res.device_id, "homeserver": str(req.client.api.base_url),
}) "access_token": login_res.access_token,
"device_id": login_res.device_id,
},
)
return web.json_response(login_res.serialize()) return web.json_response(login_res.serialize())
@ -162,13 +178,17 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows() flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO): if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported return resp.sso_not_supported
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16)) waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config() cfg = get_config()
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/") public_url = (
/ "client/auth_external_sso/complete" / waiter_id) URL(cfg["server.public_url"])
sso_url = (req.client.api.base_url / cfg["server.base_path"].lstrip("/")
.with_path(str(Path.login.sso.redirect)) / "client/auth_external_sso/complete"
.with_query({"redirectUrl": str(public_url)})) / waiter_id
)
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, get_loop().create_future() sso_waiters[waiter_id] = req, get_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
@ -178,25 +198,40 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
device_id = f"maubot_{device_id}" device_id = f"maubot_{device_id}"
try: try:
if req.sso: if req.sso:
res = await req.client.login(token=login_token, login_type=LoginType.TOKEN, res = await req.client.login(
device_id=device_id, store_access_token=False, token=login_token,
initial_device_display_name=req.device_name) login_type=LoginType.TOKEN,
device_id=device_id,
store_access_token=False,
initial_device_display_name=req.device_name,
)
else: else:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD, res = await req.client.login(
password=req.password, device_id=device_id, identifier=req.username,
initial_device_display_name=req.device_name, login_type=LoginType.PASSWORD,
store_access_token=False) password=req.password,
device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False,
)
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response({ return web.json_response(
"errcode": e.errcode, {
"error": e.message, "errcode": e.errcode,
}, status=e.http_status) "error": e.message,
},
status=e.http_status,
)
if req.update_client: if req.update_client:
return await _create_or_update_client(res.user_id, { return await _create_or_update_client(
"homeserver": str(req.client.api.base_url), res.user_id,
"access_token": res.access_token, {
"device_id": res.device_id, "homeserver": str(req.client.api.base_url),
}, is_login=True) "access_token": res.access_token,
"device_id": res.device_id,
},
is_login=True,
)
return web.json_response(res.serialize()) return web.json_response(res.serialize())
@ -230,6 +265,8 @@ async def complete_sso(request: web.Request) -> web.Response:
return web.Response(status=400, text="Missing loginToken query parameter\n") return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError: except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n") return web.Response(status=500, text="Invalid state\n")
return web.Response(status=200, return web.Response(
text="Login token received, please return to your Maubot client. " status=200,
"This tab can be closed.\n") text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n",
)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web, client as http from aiohttp import client as http, web
from ...client import Client from ...client import Client
from .base import routes from .base import routes
@ -45,8 +45,9 @@ async def proxy(request: web.Request) -> web.StreamResponse:
headers["X-Forwarded-For"] = f"{host}:{port}" headers["X-Forwarded-For"] = f"{host}:{port}"
data = await request.read() data = await request.read()
async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers, async with http.request(
params=query, data=data) as proxy_resp: request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data
) as proxy_resp:
response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers) response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers)
await response.prepare(request) await response.prepare(request)
async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE): async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE):

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from string import Template from string import Template
from subprocess import run import asyncio
import re import re
from ruamel.yaml import YAML
from aiohttp import web from aiohttp import web
from ruamel.yaml import YAML
from .base import routes from .base import routes
@ -27,9 +27,7 @@ enabled = False
@routes.get("/debug/open") @routes.get("/debug/open")
async def check_enabled(_: web.Request) -> web.Response: async def check_enabled(_: web.Request) -> web.Response:
return web.json_response({ return web.json_response({"enabled": enabled})
"enabled": enabled,
})
try: try:
@ -40,7 +38,6 @@ try:
editor_command = Template(cfg["editor"]) editor_command = Template(cfg["editor"])
pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]] pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]]
@routes.post("/debug/open") @routes.post("/debug/open")
async def open_file(request: web.Request) -> web.Response: async def open_file(request: web.Request) -> web.Response:
data = await request.json() data = await request.json()
@ -51,13 +48,9 @@ try:
cmd = editor_command.substitute(path=path, line=data["line"]) cmd = editor_command.substitute(path=path, line=data["line"])
except (KeyError, ValueError): except (KeyError, ValueError):
return web.Response(status=400) return web.Response(status=400)
res = run(cmd, shell=True) res = await asyncio.create_subprocess_shell(cmd)
return web.json_response({ stdout, stderr = await res.communicate()
"return": res.returncode, return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr})
"stdout": res.stdout,
"stderr": res.stderr
})
enabled = True enabled = True
except Exception: except Exception:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,10 @@ from json import JSONDecodeError
from aiohttp import web from aiohttp import web
from ...client import Client
from ...db import DBPlugin from ...db import DBPlugin
from ...instance import PluginInstance from ...instance import PluginInstance
from ...loader import PluginLoader from ...loader import PluginLoader
from ...client import Client
from .base import routes from .base import routes
from .responses import resp from .responses import resp
@ -52,8 +52,13 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
PluginLoader.find(plugin_type) PluginLoader.find(plugin_type)
except KeyError: except KeyError:
return resp.plugin_type_not_found return resp.plugin_type_not_found
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True), db_instance = DBPlugin(
primary_user=primary_user, config=data.get("config", "")) id=instance_id,
type=plugin_type,
enabled=data.get("enabled", True),
primary_user=primary_user,
config=data.get("config", ""),
)
instance = PluginInstance(db_instance) instance = PluginInstance(db_instance)
instance.load() instance.load()
instance.db_instance.insert() instance.db_instance.insert()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, TYPE_CHECKING from __future__ import annotations
from datetime import datetime from datetime import datetime
from aiohttp import web from aiohttp import web
from sqlalchemy import Table, Column, asc, desc, exc from sqlalchemy import Column, Table, asc, desc, exc
from sqlalchemy.orm import Query
from sqlalchemy.engine.result import ResultProxy, RowProxy from sqlalchemy.engine.result import ResultProxy, RowProxy
from sqlalchemy.orm import Query
from ...instance import PluginInstance from ...instance import PluginInstance
from .base import routes from .base import routes
@ -34,23 +35,26 @@ async def get_database(request: web.Request) -> web.Response:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
return resp.plugin_has_no_database return resp.plugin_has_no_database
if TYPE_CHECKING: table: Table
table: Table column: Column
column: Column return web.json_response(
return web.json_response({ {
table.name: { table.name: {
"columns": { "columns": {
column.name: { column.name: {
"type": str(column.type), "type": str(column.type),
"unique": column.unique or False, "unique": column.unique or False,
"default": column.default, "default": column.default,
"nullable": column.nullable, "nullable": column.nullable,
"primary": column.primary_key, "primary": column.primary_key,
"autoincrement": column.autoincrement, "autoincrement": column.autoincrement,
} for column in table.columns }
}, for column in table.columns
} for table in instance.get_db_tables().values() },
}) }
for table in instance.get_db_tables().values()
}
)
def check_type(val): def check_type(val):
@ -74,9 +78,12 @@ async def get_table(request: web.Request) -> web.Response:
return resp.table_not_found return resp.table_not_found
try: try:
order = [tuple(order.split(":")) for order in request.query.getall("order")] order = [tuple(order.split(":")) for order in request.query.getall("order")]
order = [(asc if sort.lower() == "asc" else desc)(table.columns[column]) order = [
if sort else table.columns[column] (asc if sort.lower() == "asc" else desc)(table.columns[column])
for column, sort in order] if sort
else table.columns[column]
for column, sort in order
]
except KeyError: except KeyError:
order = [] order = []
limit = int(request.query.get("limit", 100)) limit = int(request.query.get("limit", 100))
@ -96,12 +103,12 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"] sql_query = data["query"]
except KeyError: except KeyError:
return resp.query_missing return resp.query_missing
return execute_query(instance, sql_query, return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False))
rows_as_dict=data.get("rows_as_dict", False))
def execute_query(instance: PluginInstance, sql_query: Union[str, Query], def execute_query(
rows_as_dict: bool = False) -> web.Response: instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False
) -> web.Response:
try: try:
res: ResultProxy = instance.inst_db.execute(sql_query) res: ResultProxy = instance.inst_db.execute(sql_query)
except exc.IntegrityError as e: except exc.IntegrityError as e:
@ -114,10 +121,14 @@ def execute_query(instance: PluginInstance, sql_query: Union[str, Query],
} }
if res.returns_rows: if res.returns_rows:
row: RowProxy row: RowProxy
data["rows"] = [({key: check_type(value) for key, value in row.items()} data["rows"] = [
if rows_as_dict (
else [check_type(value) for value in row]) {key: check_type(value) for key, value in row.items()}
for row in res] if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
data["columns"] = res.keys() data["columns"] = res.keys()
else: else:
data["rowcount"] = res.rowcount data["rowcount"] = res.rowcount

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,31 +13,60 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Deque, List from __future__ import annotations
from datetime import datetime
from collections import deque from collections import deque
import logging from datetime import datetime
import asyncio import asyncio
import logging
from aiohttp import web from aiohttp import web, web_ws
from .base import routes, get_loop
from .auth import is_valid_token from .auth import is_valid_token
from .base import get_loop, routes
BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName", BUILTIN_ATTRS = {
"levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name", "args",
"pathname", "process", "processName", "relativeCreated", "stack_info", "thread", "asctime",
"threadName"} "created",
INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name", "exc_info",
"pathname"} "exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
INCLUDE_ATTRS = {
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"name",
"pathname",
}
EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS
MAX_LINES = 2048 MAX_LINES = 2048
class LogCollector(logging.Handler): class LogCollector(logging.Handler):
lines: Deque[dict] lines: deque[dict]
formatter: logging.Formatter formatter: logging.Formatter
listeners: List[web.WebSocketResponse] listeners: list[web.WebSocketResponse]
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
def __init__(self, level=logging.NOTSET) -> None: def __init__(self, level=logging.NOTSET) -> None:
@ -56,9 +85,7 @@ class LogCollector(logging.Handler):
# JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license) # JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license)
# https://github.com/marselester/json-log-formatter # https://github.com/marselester/json-log-formatter
content = { content = {
name: value name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS
for name, value in record.__dict__.items()
if name not in EXCLUDE_ATTRS
} }
content["id"] = str(record.relativeCreated) content["id"] = str(record.relativeCreated)
content["msg"] = record.getMessage() content["msg"] = record.getMessage()
@ -119,6 +146,7 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
asyncio.ensure_future(close_if_not_authenticated()) asyncio.ensure_future(close_if_not_authenticated())
try: try:
msg: web_ws.WSMessage
async for msg in ws: async for msg in ws:
if msg.type != web.WSMsgType.TEXT: if msg.type != web.WSMsgType.TEXT:
continue continue

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,10 @@ import json
from aiohttp import web from aiohttp import web
from .base import routes, get_config
from .responses import resp
from .auth import create_token from .auth import create_token
from .base import get_config, routes
from .responses import resp
@routes.post("/auth/login") @routes.post("/auth/login")
async def login(request: web.Request) -> web.Response: async def login(request: web.Request) -> web.Response:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Awaitable from typing import Awaitable, Callable
import logging
import base64 import base64
import logging
from aiohttp import web from aiohttp import web
from .responses import resp
from .auth import check_token from .auth import check_token
from .base import get_config from .base import get_config
from .responses import resp
Handler = Callable[[web.Request], Awaitable[web.Response]] Handler = Callable[[web.Request], Awaitable[web.Response]]
log = logging.getLogger("maubot.server") log = logging.getLogger("maubot.server")
@ -29,7 +29,7 @@ log = logging.getLogger("maubot.server")
@web.middleware @web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response: async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len(get_config()["server.base_path"]):] subpath = request.path[len(get_config()["server.base_path"]) :]
if ( if (
subpath.startswith("/auth/") subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/") or subpath.startswith("/client/auth_external_sso/complete/")
@ -52,15 +52,18 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
return resp.path_not_found return resp.path_not_found
elif ex.status_code == 405: elif ex.status_code == 405:
return resp.method_not_allowed return resp.method_not_allowed
return web.json_response({ return web.json_response(
"httpexception": { {
"headers": {key: value for key, value in ex.headers.items()}, "httpexception": {
"class": type(ex).__name__, "headers": {key: value for key, value in ex.headers.items()},
"body": ex.text or base64.b64encode(ex.body) "class": type(ex).__name__,
"body": ex.text or base64.b64encode(ex.body),
},
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}",
"errcode": f"unhandled_http_{ex.status}",
}, },
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}", status=ex.status,
"errcode": f"unhandled_http_{ex.status}", )
}, status=ex.status)
except Exception: except Exception:
log.exception("Error in handler") log.exception("Error in handler")
return resp.internal_server_error return resp.internal_server_error

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,9 @@ import traceback
from aiohttp import web from aiohttp import web
from ...loader import PluginLoader, MaubotZipImportError from ...loader import MaubotZipImportError, PluginLoader
from .responses import resp
from .base import routes from .base import routes
from .responses import resp
@routes.get("/plugins") @routes.get("/plugins")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -15,16 +15,16 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from io import BytesIO from io import BytesIO
from time import time from time import time
import traceback
import os.path import os.path
import re import re
import traceback
from aiohttp import web from aiohttp import web
from packaging.version import Version from packaging.version import Version
from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError from ...loader import MaubotZipImportError, PluginLoader, ZippedPluginLoader
from .base import get_config, routes
from .responses import resp from .responses import resp
from .base import routes, get_config
@routes.put("/plugin/{id}") @routes.put("/plugin/{id}")
@ -78,15 +78,20 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R
return resp.created(plugin.to_dict()) return resp.created(plugin.to_dict())
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, async def upload_replacement_plugin(
new_version: Version) -> web.Response: plugin: ZippedPluginLoader, content: bytes, new_version: Version
) -> web.Response:
dirname = os.path.dirname(plugin.path) dirname = os.path.dirname(plugin.path)
old_filename = os.path.basename(plugin.path) old_filename = os.path.basename(plugin.path)
if str(plugin.meta.version) in old_filename: if str(plugin.meta.version) in old_filename:
replacement = (str(new_version) if plugin.meta.version != new_version replacement = (
else f"{new_version}-ts{int(time())}") str(new_version)
filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", if plugin.meta.version != new_version
replacement, old_filename) else f"{new_version}-ts{int(time())}"
)
filename = re.sub(
f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename
)
else: else:
filename = old_filename.rstrip(".mbp") filename = old_filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp" filename = f"{filename}-v{new_version}.mbp"

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,296 +16,416 @@
from http import HTTPStatus from http import HTTPStatus
from aiohttp import web from aiohttp import web
from sqlalchemy.exc import OperationalError, IntegrityError from sqlalchemy.exc import IntegrityError, OperationalError
class _Response: class _Response:
@property @property
def body_not_json(self) -> web.Response: def body_not_json(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Request body is not JSON", {
"errcode": "body_not_json", "error": "Request body is not JSON",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "body_not_json",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def plugin_type_required(self) -> web.Response: def plugin_type_required(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Plugin type is required when creating plugin instances", {
"errcode": "plugin_type_required", "error": "Plugin type is required when creating plugin instances",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "plugin_type_required",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def primary_user_required(self) -> web.Response: def primary_user_required(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Primary user is required when creating plugin instances", {
"errcode": "primary_user_required", "error": "Primary user is required when creating plugin instances",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "primary_user_required",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_access_token(self) -> web.Response: def bad_client_access_token(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Invalid access token", {
"errcode": "bad_client_access_token", "error": "Invalid access token",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "bad_client_access_token",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_access_details(self) -> web.Response: def bad_client_access_details(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Invalid homeserver or access token", {
"errcode": "bad_client_access_details" "error": "Invalid homeserver or access token",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "bad_client_access_details",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_client_connection_details(self) -> web.Response: def bad_client_connection_details(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Could not connect to homeserver", {
"errcode": "bad_client_connection_details" "error": "Could not connect to homeserver",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "bad_client_connection_details",
},
status=HTTPStatus.BAD_REQUEST,
)
def mxid_mismatch(self, found: str) -> web.Response: def mxid_mismatch(self, found: str) -> web.Response:
return web.json_response({ return web.json_response(
"error": "The Matrix user ID of the client and the user ID of the access token don't " {
f"match. Access token is for user {found}", "error": (
"errcode": "mxid_mismatch", "The Matrix user ID of the client and the user ID of the access token don't "
}, status=HTTPStatus.BAD_REQUEST) f"match. Access token is for user {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
def device_id_mismatch(self, found: str) -> web.Response: def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response({ return web.json_response(
"error": "The Matrix device ID of the client and the device ID of the access token " {
f"don't match. Access token is for device {found}", "error": (
"errcode": "mxid_mismatch", "The Matrix device ID of the client and the device ID of the access token "
}, status=HTTPStatus.BAD_REQUEST) f"don't match. Access token is for device {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def pid_mismatch(self) -> web.Response: def pid_mismatch(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "The ID in the path does not match the ID of the uploaded plugin", {
"errcode": "pid_mismatch", "error": "The ID in the path does not match the ID of the uploaded plugin",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "pid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def username_or_password_missing(self) -> web.Response: def username_or_password_missing(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Username or password missing", {
"errcode": "username_or_password_missing", "error": "Username or password missing",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "username_or_password_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def query_missing(self) -> web.Response: def query_missing(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Query missing", {
"errcode": "query_missing", "error": "Query missing",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "query_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def sql_operational_error(error: OperationalError, query: str) -> web.Response: def sql_operational_error(error: OperationalError, query: str) -> web.Response:
return web.json_response({ return web.json_response(
"ok": False, {
"query": query, "ok": False,
"error": str(error.orig), "query": query,
"full_error": str(error), "error": str(error.orig),
"errcode": "sql_operational_error", "full_error": str(error),
}, status=HTTPStatus.BAD_REQUEST) "errcode": "sql_operational_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def sql_integrity_error(error: IntegrityError, query: str) -> web.Response: def sql_integrity_error(error: IntegrityError, query: str) -> web.Response:
return web.json_response({ return web.json_response(
"ok": False, {
"query": query, "ok": False,
"error": str(error.orig), "query": query,
"full_error": str(error), "error": str(error.orig),
"errcode": "sql_integrity_error", "full_error": str(error),
}, status=HTTPStatus.BAD_REQUEST) "errcode": "sql_integrity_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def bad_auth(self) -> web.Response: def bad_auth(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Invalid username or password", {
"errcode": "invalid_auth", "error": "Invalid username or password",
}, status=HTTPStatus.UNAUTHORIZED) "errcode": "invalid_auth",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def no_token(self) -> web.Response: def no_token(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Authorization token missing", {
"errcode": "auth_token_missing", "error": "Authorization token missing",
}, status=HTTPStatus.UNAUTHORIZED) "errcode": "auth_token_missing",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def invalid_token(self) -> web.Response: def invalid_token(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Invalid authorization token", {
"errcode": "auth_token_invalid", "error": "Invalid authorization token",
}, status=HTTPStatus.UNAUTHORIZED) "errcode": "auth_token_invalid",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property @property
def plugin_not_found(self) -> web.Response: def plugin_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Plugin not found", {
"errcode": "plugin_not_found", "error": "Plugin not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "plugin_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def client_not_found(self) -> web.Response: def client_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Client not found", {
"errcode": "client_not_found", "error": "Client not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "client_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def primary_user_not_found(self) -> web.Response: def primary_user_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Client for given primary user not found", {
"errcode": "primary_user_not_found", "error": "Client for given primary user not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "primary_user_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def instance_not_found(self) -> web.Response: def instance_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Plugin instance not found", {
"errcode": "instance_not_found", "error": "Plugin instance not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "instance_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def plugin_type_not_found(self) -> web.Response: def plugin_type_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Given plugin type not found", {
"errcode": "plugin_type_not_found", "error": "Given plugin type not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "plugin_type_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def path_not_found(self) -> web.Response: def path_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Resource not found", {
"errcode": "resource_not_found", "error": "Resource not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "resource_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def server_not_found(self) -> web.Response: def server_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Registration target server not found", {
"errcode": "server_not_found", "error": "Registration target server not found",
}, status=HTTPStatus.NOT_FOUND) "errcode": "server_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def registration_secret_not_found(self) -> web.Response: def registration_secret_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Config does not have a registration secret for that server", {
"errcode": "registration_secret_not_found", "error": "Config does not have a registration secret for that server",
}, status=HTTPStatus.NOT_FOUND) "errcode": "registration_secret_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property @property
def registration_no_sso(self) -> web.Response: def registration_no_sso(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "The register operation is only for registering with a password", {
"errcode": "registration_no_sso", "error": "The register operation is only for registering with a password",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "registration_no_sso",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def sso_not_supported(self) -> web.Response: def sso_not_supported(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "That server does not seem to support single sign-on", {
"errcode": "sso_not_supported", "error": "That server does not seem to support single sign-on",
}, status=HTTPStatus.FORBIDDEN) "errcode": "sso_not_supported",
},
status=HTTPStatus.FORBIDDEN,
)
@property @property
def plugin_has_no_database(self) -> web.Response: def plugin_has_no_database(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Given plugin does not have a database", {
"errcode": "plugin_has_no_database", "error": "Given plugin does not have a database",
}) "errcode": "plugin_has_no_database",
}
)
@property @property
def table_not_found(self) -> web.Response: def table_not_found(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Given table not found in plugin database", {
"errcode": "table_not_found", "error": "Given table not found in plugin database",
}) "errcode": "table_not_found",
}
)
@property @property
def method_not_allowed(self) -> web.Response: def method_not_allowed(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Method not allowed", {
"errcode": "method_not_allowed", "error": "Method not allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED) "errcode": "method_not_allowed",
},
status=HTTPStatus.METHOD_NOT_ALLOWED,
)
@property @property
def user_exists(self) -> web.Response: def user_exists(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "There is already a client with the user ID of that token", {
"errcode": "user_exists", "error": "There is already a client with the user ID of that token",
}, status=HTTPStatus.CONFLICT) "errcode": "user_exists",
},
status=HTTPStatus.CONFLICT,
)
@property @property
def plugin_exists(self) -> web.Response: def plugin_exists(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "A plugin with the same ID as the uploaded plugin already exists", {
"errcode": "plugin_exists" "error": "A plugin with the same ID as the uploaded plugin already exists",
}, status=HTTPStatus.CONFLICT) "errcode": "plugin_exists",
},
status=HTTPStatus.CONFLICT,
)
@property @property
def plugin_in_use(self) -> web.Response: def plugin_in_use(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Plugin instances of this type still exist", {
"errcode": "plugin_in_use", "error": "Plugin instances of this type still exist",
}, status=HTTPStatus.PRECONDITION_FAILED) "errcode": "plugin_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@property @property
def client_in_use(self) -> web.Response: def client_in_use(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Plugin instances with this client as their primary user still exist", {
"errcode": "client_in_use", "error": "Plugin instances with this client as their primary user still exist",
}, status=HTTPStatus.PRECONDITION_FAILED) "errcode": "client_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@staticmethod @staticmethod
def plugin_import_error(error: str, stacktrace: str) -> web.Response: def plugin_import_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({ return web.json_response(
"error": error, {
"stacktrace": stacktrace, "error": error,
"errcode": "plugin_invalid", "stacktrace": stacktrace,
}, status=HTTPStatus.BAD_REQUEST) "errcode": "plugin_invalid",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def plugin_reload_error(error: str, stacktrace: str) -> web.Response: def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({ return web.json_response(
"error": error, {
"stacktrace": stacktrace, "error": error,
"errcode": "plugin_reload_fail", "stacktrace": stacktrace,
}, status=HTTPStatus.INTERNAL_SERVER_ERROR) "errcode": "plugin_reload_fail",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def internal_server_error(self) -> web.Response: def internal_server_error(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Internal server error", {
"errcode": "internal_server_error", "error": "Internal server error",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR) "errcode": "internal_server_error",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def invalid_server(self) -> web.Response: def invalid_server(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Invalid registration server object in maubot configuration", {
"errcode": "invalid_server", "error": "Invalid registration server object in maubot configuration",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR) "errcode": "invalid_server",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property @property
def unsupported_plugin_loader(self) -> web.Response: def unsupported_plugin_loader(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Existing plugin with same ID uses unsupported plugin loader", {
"errcode": "unsupported_plugin_loader", "error": "Existing plugin with same ID uses unsupported plugin loader",
}, status=HTTPStatus.BAD_REQUEST) "errcode": "unsupported_plugin_loader",
},
status=HTTPStatus.BAD_REQUEST,
)
@property @property
def not_implemented(self) -> web.Response: def not_implemented(self) -> web.Response:
return web.json_response({ return web.json_response(
"error": "Not implemented", {
"errcode": "not_implemented", "error": "Not implemented",
}, status=HTTPStatus.NOT_IMPLEMENTED) "errcode": "not_implemented",
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
@property @property
def ok(self) -> web.Response: def ok(self) -> web.Response:
return web.json_response({ return web.json_response(
"success": True, {"success": True},
}, status=HTTPStatus.OK) status=HTTPStatus.OK,
)
@property @property
def deleted(self) -> web.Response: def deleted(self) -> web.Response:
@ -320,15 +440,10 @@ class _Response:
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK) return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def logged_in(self, token: str) -> web.Response: def logged_in(self, token: str) -> web.Response:
return self.found({ return self.found({"token": token})
"token": token,
})
def pong(self, user: str, features: dict) -> web.Response: def pong(self, user: str, features: dict) -> web.Response:
return self.found({ return self.found({"username": user, "features": features})
"username": user,
"features": features,
})
@staticmethod @staticmethod
def created(data: dict) -> web.Response: def created(data: dict) -> web.Response:

View File

@ -1,6 +1,6 @@
<!-- <!--
maubot - A plugin-based Matrix bot system. maubot - A plugin-based Matrix bot system.
Copyright (C) 2019 Tulir Asokan Copyright (C) 2022 Tulir Asokan
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by // it under the terms of the GNU General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system. // maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2022 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,23 +13,36 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Awaitable, Optional, Tuple, List from __future__ import annotations
from typing import Awaitable
from html import escape from html import escape
import asyncio import asyncio
import attr import attr
from mautrix.client import Client as MatrixClient, SyncStream from mautrix.client import Client as MatrixClient, SyncStream
from mautrix.util.formatter import MatrixParser, MarkdownString, EntityType
from mautrix.util import markdown
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo, EncryptedEvent)
from mautrix.errors import DecryptionError from mautrix.errors import DecryptionError
from mautrix.types import (
EncryptedEvent,
Event,
EventID,
EventType,
Format,
MessageEvent,
MessageEventContent,
MessageType,
RelatesTo,
RoomID,
TextMessageEventContent,
)
from mautrix.util import markdown
from mautrix.util.formatter import EntityType, MarkdownString, MatrixParser
class HumanReadableString(MarkdownString): class HumanReadableString(MarkdownString):
def format(self, entity_type: EntityType, **kwargs) -> 'MarkdownString': def format(self, entity_type: EntityType, **kwargs) -> MarkdownString:
if entity_type == EntityType.URL and kwargs['url'] != self.text: if entity_type == EntityType.URL and kwargs["url"] != self.text:
self.text = f"{self.text} ({kwargs['url']})" self.text = f"{self.text} ({kwargs['url']})"
return self return self
return super(HumanReadableString, self).format(entity_type, **kwargs) return super(HumanReadableString, self).format(entity_type, **kwargs)
@ -39,8 +52,9 @@ class MaubotHTMLParser(MatrixParser[HumanReadableString]):
fs = HumanReadableString fs = HumanReadableString
async def parse_formatted(message: str, allow_html: bool = False, render_markdown: bool = True async def parse_formatted(
) -> Tuple[str, str]: message: str, allow_html: bool = False, render_markdown: bool = True
) -> tuple[str, str]:
if render_markdown: if render_markdown:
html = markdown.render(message, allow_html=allow_html) html = markdown.render(message, allow_html=allow_html)
elif allow_html: elif allow_html:
@ -51,19 +65,25 @@ async def parse_formatted(message: str, allow_html: bool = False, render_markdow
class MaubotMessageEvent(MessageEvent): class MaubotMessageEvent(MessageEvent):
client: 'MaubotMatrixClient' client: MaubotMatrixClient
disable_reply: bool disable_reply: bool
def __init__(self, base: MessageEvent, client: 'MaubotMatrixClient'): def __init__(self, base: MessageEvent, client: MaubotMatrixClient):
super().__init__(**{a.name.lstrip("_"): getattr(base, a.name) super().__init__(
for a in attr.fields(MessageEvent)}) **{a.name.lstrip("_"): getattr(base, a.name) for a in attr.fields(MessageEvent)}
)
self.client = client self.client = client
self.disable_reply = client.disable_replies self.disable_reply = client.disable_replies
async def respond(self, content: Union[str, MessageEventContent], async def respond(
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True, self,
allow_html: bool = False, reply: Union[bool, str] = False, content: str | MessageEventContent,
edits: Optional[Union[EventID, MessageEvent]] = None) -> EventID: event_type: EventType = EventType.ROOM_MESSAGE,
markdown: bool = True,
allow_html: bool = False,
reply: bool | str = False,
edits: EventID | MessageEvent | None = None,
) -> EventID:
if isinstance(content, str): if isinstance(content, str):
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content) content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content)
if allow_html or markdown: if allow_html or markdown:
@ -77,18 +97,25 @@ class MaubotMessageEvent(MessageEvent):
if reply != "force" and self.disable_reply: if reply != "force" and self.disable_reply:
content.body = f"{self.sender}: {content.body}" content.body = f"{self.sender}: {content.body}"
fmt_body = content.formatted_body or escape(content.body).replace("\n", "<br>") fmt_body = content.formatted_body or escape(content.body).replace("\n", "<br>")
content.formatted_body = (f'<a href="https://matrix.to/#/{self.sender}">' content.formatted_body = (
f'{self.sender}' f'<a href="https://matrix.to/#/{self.sender}">'
f'</a>: {fmt_body}') f"{self.sender}"
f"</a>: {fmt_body}"
)
else: else:
content.set_reply(self) content.set_reply(self)
return await self.client.send_message_event(self.room_id, event_type, content) return await self.client.send_message_event(self.room_id, event_type, content)
def reply(self, content: Union[str, MessageEventContent], def reply(
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True, self,
allow_html: bool = False) -> Awaitable[EventID]: content: str | MessageEventContent,
return self.respond(content, event_type, markdown=markdown, reply=True, event_type: EventType = EventType.ROOM_MESSAGE,
allow_html=allow_html) markdown: bool = True,
allow_html: bool = False,
) -> Awaitable[EventID]:
return self.respond(
content, event_type, markdown=markdown, reply=True, allow_html=allow_html
)
def mark_read(self) -> Awaitable[None]: def mark_read(self) -> Awaitable[None]:
return self.client.send_receipt(self.room_id, self.event_id, "m.read") return self.client.send_receipt(self.room_id, self.event_id, "m.read")
@ -96,11 +123,16 @@ class MaubotMessageEvent(MessageEvent):
def react(self, key: str) -> Awaitable[EventID]: def react(self, key: str) -> Awaitable[EventID]:
return self.client.react(self.room_id, self.event_id, key) return self.client.react(self.room_id, self.event_id, key)
def edit(self, content: Union[str, MessageEventContent], def edit(
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True, self,
allow_html: bool = False) -> Awaitable[EventID]: content: str | MessageEventContent,
return self.respond(content, event_type, markdown=markdown, edits=self, event_type: EventType = EventType.ROOM_MESSAGE,
allow_html=allow_html) markdown: bool = True,
allow_html: bool = False,
) -> Awaitable[EventID]:
return self.respond(
content, event_type, markdown=markdown, edits=self, allow_html=allow_html
)
class MaubotMatrixClient(MatrixClient): class MaubotMatrixClient(MatrixClient):
@ -110,11 +142,17 @@ class MaubotMatrixClient(MatrixClient):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.disable_replies = False self.disable_replies = False
async def send_markdown(self, room_id: RoomID, markdown: str, *, allow_html: bool = False, async def send_markdown(
msgtype: MessageType = MessageType.TEXT, self,
edits: Optional[Union[EventID, MessageEvent]] = None, room_id: RoomID,
relates_to: Optional[RelatesTo] = None, **kwargs markdown: str,
) -> EventID: *,
allow_html: bool = False,
msgtype: MessageType = MessageType.TEXT,
edits: EventID | MessageEvent | None = None,
relates_to: RelatesTo | None = None,
**kwargs,
) -> EventID:
content = TextMessageEventContent(msgtype=msgtype, format=Format.HTML) content = TextMessageEventContent(msgtype=msgtype, format=Format.HTML)
content.body, content.formatted_body = await parse_formatted( content.body, content.formatted_body = await parse_formatted(
markdown, allow_html=allow_html markdown, allow_html=allow_html
@ -127,7 +165,7 @@ class MaubotMatrixClient(MatrixClient):
content.set_edit(edits) content.set_edit(edits)
return await self.send_message(room_id, content, **kwargs) return await self.send_message(room_id, content, **kwargs)
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]: def dispatch_event(self, event: Event, source: SyncStream) -> list[asyncio.Task]:
if isinstance(event, MessageEvent) and not isinstance(event, MaubotMessageEvent): if isinstance(event, MessageEvent) and not isinstance(event, MaubotMessageEvent):
event = MaubotMessageEvent(event, self) event = MaubotMessageEvent(event, self)
elif source != SyncStream.INTERNAL: elif source != SyncStream.INTERNAL:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,38 +13,50 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Optional, TYPE_CHECKING from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC from abc import ABC
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession from aiohttp import ClientSession
from sqlalchemy.engine.base import Engine
from yarl import URL from yarl import URL
if TYPE_CHECKING: if TYPE_CHECKING:
from mautrix.util.logging import TraceLogger
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
from mautrix.util.logging import TraceLogger
from .client import MaubotMatrixClient from .client import MaubotMatrixClient
from .plugin_server import PluginWebApp
from .loader import BasePluginLoader from .loader import BasePluginLoader
from .plugin_server import PluginWebApp
class Plugin(ABC): class Plugin(ABC):
client: 'MaubotMatrixClient' client: MaubotMatrixClient
http: ClientSession http: ClientSession
id: str id: str
log: 'TraceLogger' log: TraceLogger
loop: AbstractEventLoop loop: AbstractEventLoop
loader: 'BasePluginLoader' loader: BasePluginLoader
config: Optional['BaseProxyConfig'] config: BaseProxyConfig | None
database: Optional[Engine] database: Engine | None
webapp: Optional['PluginWebApp'] webapp: PluginWebApp | None
webapp_url: Optional[URL] webapp_url: URL | None
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, def __init__(
instance_id: str, log: 'TraceLogger', config: Optional['BaseProxyConfig'], self,
database: Optional[Engine], webapp: Optional['PluginWebApp'], client: MaubotMatrixClient,
webapp_url: Optional[str], loader: 'BasePluginLoader') -> None: loop: AbstractEventLoop,
http: ClientSession,
instance_id: str,
log: TraceLogger,
config: BaseProxyConfig | None,
database: Engine | None,
webapp: PluginWebApp | None,
webapp_url: str | None,
loader: BasePluginLoader,
) -> None:
self.client = client self.client = client
self.loop = loop self.loop = loop
self.http = http self.http = http
@ -74,8 +86,10 @@ class Plugin(ABC):
else: else:
if len(web_handlers) > 0 and self.webapp is None: if len(web_handlers) > 0 and self.webapp is None:
if not warned_webapp: if not warned_webapp:
self.log.warning(f"{type(obj).__name__} has web handlers, but the webapp" self.log.warning(
" feature isn't enabled in the plugin's maubot.yaml") f"{type(obj).__name__} has web handlers, but the webapp"
" feature isn't enabled in the plugin's maubot.yaml"
)
warned_webapp = True warned_webapp = True
continue continue
for method, path, kwargs in web_handlers: for method, path, kwargs in web_handlers:
@ -107,7 +121,7 @@ class Plugin(ABC):
pass pass
@classmethod @classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]: def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None return None
def on_external_config_update(self) -> None: def on_external_config_update(self) -> None:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,10 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Callable, Awaitable from __future__ import annotations
from typing import Awaitable, Callable
from functools import partial from functools import partial
from aiohttp import web, hdrs from aiohttp import hdrs, web
from yarl import URL from yarl import URL
Handler = Callable[[web.Request], Awaitable[web.Response]] Handler = Callable[[web.Request], Awaitable[web.Response]]
@ -26,7 +28,7 @@ Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher): class PluginWebApp(web.UrlDispatcher):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._middleware: List[Middleware] = [] self._middleware: list[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None: def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware) self._middleware.append(middleware)
@ -58,8 +60,8 @@ class PluginWebApp(web.UrlDispatcher):
class PrefixResource(web.Resource): class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None): def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix assert not prefix or prefix.startswith("/"), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix assert prefix in ("", "/") or not prefix.endswith("/"), prefix
super().__init__(name=name) super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path self._prefix = URL.build(path=prefix).raw_path
@ -68,14 +70,14 @@ class PrefixResource(web.Resource):
return self._prefix return self._prefix
def get_info(self): def get_info(self):
return {'path': self._prefix} return {"path": self._prefix}
def url_for(self): def url_for(self):
return URL.build(path=self._prefix, encoded=True) return URL.build(path=self._prefix, encoded=True)
def add_prefix(self, prefix): def add_prefix(self, prefix):
assert prefix.startswith('/') assert prefix.startswith("/")
assert not prefix.endswith('/') assert not prefix.endswith("/")
assert len(prefix) > 1 assert len(prefix) > 1
self._prefix = prefix + self._prefix self._prefix = prefix + self._prefix
@ -84,4 +86,3 @@ class PrefixResource(web.Resource):
def raw_match(self, path: str) -> bool: def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix) return path and path.startswith(self._prefix)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,36 +13,40 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Dict from __future__ import annotations
import logging
import asyncio import asyncio
import json import json
from yarl import URL import logging
from aiohttp import web, hdrs from aiohttp import hdrs, web
from aiohttp.abc import AbstractAccessLogger from aiohttp.abc import AbstractAccessLogger
from yarl import URL
import pkg_resources import pkg_resources
from mautrix.api import PathBuilder, Method from mautrix.api import Method, PathBuilder
from .config import Config
from .plugin_server import PrefixResource, PluginWebApp
from .__meta__ import __version__ from .__meta__ import __version__
from .config import Config
from .plugin_server import PluginWebApp, PrefixResource
class AccessLogger(AbstractAccessLogger): class AccessLogger(AbstractAccessLogger):
def log(self, request: web.Request, response: web.Response, time: int): def log(self, request: web.Request, response: web.Response, time: int):
self.logger.info(f'{request.remote} "{request.method} {request.path} ' self.logger.info(
f'{response.status} {response.body_length} ' f'{request.remote} "{request.method} {request.path} '
f'in {round(time, 4)}s"') f"{response.status} {response.body_length} "
f'in {round(time, 4)}s"'
)
class MaubotServer: class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server") log: logging.Logger = logging.getLogger("maubot.server")
plugin_routes: Dict[str, PluginWebApp] plugin_routes: dict[str, PluginWebApp]
def __init__(self, management_api: web.Application, config: Config, def __init__(
loop: asyncio.AbstractEventLoop) -> None: self, management_api: web.Application, config: Config, loop: asyncio.AbstractEventLoop
) -> None:
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024) self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024)
self.config = config self.config = config
@ -57,13 +61,15 @@ class MaubotServer:
async def handle_plugin_path(self, request: web.Request) -> web.StreamResponse: async def handle_plugin_path(self, request: web.Request) -> web.StreamResponse:
for path, app in self.plugin_routes.items(): for path, app in self.plugin_routes.items():
if request.path.startswith(path): if request.path.startswith(path):
request = request.clone(rel_url=request.rel_url request = request.clone(
.with_path(request.rel_url.path[len(path):]) rel_url=request.rel_url.with_path(
.with_query(request.query_string)) request.rel_url.path[len(path) :]
).with_query(request.query_string)
)
return await app.handle(request) return await app.handle(request)
return web.Response(status=404) return web.Response(status=404)
def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]: def get_instance_subapp(self, instance_id: str) -> tuple[PluginWebApp, str]:
subpath = self.config["server.plugin_base_path"] + instance_id subpath = self.config["server.plugin_base_path"] + instance_id
url = self.config["server.public_url"] + subpath url = self.config["server.public_url"] + subpath
try: try:
@ -94,8 +100,9 @@ class MaubotServer:
ui_base = self.config["server.ui_base_path"] ui_base = self.config["server.ui_base_path"]
if ui_base == "/": if ui_base == "/":
ui_base = "" ui_base = ""
directory = (self.config["server.override_resource_path"] directory = self.config[
or pkg_resources.resource_filename("maubot", "management/frontend/build")) "server.override_resource_path"
] or pkg_resources.resource_filename("maubot", "management/frontend/build")
self.app.router.add_static(f"{ui_base}/static", f"{directory}/static") self.app.router.add_static(f"{ui_base}/static", f"{directory}/static")
self.setup_static_root_files(directory, ui_base) self.setup_static_root_files(directory, ui_base)
@ -115,8 +122,9 @@ class MaubotServer:
raise web.HTTPFound(f"{ui_base}/") raise web.HTTPFound(f"{ui_base}/")
self.app.middlewares.append(frontend_404_middleware) self.app.middlewares.append(frontend_404_middleware)
self.app.router.add_get(f"{ui_base}/", lambda _: web.Response(body=index_html, self.app.router.add_get(
content_type="text/html")) f"{ui_base}/", lambda _: web.Response(body=index_html, content_type="text/html")
)
self.app.router.add_get(ui_base, ui_base_redirect) self.app.router.add_get(ui_base, ui_base_redirect)
def setup_static_root_files(self, directory: str, ui_base: str) -> None: def setup_static_root_files(self, directory: str, ui_base: str) -> None:
@ -128,8 +136,9 @@ class MaubotServer:
for file, mime in files.items(): for file, mime in files.items():
with open(f"{directory}/{file}", "rb") as stream: with open(f"{directory}/{file}", "rb") as stream:
data = stream.read() data = stream.read()
self.app.router.add_get(f"{ui_base}/{file}", lambda _: web.Response(body=data, self.app.router.add_get(
content_type=mime)) f"{ui_base}/{file}", lambda _: web.Response(body=data, content_type=mime)
)
# also set up a resource path for the public url path prefix config # also set up a resource path for the public url path prefix config
# cut the prefix path from public_url # cut the prefix path from public_url
@ -143,8 +152,12 @@ class MaubotServer:
api_path = f"{public_url_path}{base_path}" api_path = f"{public_url_path}{base_path}"
path_prefix_response_body = json.dumps({"api_path": api_path.rstrip("/")}) path_prefix_response_body = json.dumps({"api_path": api_path.rstrip("/")})
self.app.router.add_get(f"{ui_base}/paths.json", lambda _: web.Response(body=path_prefix_response_body, self.app.router.add_get(
content_type="application/json")) f"{ui_base}/paths.json",
lambda _: web.Response(
body=path_prefix_response_body, content_type="application/json"
),
)
def add_route(self, method: Method, path: PathBuilder, handler) -> None: def add_route(self, method: Method, path: PathBuilder, handler) -> None:
self.app.router.add_route(method.value, str(path), handler) self.app.router.add_route(method.value, str(path), handler)
@ -161,9 +174,7 @@ class MaubotServer:
@staticmethod @staticmethod
async def version(_: web.Request) -> web.Response: async def version(_: web.Request) -> web.Response:
return web.json_response({ return web.json_response({"version": __version__})
"version": __version__
})
async def handle_transaction(self, request: web.Request) -> web.Response: async def handle_transaction(self, request: web.Request) -> web.Response:
return web.Response(status=501) return web.Response(status=501)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,43 +13,53 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Type, cast from __future__ import annotations
import logging.config
import importlib from typing import cast
import argparse import argparse
import asyncio import asyncio
import copy
import importlib
import logging.config
import os.path import os.path
import signal import signal
import copy
import sys import sys
from aiohttp import ClientSession, hdrs, web
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from aiohttp import web, hdrs, ClientSession
from yarl import URL from yarl import URL
import sqlalchemy as sql
from mautrix.util.config import RecursiveDict, BaseMissingError from mautrix.types import (
EventType,
Filter,
FilterID,
Membership,
RoomEventFilter,
RoomFilter,
StrippedStateEvent,
SyncToken,
)
from mautrix.util.config import BaseMissingError, RecursiveDict
from mautrix.util.db import Base from mautrix.util.db import Base
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from mautrix.types import (Filter, RoomFilter, RoomEventFilter, StrippedStateEvent,
EventType, Membership, FilterID, SyncToken)
from ..__meta__ import __version__
from ..lib.store_proxy import SyncStoreProxy
from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient
from ..plugin_base import Plugin from ..plugin_base import Plugin
from ..plugin_server import PluginWebApp, PrefixResource from ..plugin_server import PluginWebApp, PrefixResource
from ..loader import PluginMeta
from ..server import AccessLogger from ..server import AccessLogger
from ..matrix import MaubotMatrixClient
from ..lib.store_proxy import SyncStoreProxy
from ..__meta__ import __version__
from .config import Config from .config import Config
from .loader import FileSystemLoader
from .database import NextBatch from .database import NextBatch
from .loader import FileSystemLoader
crypto_import_error = None crypto_import_error = None
try: try:
from mautrix.crypto import OlmMachine, PgCryptoStore, PgCryptoStateStore from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
from mautrix.util.async_db import Database as AsyncDatabase from mautrix.util.async_db import Database as AsyncDatabase
except ImportError as err: except ImportError as err:
crypto_import_error = err crypto_import_error = err
@ -57,15 +67,32 @@ except ImportError as err:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="A plugin-based Matrix bot system -- standalone mode.", description="A plugin-based Matrix bot system -- standalone mode.",
prog="python -m maubot.standalone") prog="python -m maubot.standalone",
parser.add_argument("-c", "--config", type=str, default="config.yaml", )
metavar="<path>", help="the path to your config file") parser.add_argument(
parser.add_argument("-b", "--base-config", type=str, "-c",
default="pkg://maubot.standalone/example-config.yaml", "--config",
metavar="<path>", help="the path to the example config " type=str,
"(for automatic config updates)") default="config.yaml",
parser.add_argument("-m", "--meta", type=str, default="maubot.yaml", metavar="<path>",
metavar="<path>", help="the path to your plugin metadata file") help="the path to your config file",
)
parser.add_argument(
"-b",
"--base-config",
type=str,
default="pkg://maubot.standalone/example-config.yaml",
metavar="<path>",
help="the path to the example config " "(for automatic config updates)",
)
parser.add_argument(
"-m",
"--meta",
type=str,
default="maubot.yaml",
metavar="<path>",
help="the path to your plugin metadata file",
)
args = parser.parse_args() args = parser.parse_args()
config = Config(args.config, args.base_config) config = Config(args.config, args.base_config)
@ -92,7 +119,7 @@ else:
module = meta.modules[0] module = meta.modules[0]
main_class = meta.main_class main_class = meta.main_class
bot_module = importlib.import_module(module) bot_module = importlib.import_module(module)
plugin: Type[Plugin] = getattr(bot_module, main_class) plugin: type[Plugin] = getattr(bot_module, main_class)
loader = FileSystemLoader(os.path.dirname(args.meta)) loader = FileSystemLoader(os.path.dirname(args.meta))
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}") log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
@ -110,8 +137,10 @@ access_token = config["user.credentials.access_token"]
crypto_store = crypto_db = state_store = None crypto_store = crypto_db = state_store = None
if device_id and not OlmMachine: if device_id and not OlmMachine:
log.warning("device_id set in config, but encryption dependencies not installed", log.warning(
exc_info=crypto_import_error) "device_id set in config, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
elif device_id: elif device_id:
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table) crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table)
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db) crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db)
@ -124,27 +153,25 @@ if not nb:
bot_config = None bot_config = None
if not meta.config and "base-config.yaml" in meta.extra_files: if not meta.config and "base-config.yaml" in meta.extra_files:
log.warning("base-config.yaml in extra files, but config is not set to true. " log.warning(
"Assuming legacy plugin and loading config.") "base-config.yaml in extra files, but config is not set to true. "
"Assuming legacy plugin and loading config."
)
meta.config = True meta.config = True
if meta.config: if meta.config:
log.debug("Loading config") log.debug("Loading config")
config_class = plugin.get_config_class() config_class = plugin.get_config_class()
def load() -> CommentedMap: def load() -> CommentedMap:
return config["plugin_config"] return config["plugin_config"]
def load_base() -> RecursiveDict[CommentedMap]: def load_base() -> RecursiveDict[CommentedMap]:
return RecursiveDict(config.load_base()["plugin_config"], CommentedMap) return RecursiveDict(config.load_base()["plugin_config"], CommentedMap)
def save(data: RecursiveDict[CommentedMap]) -> None: def save(data: RecursiveDict[CommentedMap]) -> None:
config["plugin_config"] = data config["plugin_config"] = data
config.save() config.save()
try: try:
bot_config = config_class(load=load, load_base=load_base, save=save) bot_config = config_class(load=load, load_base=load_base, save=save)
bot_config.load_and_update() bot_config.load_and_update()
@ -161,9 +188,11 @@ if meta.webapp:
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse: async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
if req.path.startswith(web_base_path): if req.path.startswith(web_base_path):
req = req.clone(rel_url=req.rel_url req = req.clone(
.with_path(req.rel_url.path[len(web_base_path):]) rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query(
.with_query(req.query_string)) req.query_string
)
)
return await plugin_webapp.handle(req) return await plugin_webapp.handle(req)
return web.Response(status=404) return web.Response(status=404)
@ -175,8 +204,8 @@ else:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
client: Optional[MaubotMatrixClient] = None client: MaubotMatrixClient | None = None
bot: Optional[Plugin] = None bot: Plugin | None = None
async def main(): async def main():
@ -185,10 +214,17 @@ async def main():
global client, bot global client, bot
client_log = logging.getLogger("maubot.client").getChild(user_id) client_log = logging.getLogger("maubot.client").getChild(user_id)
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token, client = MaubotMatrixClient(
client_session=http_client, loop=loop, log=client_log, mxid=user_id,
sync_store=SyncStoreProxy(nb), state_store=state_store, base_url=homeserver,
device_id=device_id) token=access_token,
client_session=http_client,
loop=loop,
log=client_log,
sync_store=SyncStoreProxy(nb),
state_store=state_store,
device_id=device_id,
)
client.ignore_first_sync = config["user.ignore_first_sync"] client.ignore_first_sync = config["user.ignore_first_sync"]
client.ignore_initial_sync = config["user.ignore_initial_sync"] client.ignore_initial_sync = config["user.ignore_initial_sync"]
if crypto_store: if crypto_store:
@ -199,8 +235,10 @@ async def main():
client.crypto = OlmMachine(client, crypto_store, state_store) client.crypto = OlmMachine(client, crypto_store, state_store)
crypto_device_id = await crypto_store.get_device_id() crypto_device_id = await crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != device_id: if crypto_device_id and crypto_device_id != device_id:
log.fatal("Mismatching device ID in crypto store and config " log.fatal(
f"(store: {crypto_device_id}, config: {device_id})") "Mismatching device ID in crypto store and config "
f"(store: {crypto_device_id}, config: {device_id})"
)
sys.exit(10) sys.exit(10)
await client.crypto.load() await client.crypto.load()
if not crypto_device_id: if not crypto_device_id:
@ -224,17 +262,23 @@ async def main():
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}") log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
sys.exit(11) sys.exit(11)
elif whoami.device_id and device_id and whoami.device_id != device_id: elif whoami.device_id and device_id and whoami.device_id != device_id:
log.fatal(f"Device ID mismatch: configured {device_id}, " log.fatal(
f"but server said {whoami.device_id}") f"Device ID mismatch: configured {device_id}, "
f"but server said {whoami.device_id}"
)
sys.exit(12) sys.exit(12)
log.debug(f"Confirmed connection as {whoami.user_id} / {whoami.device_id}") log.debug(f"Confirmed connection as {whoami.user_id} / {whoami.device_id}")
break break
if config["user.sync"]: if config["user.sync"]:
if not nb.filter_id: if not nb.filter_id:
nb.edit(filter_id=await client.create_filter(Filter( nb.edit(
room=RoomFilter(timeline=RoomEventFilter(limit=50)), filter_id=await client.create_filter(
))) Filter(
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
)
)
)
client.start(nb.filter_id) client.start(nb.filter_id)
if config["user.autojoin"]: if config["user.autojoin"]:
@ -252,9 +296,18 @@ async def main():
await client.set_displayname(displayname) await client.set_displayname(displayname)
plugin_log = cast(TraceLogger, logging.getLogger("maubot.instance.__main__")) plugin_log = cast(TraceLogger, logging.getLogger("maubot.instance.__main__"))
bot = plugin(client=client, loop=loop, http=http_client, instance_id="__main__", bot = plugin(
log=plugin_log, config=bot_config, database=db if meta.database else None, client=client,
webapp=plugin_webapp, webapp_url=public_url, loader=loader) loop=loop,
http=http_client,
instance_id="__main__",
log=plugin_log,
config=bot_config,
database=db if meta.database else None,
webapp=plugin_webapp,
webapp_url=public_url,
loader=loader,
)
await bot.internal_start() await bot.internal_start()

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from __future__ import annotations
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import FilterID, SyncToken, UserID
from mautrix.util.db import Base from mautrix.util.db import Base
from mautrix.types import UserID, SyncToken, FilterID
class NextBatch(Base): class NextBatch(Base):
@ -29,5 +29,5 @@ class NextBatch(Base):
filter_id: FilterID = sql.Column(sql.String(255)) filter_id: FilterID = sql.Column(sql.String(255))
@classmethod @classmethod
def get(cls, user_id: UserID) -> Optional['NextBatch']: def get(cls, user_id: UserID) -> NextBatch | None:
return cls._select_one_or_none(cls.c.user_id == user_id) return cls._select_one_or_none(cls.c.user_id == user_id)

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan # Copyright (C) 2022 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List from __future__ import annotations
import os.path
import os import os
import os.path
from ..loader import BasePluginLoader from ..loader import BasePluginLoader
class FileSystemLoader(BasePluginLoader): class FileSystemLoader(BasePluginLoader):
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
self.path = path self.path = path
@ -34,8 +36,8 @@ class FileSystemLoader(BasePluginLoader):
async def read_file(self, path: str) -> bytes: async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path) return self.sync_read_file(path)
def sync_list_files(self, directory: str) -> List[str]: def sync_list_files(self, directory: str) -> list[str]:
return os.listdir(os.path.join(self.path, directory)) return os.listdir(os.path.join(self.path, directory))
async def list_files(self, directory: str) -> List[str]: async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory) return self.sync_list_files(directory)

14
pyproject.toml Normal file
View File

@ -0,0 +1,14 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = "mautrix"
line_length = 99
skip = ["maubot/management/frontend"]
[tool.black]
line-length = 99
target-version = ["py38"]
required-version = "22.1.0"
force-exclude = "maubot/management/frontend"

View File

@ -1,4 +1,4 @@
mautrix==0.15.0rc4 mautrix>=0.15.0,<0.16
aiohttp>=3,<4 aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
SQLAlchemy>=1,<1.4 SQLAlchemy>=1,<1.4