Finish plugin API and add basic login system

This commit is contained in:
Tulir Asokan 2018-10-31 02:03:27 +02:00
parent d7f072aeff
commit 14fd0d6ac9
16 changed files with 160 additions and 62 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@ __pycache__
logs/
plugins/
trash/

View File

@ -30,8 +30,10 @@ server:
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
- "@admin:example.com"
root: ""
# Python logging configuration.
#

View File

@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import orm
from time import time
import sqlalchemy as sql
import logging.config
import argparse
@ -22,7 +21,6 @@ import asyncio
import signal
import copy
import sys
import os
from .config import Config
from .db import Base, init as init_db

View File

@ -15,9 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import random
import string
import bcrypt
import re
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
class Config(BaseFileConfig):
@staticmethod
@ -27,16 +31,35 @@ class Config(BaseFileConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
base, copy, _ = helper
copy("database")
copy("plugin_directories")
copy("plugin_db_directory")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
copy("plugin_directories.db")
copy("server.hostname")
copy("server.port")
copy("server.listen")
copy("server.base_path")
shared_secret = self["server.shared_secret"]
copy("server.appservice_base_path")
shared_secret = self["server.unshared_secret"]
if shared_secret is None or shared_secret == "generate":
base["server.shared_secret"] = self._new_token()
base["server.unshared_secret"] = self._new_token()
else:
base["server.shared_secret"] = shared_secret
base["server.unshared_secret"] = shared_secret
copy("admins")
for username, password in base["admins"].items():
if password and not bcrypt_regex.match(password):
if password == "password":
password = self._new_token()
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
bcrypt.gensalt()).decode("utf-8")
copy("logging")
def is_admin(self, user: str) -> bool:
return user == "root" or user in self["admins"]
def check_password(self, user: str, passwd: str) -> bool:
if user == "root":
return False
passwd_hash = self["admins"].get(user, None)
if not passwd_hash:
return False
return bcrypt.checkpw(passwd.encode("utf-8"), passwd_hash.encode("utf-8"))

View File

@ -87,13 +87,6 @@ class PluginInstance:
def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config)
def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]:
try:
base = self.loader.read_file("base-config.yaml")
return RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
return None
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
@ -103,14 +96,23 @@ class PluginInstance:
if not self.enabled:
self.log.warning(f"Plugin disabled, not starting.")
return
cls = self.loader.load()
cls = await self.loader.load()
config_class = cls.get_config_class()
if config_class:
self.config = config_class(self.load_config, self.load_config_base,
self.save_config)
try:
base = await self.loader.read_file("base-config.yaml")
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
base_file = None
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
self.plugin = cls(self.client.client, self.id, self.log, self.config,
self.mb_config["plugin_db_directory"])
self.mb_config["plugin_directories.db"])
try:
await self.plugin.start()
except Exception:
self.log.exception("Failed to start instance")
self.enabled = False
return
self.running = True
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
f"with user {self.client.id}")

View File

@ -59,10 +59,12 @@ class PluginLoader(ABC):
pass
async def stop_instances(self) -> None:
await asyncio.gather([instance.stop() for instance in self.references if instance.running])
await asyncio.gather(*[instance.stop() for instance
in self.references if instance.running])
async def start_instances(self) -> None:
await asyncio.gather([instance.start() for instance in self.references if instance.enabled])
await asyncio.gather(*[instance.start() for instance
in self.references if instance.enabled])
@abstractmethod
async def load(self) -> Type[PluginClass]:

View File

@ -207,8 +207,10 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
return plugin
async def reload(self) -> Type[PluginClass]:
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
await self.unload()
if new_path is not None:
self.path = new_path
return await self.load(reset_cache=True)
async def unload(self) -> None:

View File

@ -27,8 +27,9 @@ config: Config = None
def is_valid_token(token: str) -> bool:
data = verify_token(config["server.unshared_secret"], token)
user_id = data.get("user_id", None)
return user_id is not None and user_id in config["admins"]
if not data:
return False
return config.is_admin(data.get("user_id", None))
def create_token(user: UserID) -> str:
@ -40,7 +41,9 @@ def create_token(user: UserID) -> str:
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
global config
config = cfg
from .middleware import auth, error, log
app = web.Application(loop=loop, middlewares=[auth, log, error])
from .middleware import auth, error
from .auth import web as _
from .plugin import web as _
app = web.Application(loop=loop, middlewares=[auth, error])
app.add_routes(routes)
return app

View File

@ -0,0 +1,43 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
import json
from . import routes, config, create_token
from .responses import ErrBadAuth, ErrBodyNotJSON
@routes.post("/login")
async def login(request: web.Request) -> web.Response:
try:
data = await request.json()
except json.JSONDecodeError:
return ErrBodyNotJSON
secret = data.get("secret")
if secret and config["server.unshared_secret"] == secret:
user = data.get("user") or "root"
return web.json_response({
"token": create_token(user),
})
username = data.get("username")
password = data.get("password")
if config.check_password(username, password):
return web.json_response({
"token": create_token(username),
})
return ErrBadAuth

View File

@ -15,25 +15,21 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Awaitable
from aiohttp import web
import logging
from .responses import ErrNoToken, ErrInvalidToken
from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed
from . import is_valid_token
Handler = Callable[[web.Request], Awaitable[web.Response]]
req_log = logging.getLogger("maubot.mgmt.request")
resp_log = logging.getLogger("maubot.mgmt.response")
@web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response:
if request.path.endswith("/login"):
return await handler(request)
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
req_log.debug(f"Request missing auth: {request.remote} {request.method} {request.path}")
return ErrNoToken
if not is_valid_token(token[len("Bearer "):]):
req_log.debug(f"Request invalid auth: {request.remote} {request.method} {request.path}")
return ErrInvalidToken
return await handler(request)
@ -43,6 +39,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
try:
return await handler(request)
except web.HTTPException as ex:
if ex.status_code == 404:
return ErrPathNotFound
elif ex.status_code == 405:
return ErrMethodNotAllowed
return web.json_response({
"error": f"Unhandled HTTP {ex.status}",
"errcode": f"unhandled_http_{ex.status}",
@ -56,12 +56,3 @@ def get_req_no():
global req_no
req_no += 1
return req_no
@web.middleware
async def log(request: web.Request, handler: Handler) -> web.Response:
local_req_no = get_req_no()
req_log.info(f"Request {local_req_no}: {request.remote} {request.method} {request.path}")
resp = await handler(request)
resp_log.info(f"Responded to {local_req_no} from {request.remote}: {resp}")
return resp

View File

@ -92,25 +92,24 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
if plugin.version in filename:
filename = filename.replace(plugin.version, new_version)
else:
filename = filename.rstrip(".mbp") + new_version + ".mbp"
filename = filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp"
path = os.path.join(dirname, filename)
with open(path, "wb") as p:
p.write(content)
old_path = plugin.path
plugin.path = path
await plugin.stop_instances()
try:
await plugin.reload()
await plugin.reload(new_path=path)
except MaubotZipImportError as e:
plugin.path = old_path
try:
await plugin.reload()
await plugin.reload(new_path=old_path)
await plugin.start_instances()
except MaubotZipImportError:
pass
await plugin.start_instances()
return plugin_import_error(str(e), traceback.format_exc())
await plugin.start_instances()
ZippedPluginLoader.trash(plugin.path, reason="update")
ZippedPluginLoader.trash(old_path, reason="update")
return RespOK

View File

@ -13,27 +13,48 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from http import HTTPStatus
from aiohttp import web
ErrBadAuth = web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
}, status=HTTPStatus.UNAUTHORIZED)
ErrNoToken = web.json_response({
"error": "Authorization token missing",
"errcode": "auth_token_missing",
}, status=web.HTTPUnauthorized)
}, status=HTTPStatus.UNAUTHORIZED)
ErrInvalidToken = web.json_response({
"error": "Invalid authorization token",
"errcode": "auth_token_invalid",
}, status=web.HTTPUnauthorized)
}, status=HTTPStatus.UNAUTHORIZED)
ErrPluginNotFound = web.json_response({
"error": "Plugin not found",
"errcode": "plugin_not_found",
}, status=web.HTTPNotFound)
}, status=HTTPStatus.NOT_FOUND)
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrMethodNotAllowed = web.json_response({
"error": "Method not allowed",
"errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
ErrPluginInUse = web.json_response({
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
}, status=web.HTTPPreconditionFailed)
}, status=HTTPStatus.PRECONDITION_FAILED)
ErrBodyNotJSON = web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
@ -41,7 +62,7 @@ def plugin_import_error(error: str, stacktrace: str) -> web.Response:
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_invalid",
}, status=web.HTTPBadRequest)
}, status=HTTPStatus.BAD_REQUEST)
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
@ -49,21 +70,21 @@ def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_reload_fail",
}, status=web.HTTPInternalServerError)
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
ErrUnsupportedPluginLoader = web.json_response({
"error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader",
}, status=web.HTTPBadRequest)
}, status=HTTPStatus.BAD_REQUEST)
ErrNotImplemented = web.json_response({
"error": "Not implemented",
"errcode": "not_implemented",
}, status=web.HTTPNotImplemented)
}, status=HTTPStatus.NOT_IMPLEMENTED)
RespOK = web.json_response({
"success": True,
}, status=web.HTTPOk)
}, status=HTTPStatus.OK)
RespDeleted = web.Response(status=web.HTTPNoContent)
RespDeleted = web.Response(status=HTTPStatus.NO_CONTENT)

View File

@ -27,6 +27,9 @@ if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
class Plugin(ABC):
client: 'MaubotMatrixClient'
id: str
@ -41,7 +44,9 @@ class Plugin(ABC):
self.config = config
self.__db_base_path = db_base_path
def request_db_engine(self) -> Engine:
def request_db_engine(self) -> Optional[Engine]:
if not self.__db_base_path:
raise DatabaseNotConfigured
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
def set_command_spec(self, spec: 'CommandSpec') -> None:

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
import logging
import asyncio
from mautrix.api import PathBuilder, Method
@ -23,6 +24,8 @@ from .__meta__ import __version__
class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
def __init__(self, config: Config, management: web.Application,
loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop or asyncio.get_event_loop()
@ -45,6 +48,7 @@ class MaubotServer:
await self.runner.setup()
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
await site.start()
self.log.info(f"Listening on {site.name}")
async def stop(self) -> None:
await self.runner.cleanup()

View File

@ -5,3 +5,4 @@ alembic
commonmark
ruamel.yaml
attrs
bcrypt

View File

@ -28,6 +28,7 @@ setuptools.setup(
"commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16",
"attrs>=18.1.0,<19",
"bcrypt>=3.1.4,<4",
],
classifiers=[