Finish plugin API and add basic login system
This commit is contained in:
parent
d7f072aeff
commit
14fd0d6ac9
1
.gitignore
vendored
1
.gitignore
vendored
@ -13,3 +13,4 @@ __pycache__
|
||||
|
||||
logs/
|
||||
plugins/
|
||||
trash/
|
||||
|
@ -30,8 +30,10 @@ server:
|
||||
# Set to "generate" to generate and save a new token at startup.
|
||||
unshared_secret: generate
|
||||
|
||||
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
|
||||
# to prevent normal login. Root is a special user that can't have a password and will always exist.
|
||||
admins:
|
||||
- "@admin:example.com"
|
||||
root: ""
|
||||
|
||||
# Python logging configuration.
|
||||
#
|
||||
|
@ -14,7 +14,6 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import orm
|
||||
from time import time
|
||||
import sqlalchemy as sql
|
||||
import logging.config
|
||||
import argparse
|
||||
@ -22,7 +21,6 @@ import asyncio
|
||||
import signal
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
|
||||
from .config import Config
|
||||
from .db import Base, init as init_db
|
||||
|
@ -15,9 +15,13 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import random
|
||||
import string
|
||||
import bcrypt
|
||||
import re
|
||||
|
||||
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
|
||||
|
||||
bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
|
||||
|
||||
|
||||
class Config(BaseFileConfig):
|
||||
@staticmethod
|
||||
@ -27,16 +31,35 @@ class Config(BaseFileConfig):
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
base, copy, _ = helper
|
||||
copy("database")
|
||||
copy("plugin_directories")
|
||||
copy("plugin_db_directory")
|
||||
copy("plugin_directories.upload")
|
||||
copy("plugin_directories.load")
|
||||
copy("plugin_directories.trash")
|
||||
copy("plugin_directories.db")
|
||||
copy("server.hostname")
|
||||
copy("server.port")
|
||||
copy("server.listen")
|
||||
copy("server.base_path")
|
||||
shared_secret = self["server.shared_secret"]
|
||||
copy("server.appservice_base_path")
|
||||
shared_secret = self["server.unshared_secret"]
|
||||
if shared_secret is None or shared_secret == "generate":
|
||||
base["server.shared_secret"] = self._new_token()
|
||||
base["server.unshared_secret"] = self._new_token()
|
||||
else:
|
||||
base["server.shared_secret"] = shared_secret
|
||||
base["server.unshared_secret"] = shared_secret
|
||||
copy("admins")
|
||||
for username, password in base["admins"].items():
|
||||
if password and not bcrypt_regex.match(password):
|
||||
if password == "password":
|
||||
password = self._new_token()
|
||||
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
|
||||
bcrypt.gensalt()).decode("utf-8")
|
||||
copy("logging")
|
||||
|
||||
def is_admin(self, user: str) -> bool:
|
||||
return user == "root" or user in self["admins"]
|
||||
|
||||
def check_password(self, user: str, passwd: str) -> bool:
|
||||
if user == "root":
|
||||
return False
|
||||
passwd_hash = self["admins"].get(user, None)
|
||||
if not passwd_hash:
|
||||
return False
|
||||
return bcrypt.checkpw(passwd.encode("utf-8"), passwd_hash.encode("utf-8"))
|
||||
|
@ -87,13 +87,6 @@ class PluginInstance:
|
||||
def load_config(self) -> CommentedMap:
|
||||
return yaml.load(self.db_instance.config)
|
||||
|
||||
def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]:
|
||||
try:
|
||||
base = self.loader.read_file("base-config.yaml")
|
||||
return RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
||||
except (FileNotFoundError, KeyError):
|
||||
return None
|
||||
|
||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||
buf = io.StringIO()
|
||||
yaml.dump(data, buf)
|
||||
@ -103,14 +96,23 @@ class PluginInstance:
|
||||
if not self.enabled:
|
||||
self.log.warning(f"Plugin disabled, not starting.")
|
||||
return
|
||||
cls = self.loader.load()
|
||||
cls = await self.loader.load()
|
||||
config_class = cls.get_config_class()
|
||||
if config_class:
|
||||
self.config = config_class(self.load_config, self.load_config_base,
|
||||
self.save_config)
|
||||
try:
|
||||
base = await self.loader.read_file("base-config.yaml")
|
||||
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
||||
except (FileNotFoundError, KeyError):
|
||||
base_file = None
|
||||
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
|
||||
self.plugin = cls(self.client.client, self.id, self.log, self.config,
|
||||
self.mb_config["plugin_db_directory"])
|
||||
await self.plugin.start()
|
||||
self.mb_config["plugin_directories.db"])
|
||||
try:
|
||||
await self.plugin.start()
|
||||
except Exception:
|
||||
self.log.exception("Failed to start instance")
|
||||
self.enabled = False
|
||||
return
|
||||
self.running = True
|
||||
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
||||
f"with user {self.client.id}")
|
||||
|
@ -59,10 +59,12 @@ class PluginLoader(ABC):
|
||||
pass
|
||||
|
||||
async def stop_instances(self) -> None:
|
||||
await asyncio.gather([instance.stop() for instance in self.references if instance.running])
|
||||
await asyncio.gather(*[instance.stop() for instance
|
||||
in self.references if instance.running])
|
||||
|
||||
async def start_instances(self) -> None:
|
||||
await asyncio.gather([instance.start() for instance in self.references if instance.enabled])
|
||||
await asyncio.gather(*[instance.start() for instance
|
||||
in self.references if instance.enabled])
|
||||
|
||||
@abstractmethod
|
||||
async def load(self) -> Type[PluginClass]:
|
||||
|
@ -207,8 +207,10 @@ class ZippedPluginLoader(PluginLoader):
|
||||
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
||||
return plugin
|
||||
|
||||
async def reload(self) -> Type[PluginClass]:
|
||||
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
|
||||
await self.unload()
|
||||
if new_path is not None:
|
||||
self.path = new_path
|
||||
return await self.load(reset_cache=True)
|
||||
|
||||
async def unload(self) -> None:
|
||||
|
@ -27,8 +27,9 @@ config: Config = None
|
||||
|
||||
def is_valid_token(token: str) -> bool:
|
||||
data = verify_token(config["server.unshared_secret"], token)
|
||||
user_id = data.get("user_id", None)
|
||||
return user_id is not None and user_id in config["admins"]
|
||||
if not data:
|
||||
return False
|
||||
return config.is_admin(data.get("user_id", None))
|
||||
|
||||
|
||||
def create_token(user: UserID) -> str:
|
||||
@ -40,7 +41,9 @@ def create_token(user: UserID) -> str:
|
||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||
global config
|
||||
config = cfg
|
||||
from .middleware import auth, error, log
|
||||
app = web.Application(loop=loop, middlewares=[auth, log, error])
|
||||
from .middleware import auth, error
|
||||
from .auth import web as _
|
||||
from .plugin import web as _
|
||||
app = web.Application(loop=loop, middlewares=[auth, error])
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
43
maubot/management/api/auth.py
Normal file
43
maubot/management/api/auth.py
Normal file
@ -0,0 +1,43 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from aiohttp import web
|
||||
import json
|
||||
|
||||
from . import routes, config, create_token
|
||||
from .responses import ErrBadAuth, ErrBodyNotJSON
|
||||
|
||||
|
||||
@routes.post("/login")
|
||||
async def login(request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return ErrBodyNotJSON
|
||||
secret = data.get("secret")
|
||||
if secret and config["server.unshared_secret"] == secret:
|
||||
user = data.get("user") or "root"
|
||||
return web.json_response({
|
||||
"token": create_token(user),
|
||||
})
|
||||
|
||||
username = data.get("username")
|
||||
password = data.get("password")
|
||||
if config.check_password(username, password):
|
||||
return web.json_response({
|
||||
"token": create_token(username),
|
||||
})
|
||||
|
||||
return ErrBadAuth
|
@ -15,25 +15,21 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Callable, Awaitable
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
from .responses import ErrNoToken, ErrInvalidToken
|
||||
from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed
|
||||
from . import is_valid_token
|
||||
|
||||
Handler = Callable[[web.Request], Awaitable[web.Response]]
|
||||
|
||||
req_log = logging.getLogger("maubot.mgmt.request")
|
||||
resp_log = logging.getLogger("maubot.mgmt.response")
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
||||
if request.path.endswith("/login"):
|
||||
return await handler(request)
|
||||
token = request.headers.get("Authorization", "")
|
||||
if not token or not token.startswith("Bearer "):
|
||||
req_log.debug(f"Request missing auth: {request.remote} {request.method} {request.path}")
|
||||
return ErrNoToken
|
||||
if not is_valid_token(token[len("Bearer "):]):
|
||||
req_log.debug(f"Request invalid auth: {request.remote} {request.method} {request.path}")
|
||||
return ErrInvalidToken
|
||||
return await handler(request)
|
||||
|
||||
@ -43,6 +39,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
|
||||
try:
|
||||
return await handler(request)
|
||||
except web.HTTPException as ex:
|
||||
if ex.status_code == 404:
|
||||
return ErrPathNotFound
|
||||
elif ex.status_code == 405:
|
||||
return ErrMethodNotAllowed
|
||||
return web.json_response({
|
||||
"error": f"Unhandled HTTP {ex.status}",
|
||||
"errcode": f"unhandled_http_{ex.status}",
|
||||
@ -56,12 +56,3 @@ def get_req_no():
|
||||
global req_no
|
||||
req_no += 1
|
||||
return req_no
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def log(request: web.Request, handler: Handler) -> web.Response:
|
||||
local_req_no = get_req_no()
|
||||
req_log.info(f"Request {local_req_no}: {request.remote} {request.method} {request.path}")
|
||||
resp = await handler(request)
|
||||
resp_log.info(f"Responded to {local_req_no} from {request.remote}: {resp}")
|
||||
return resp
|
||||
|
@ -92,25 +92,24 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
|
||||
if plugin.version in filename:
|
||||
filename = filename.replace(plugin.version, new_version)
|
||||
else:
|
||||
filename = filename.rstrip(".mbp") + new_version + ".mbp"
|
||||
filename = filename.rstrip(".mbp")
|
||||
filename = f"{filename}-v{new_version}.mbp"
|
||||
path = os.path.join(dirname, filename)
|
||||
with open(path, "wb") as p:
|
||||
p.write(content)
|
||||
old_path = plugin.path
|
||||
plugin.path = path
|
||||
await plugin.stop_instances()
|
||||
try:
|
||||
await plugin.reload()
|
||||
await plugin.reload(new_path=path)
|
||||
except MaubotZipImportError as e:
|
||||
plugin.path = old_path
|
||||
try:
|
||||
await plugin.reload()
|
||||
await plugin.reload(new_path=old_path)
|
||||
await plugin.start_instances()
|
||||
except MaubotZipImportError:
|
||||
pass
|
||||
await plugin.start_instances()
|
||||
return plugin_import_error(str(e), traceback.format_exc())
|
||||
await plugin.start_instances()
|
||||
ZippedPluginLoader.trash(plugin.path, reason="update")
|
||||
ZippedPluginLoader.trash(old_path, reason="update")
|
||||
return RespOK
|
||||
|
||||
|
||||
|
@ -13,27 +13,48 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from http import HTTPStatus
|
||||
from aiohttp import web
|
||||
|
||||
ErrBadAuth = web.json_response({
|
||||
"error": "Invalid username or password",
|
||||
"errcode": "invalid_auth",
|
||||
}, status=HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
ErrNoToken = web.json_response({
|
||||
"error": "Authorization token missing",
|
||||
"errcode": "auth_token_missing",
|
||||
}, status=web.HTTPUnauthorized)
|
||||
}, status=HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
ErrInvalidToken = web.json_response({
|
||||
"error": "Invalid authorization token",
|
||||
"errcode": "auth_token_invalid",
|
||||
}, status=web.HTTPUnauthorized)
|
||||
}, status=HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
ErrPluginNotFound = web.json_response({
|
||||
"error": "Plugin not found",
|
||||
"errcode": "plugin_not_found",
|
||||
}, status=web.HTTPNotFound)
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
ErrPathNotFound = web.json_response({
|
||||
"error": "Resource not found",
|
||||
"errcode": "resource_not_found",
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
ErrMethodNotAllowed = web.json_response({
|
||||
"error": "Method not allowed",
|
||||
"errcode": "method_not_allowed",
|
||||
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
ErrPluginInUse = web.json_response({
|
||||
"error": "Plugin instances of this type still exist",
|
||||
"errcode": "plugin_in_use",
|
||||
}, status=web.HTTPPreconditionFailed)
|
||||
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||
|
||||
ErrBodyNotJSON = web.json_response({
|
||||
"error": "Request body is not JSON",
|
||||
"errcode": "body_not_json",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||
@ -41,7 +62,7 @@ def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||
"error": error,
|
||||
"stacktrace": stacktrace,
|
||||
"errcode": "plugin_invalid",
|
||||
}, status=web.HTTPBadRequest)
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
|
||||
@ -49,21 +70,21 @@ def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
|
||||
"error": error,
|
||||
"stacktrace": stacktrace,
|
||||
"errcode": "plugin_reload_fail",
|
||||
}, status=web.HTTPInternalServerError)
|
||||
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
ErrUnsupportedPluginLoader = web.json_response({
|
||||
"error": "Existing plugin with same ID uses unsupported plugin loader",
|
||||
"errcode": "unsupported_plugin_loader",
|
||||
}, status=web.HTTPBadRequest)
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
ErrNotImplemented = web.json_response({
|
||||
"error": "Not implemented",
|
||||
"errcode": "not_implemented",
|
||||
}, status=web.HTTPNotImplemented)
|
||||
}, status=HTTPStatus.NOT_IMPLEMENTED)
|
||||
|
||||
RespOK = web.json_response({
|
||||
"success": True,
|
||||
}, status=web.HTTPOk)
|
||||
}, status=HTTPStatus.OK)
|
||||
|
||||
RespDeleted = web.Response(status=web.HTTPNoContent)
|
||||
RespDeleted = web.Response(status=HTTPStatus.NO_CONTENT)
|
||||
|
@ -27,6 +27,9 @@ if TYPE_CHECKING:
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
|
||||
|
||||
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
|
||||
|
||||
|
||||
class Plugin(ABC):
|
||||
client: 'MaubotMatrixClient'
|
||||
id: str
|
||||
@ -41,7 +44,9 @@ class Plugin(ABC):
|
||||
self.config = config
|
||||
self.__db_base_path = db_base_path
|
||||
|
||||
def request_db_engine(self) -> Engine:
|
||||
def request_db_engine(self) -> Optional[Engine]:
|
||||
if not self.__db_base_path:
|
||||
raise DatabaseNotConfigured
|
||||
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
|
||||
|
||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||
|
@ -14,6 +14,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from mautrix.api import PathBuilder, Method
|
||||
@ -23,6 +24,8 @@ from .__meta__ import __version__
|
||||
|
||||
|
||||
class MaubotServer:
|
||||
log: logging.Logger = logging.getLogger("maubot.server")
|
||||
|
||||
def __init__(self, config: Config, management: web.Application,
|
||||
loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
@ -45,6 +48,7 @@ class MaubotServer:
|
||||
await self.runner.setup()
|
||||
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
|
||||
await site.start()
|
||||
self.log.info(f"Listening on {site.name}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.runner.cleanup()
|
||||
|
@ -5,3 +5,4 @@ alembic
|
||||
commonmark
|
||||
ruamel.yaml
|
||||
attrs
|
||||
bcrypt
|
||||
|
Loading…
Reference in New Issue
Block a user