Fix reusing management API responses and some other things

This commit is contained in:
Tulir Asokan 2018-11-02 18:45:07 +02:00
parent ec22e5eba7
commit 2736a1f47f
9 changed files with 272 additions and 186 deletions

View File

@ -134,6 +134,7 @@ class Client:
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
def to_dict(self) -> dict:
return {

View File

@ -14,13 +14,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from asyncio import AbstractEventLoop
import logging
import io
from sqlalchemy.orm import Session
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID
@ -56,6 +57,10 @@ class PluginInstance:
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.config = None
self.started = False
self.loader = None
self.client = None
self.plugin = None
self.base_cfg = None
self.cache[self.id] = self
def to_dict(self) -> dict:
@ -94,6 +99,7 @@ class PluginInstance:
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
# TODO delete plugin db
def load_config(self) -> CommentedMap:

View File

@ -192,8 +192,10 @@ class ZippedPluginLoader(PluginLoader):
for module in self.modules:
try:
importer.load_module(module)
except ZipImportError as e:
except ZipImportError:
raise MaubotZipLoadError(f"Module {module} not found in file")
except Exception:
raise MaubotZipLoadError(f"Failed to load module {module}")
try:
main_mod = sys.modules[self.main_module]
except KeyError as e:
@ -235,7 +237,7 @@ class ZippedPluginLoader(PluginLoader):
self._importer.remove_cache()
self._importer = None
self._loaded = None
os.remove(self.path)
self.trash(self.path, reason="delete")
self.id = None
self.path = None
self.version = None

View File

@ -22,7 +22,7 @@ from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token
from .base import routes, get_config
from .responses import ErrBadAuth, ErrBodyNotJSON, ErrNoToken, ErrInvalidToken
from .responses import resp
def is_valid_token(token: str) -> bool:
@ -35,6 +35,7 @@ def is_valid_token(token: str) -> bool:
def create_token(user: UserID) -> str:
return sign_token(get_config()["server.unshared_secret"], {
"user_id": user,
"created_at": int(time()),
})
@ -42,17 +43,15 @@ def create_token(user: UserID) -> str:
async def ping(request: web.Request) -> web.Response:
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
return ErrNoToken
return resp.no_token
data = verify_token(get_config()["server.unshared_secret"], token[len("Bearer "):])
if not data:
return ErrInvalidToken
return resp.invalid_token
user = data.get("user_id", None)
if not get_config().is_admin(user):
return ErrInvalidToken
return web.json_response({
"username": user,
})
return resp.invalid_token
return resp.pong(user)
@routes.post("/auth/login")
@ -60,21 +59,15 @@ async def login(request: web.Request) -> web.Response:
try:
data = await request.json()
except json.JSONDecodeError:
return ErrBodyNotJSON
return resp.body_not_json
secret = data.get("secret")
if secret and get_config()["server.unshared_secret"] == secret:
user = data.get("user") or "root"
return web.json_response({
"token": create_token(user),
"created_at": int(time()),
})
return resp.logged_in(create_token(user))
username = data.get("username")
password = data.get("password")
if get_config().check_password(username, password):
return web.json_response({
"token": create_token(username),
"created_at": int(time()),
})
return resp.logged_in(create_token(username))
return ErrBadAuth
return resp.bad_auth

View File

@ -26,14 +26,12 @@ from mautrix.client import Client as MatrixClient
from ...db import DBClient
from ...client import Client
from .base import routes
from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse,
ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch,
ErrUserExists)
from .responses import resp
@routes.get("/clients")
async def get_clients(_: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()])
return resp.found([client.to_dict() for client in Client.cache.values()])
@routes.get("/client/{id}")
@ -41,8 +39,8 @@ async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return web.json_response(client.to_dict())
return resp.client_not_found
return resp.found(client.to_dict())
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
@ -53,15 +51,15 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
try:
mxid = await new_client.whoami()
except MatrixInvalidToken:
return ErrBadClientAccessToken
return resp.bad_client_access_token
except MatrixRequestError:
return ErrBadClientAccessDetails
return resp.bad_client_access_details
if user_id is None:
existing_client = Client.get(mxid, None)
if existing_client is not None:
return ErrUserExists
return resp.user_exists
elif mxid != user_id:
return ErrMXIDMismatch
return resp.mxid_mismatch
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
@ -72,7 +70,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
Client.db.add(db_instance)
Client.db.commit()
await client.start()
return web.json_response(client.to_dict())
return resp.created(client.to_dict())
async def _update_client(client: Client, data: dict) -> web.Response:
@ -80,18 +78,18 @@ async def _update_client(client: Client, data: dict) -> web.Response:
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None))
except MatrixInvalidToken:
return ErrBadClientAccessToken
return resp.bad_client_access_token
except MatrixRequestError:
return ErrBadClientAccessDetails
return resp.bad_client_access_details
except ValueError:
return ErrMXIDMismatch
return resp.mxid_mismatch
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
await client.update_started(data.get("started", None))
client.enabled = data.get("enabled", client.enabled)
client.autojoin = data.get("autojoin", client.autojoin)
client.sync = data.get("sync", client.sync)
return web.json_response(client.to_dict(), status=HTTPStatus.CREATED)
return resp.updated(client.to_dict())
@routes.post("/client/new")
@ -99,7 +97,7 @@ async def create_client(request: web.Request) -> web.Response:
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
return resp.body_not_json
return await _create_client(None, data)
@ -110,7 +108,7 @@ async def update_client(request: web.Request) -> web.Response:
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
return resp.body_not_json
if not client:
return await _create_client(user_id, data)
else:
@ -122,10 +120,10 @@ async def delete_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return resp.client_not_found
if len(client.references) > 0:
return ErrClientInUse
return resp.client_in_use
if client.started:
await client.stop()
client.delete()
return RespDeleted
return resp.deleted

View File

@ -23,13 +23,12 @@ from ...instance import PluginInstance
from ...loader import PluginLoader
from ...client import Client
from .base import routes
from .responses import (ErrInstanceNotFound, ErrBodyNotJSON, RespDeleted, ErrPrimaryUserNotFound,
ErrPluginTypeRequired, ErrPrimaryUserRequired, ErrPluginTypeNotFound)
from .responses import resp
@routes.get("/instances")
async def get_instances(_: web.Request) -> web.Response:
return web.json_response([instance.to_dict() for instance in PluginInstance.cache.values()])
return resp.found([instance.to_dict() for instance in PluginInstance.cache.values()])
@routes.get("/instance/{id}")
@ -37,23 +36,23 @@ async def get_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return ErrInstanceNotFound
return web.json_response(instance.to_dict())
return resp.instance_not_found
return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response:
plugin_type = data.get("type", None)
primary_user = data.get("primary_user", None)
if not plugin_type:
return ErrPluginTypeRequired
return resp.plugin_type_required
elif not primary_user:
return ErrPrimaryUserRequired
return resp.primary_user_required
elif not Client.get(primary_user):
return ErrPrimaryUserNotFound
return resp.primary_user_not_found
try:
PluginLoader.find(plugin_type)
except KeyError:
return ErrPluginTypeNotFound
return resp.plugin_type_not_found
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance)
@ -61,18 +60,18 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
PluginInstance.db.add(db_instance)
PluginInstance.db.commit()
await instance.start()
return web.json_response(instance.to_dict(), status=HTTPStatus.CREATED)
return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
if not await instance.update_primary_user(data.get("primary_user", None)):
return ErrPrimaryUserNotFound
return resp.primary_user_not_found
instance.update_id(data.get("id", None))
instance.update_enabled(data.get("enabled", None))
instance.update_config(data.get("config", None))
await instance.update_started(data.get("started", None))
instance.db.commit()
return web.json_response(instance.to_dict())
return resp.updated(instance.to_dict())
@routes.put("/instance/{id}")
@ -82,7 +81,7 @@ async def update_instance(request: web.Request) -> web.Response:
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
return resp.body_not_json
if not instance:
return await _create_instance(instance_id, data)
else:
@ -94,8 +93,8 @@ async def delete_instance(request: web.Request) -> web.Response:
instance_id = request.match_info.get("id", "").lower()
instance = PluginInstance.get(instance_id, None)
if not instance:
return ErrInstanceNotFound
return resp.instance_not_found
if instance.started:
await instance.stop()
instance.delete()
return RespDeleted
return resp.deleted

View File

@ -14,9 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Awaitable
import logging
from aiohttp import web
from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed
from .responses import resp
from .auth import is_valid_token
Handler = Callable[[web.Request], Awaitable[web.Response]]
@ -28,25 +30,32 @@ async def auth(request: web.Request, handler: Handler) -> web.Response:
return await handler(request)
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
return ErrNoToken
return resp.no_token
if not is_valid_token(token[len("Bearer "):]):
return ErrInvalidToken
return resp.invalid_token
return await handler(request)
log = logging.getLogger("maubot.server")
@web.middleware
async def error(request: web.Request, handler: Handler) -> web.Response:
try:
return await handler(request)
except web.HTTPException as ex:
print(ex)
if ex.status_code == 404:
return ErrPathNotFound
return resp.path_not_found
elif ex.status_code == 405:
return ErrMethodNotAllowed
return resp.method_not_allowed
return web.json_response({
"error": f"Unhandled HTTP {ex.status}",
"errcode": f"unhandled_http_{ex.status}",
}, status=ex.status)
except Exception:
log.exception("Error in handler")
return resp.internal_server_error
req_no = 0

View File

@ -23,14 +23,13 @@ import re
from aiohttp import web
from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError
from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error,
plugin_reload_error, RespDeleted, RespOK, ErrUnsupportedPluginLoader)
from .responses import resp
from .base import routes, get_config
@routes.get("/plugins")
async def get_plugins(_) -> web.Response:
return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
return resp.found([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
@routes.get("/plugin/{id}")
@ -38,8 +37,8 @@ async def get_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return ErrPluginNotFound
return web.json_response(plugin.to_dict())
return resp.plugin_not_found
return resp.found(plugin.to_dict())
@routes.delete("/plugin/{id}")
@ -47,11 +46,11 @@ async def delete_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return ErrPluginNotFound
return resp.plugin_not_found
elif len(plugin.references) > 0:
return ErrPluginInUse
return resp.plugin_in_use
await plugin.delete()
return RespDeleted
return resp.deleted
@routes.post("/plugin/{id}/reload")
@ -59,15 +58,15 @@ async def reload_plugin(request: web.Request) -> web.Response:
plugin_id = request.match_info.get("id", None)
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return ErrPluginNotFound
return resp.plugin_not_found
await plugin.stop_instances()
try:
await plugin.reload()
except MaubotZipImportError as e:
return plugin_reload_error(str(e), traceback.format_exc())
return resp.plugin_reload_error(str(e), traceback.format_exc())
await plugin.start_instances()
return RespOK
return resp.ok
async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Response:
@ -78,8 +77,8 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo
plugin = ZippedPluginLoader.get(path)
except MaubotZipImportError as e:
ZippedPluginLoader.trash(path)
return plugin_import_error(str(e), traceback.format_exc())
return web.json_response(plugin.to_dict(), status=HTTPStatus.CREATED)
return resp.plugin_import_error(str(e), traceback.format_exc())
return resp.created(plugin.to_dict())
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
@ -107,10 +106,10 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
await plugin.start_instances()
except MaubotZipImportError:
pass
return plugin_import_error(str(e), traceback.format_exc())
return resp.plugin_import_error(str(e), traceback.format_exc())
await plugin.start_instances()
ZippedPluginLoader.trash(old_path, reason="update")
return web.json_response(plugin.to_dict())
return resp.updated(plugin.to_dict())
@routes.post("/plugins/upload")
@ -120,11 +119,11 @@ async def upload_plugin(request: web.Request) -> web.Response:
try:
pid, version = ZippedPluginLoader.verify_meta(file)
except MaubotZipImportError as e:
return plugin_import_error(str(e), traceback.format_exc())
return resp.plugin_import_error(str(e), traceback.format_exc())
plugin = PluginLoader.id_cache.get(pid, None)
if not plugin:
return await upload_new_plugin(content, pid, version)
elif isinstance(plugin, ZippedPluginLoader):
return await upload_replacement_plugin(plugin, content, version)
else:
return ErrUnsupportedPluginLoader
return resp.unsupported_plugin_loader

View File

@ -14,132 +14,211 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from http import HTTPStatus
from aiohttp import web
ErrBodyNotJSON = web.json_response({
class _Response:
@property
def body_not_json(self) -> web.Response:
return web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrPluginTypeRequired = web.json_response({
@property
def plugin_type_required(self) -> web.Response:
return web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
@property
def primary_user_required(self) -> web.Response:
return web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessToken = web.json_response({
@property
def bad_client_access_token(self) -> web.Response:
return web.json_response({
"error": "Invalid access token",
"errcode": "bad_client_access_token",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessDetails = web.json_response({
@property
def bad_client_access_details(self) -> web.Response:
return web.json_response({
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details"
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrMXIDMismatch = web.json_response({
@property
def mxid_mismatch(self) -> web.Response:
return web.json_response({
"error": "The Matrix user ID of the client and the user ID of the access token don't match",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrBadAuth = web.json_response({
@property
def bad_auth(self) -> web.Response:
return web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
}, status=HTTPStatus.UNAUTHORIZED)
}, status=HTTPStatus.UNAUTHORIZED)
ErrNoToken = web.json_response({
@property
def no_token(self) -> web.Response:
return web.json_response({
"error": "Authorization token missing",
"errcode": "auth_token_missing",
}, status=HTTPStatus.UNAUTHORIZED)
}, status=HTTPStatus.UNAUTHORIZED)
ErrInvalidToken = web.json_response({
@property
def invalid_token(self) -> web.Response:
return web.json_response({
"error": "Invalid authorization token",
"errcode": "auth_token_invalid",
}, status=HTTPStatus.UNAUTHORIZED)
}, status=HTTPStatus.UNAUTHORIZED)
ErrPluginNotFound = web.json_response({
@property
def plugin_not_found(self) -> web.Response:
return web.json_response({
"error": "Plugin not found",
"errcode": "plugin_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrClientNotFound = web.json_response({
@property
def client_not_found(self) -> web.Response:
return web.json_response({
"error": "Client not found",
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrPrimaryUserNotFound = web.json_response({
@property
def primary_user_not_found(self) -> web.Response:
return web.json_response({
"error": "Client for given primary user not found",
"errcode": "primary_user_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrInstanceNotFound = web.json_response({
@property
def instance_not_found(self) -> web.Response:
return web.json_response({
"error": "Plugin instance not found",
"errcode": "instance_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrPluginTypeNotFound = web.json_response({
@property
def plugin_type_not_found(self) -> web.Response:
return web.json_response({
"error": "Given plugin type not found",
"errcode": "plugin_type_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrPathNotFound = web.json_response({
@property
def path_not_found(self) -> web.Response:
return web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
}, status=HTTPStatus.NOT_FOUND)
}, status=HTTPStatus.NOT_FOUND)
ErrMethodNotAllowed = web.json_response({
@property
def method_not_allowed(self) -> web.Response:
return web.json_response({
"error": "Method not allowed",
"errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
ErrUserExists = web.json_response({
@property
def user_exists(self) -> web.Response:
return web.json_response({
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
}, status=HTTPStatus.CONFLICT)
}, status=HTTPStatus.CONFLICT)
ErrPluginInUse = web.json_response({
@property
def plugin_in_use(self) -> web.Response:
return web.json_response({
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
}, status=HTTPStatus.PRECONDITION_FAILED)
ErrClientInUse = web.json_response({
@property
def client_in_use(self) -> web.Response:
return web.json_response({
"error": "Plugin instances with this client as their primary user still exist",
"errcode": "client_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
}, status=HTTPStatus.PRECONDITION_FAILED)
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
@staticmethod
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_invalid",
}, status=HTTPStatus.BAD_REQUEST)
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
@staticmethod
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_reload_fail",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
@property
def internal_server_error(self) -> web.Response:
return web.json_response({
"error": "Internal server error",
"errcode": "internal_server_error",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
ErrUnsupportedPluginLoader = web.json_response({
@property
def unsupported_plugin_loader(self) -> web.Response:
return web.json_response({
"error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader",
}, status=HTTPStatus.BAD_REQUEST)
}, status=HTTPStatus.BAD_REQUEST)
ErrNotImplemented = web.json_response({
@property
def not_implemented(self) -> web.Response:
return web.json_response({
"error": "Not implemented",
"errcode": "not_implemented",
}, status=HTTPStatus.NOT_IMPLEMENTED)
}, status=HTTPStatus.NOT_IMPLEMENTED)
RespOK = web.json_response({
@property
def ok(self) -> web.Response:
return web.json_response({
"success": True,
}, status=HTTPStatus.OK)
}, status=HTTPStatus.OK)
RespDeleted = web.Response(status=HTTPStatus.NO_CONTENT)
@property
def deleted(self) -> web.Response:
return web.Response(status=HTTPStatus.NO_CONTENT)
@staticmethod
def found(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK)
def updated(self, data: dict) -> web.Response:
return self.found(data)
def logged_in(self, token: str) -> web.Response:
return self.found({
"token": token,
})
def pong(self, user: str) -> web.Response:
return self.found({
"username": user,
})
@staticmethod
def created(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.CREATED)
resp = _Response()