From 383c9ce5ec6c090112f28717ec59ad0058b3aa26 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Nov 2018 23:31:30 +0200 Subject: [PATCH] Implement client API --- maubot/__main__.py | 11 ++-- maubot/client.py | 83 +++++++++++++++++++++--------- maubot/management/api/client.py | 66 +++++++++++++++++++++--- maubot/management/api/responses.py | 53 ++++++++++++++----- 4 files changed, 163 insertions(+), 50 deletions(-) diff --git a/maubot/__main__.py b/maubot/__main__.py index 3e18a68..3929095 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -74,18 +74,21 @@ try: log.info("Starting server") loop.run_until_complete(server.start()) log.info("Starting clients and plugins") - loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) + loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop)) log.info("Startup actions complete, running forever") periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop) loop.run_forever() except KeyboardInterrupt: - log.debug("Interrupt received, stopping HTTP clients/servers and saving database") + log.info("Interrupt received, stopping HTTP clients/servers and saving database") if periodic_commit_task is not None: periodic_commit_task.cancel() - for client in Client.cache.values(): - client.stop() + log.debug("Stopping clients") + loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()], + loop=loop)) db_session.commit() + log.debug("Stopping server") loop.run_until_complete(server.stop()) + log.debug("Closing event loop") loop.close() log.debug("Everything stopped, shutting down") sys.exit(0) diff --git a/maubot/client.py b/maubot/client.py index 684eaea..91fcb0c 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -92,7 +92,7 @@ class Client: self.db_instance.enabled = False return if not self.filter_id: - self.filter_id = await self.client.create_filter(Filter( + self.db_instance.filter_id = await self.client.create_filter(Filter( room=RoomFilter( timeline=RoomEventFilter( limit=50, @@ -122,9 +122,18 @@ class Client: def stop_sync(self) -> None: self.client.stop() - def stop(self) -> None: - self.started = False - self.stop_sync() + async def stop(self) -> None: + if self.started: + self.started = False + await self.stop_plugins() + self.stop_sync() + + def delete(self) -> None: + try: + del self.cache[self.id] + except KeyError: + pass + self.db.delete(self.db_instance) def to_dict(self) -> dict: return { @@ -158,6 +167,44 @@ class Client: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room(evt.room_id) + async def update_started(self, started: bool) -> None: + if started is None or started == self.started: + return + if started: + await self.start() + else: + await self.stop() + + async def update_displayname(self, displayname: str) -> None: + if not displayname or displayname == self.displayname: + return + self.db_instance.displayname = displayname + await self.client.set_displayname(self.displayname) + + async def update_avatar_url(self, avatar_url: ContentURI) -> None: + if not avatar_url or avatar_url == self.avatar_url: + return + self.db_instance.avatar_url = avatar_url + await self.client.set_avatar_url(self.avatar_url) + + async def update_access_details(self, access_token: str, homeserver: str) -> None: + if not access_token and not homeserver: + return + elif access_token == self.access_token and homeserver == self.homeserver: + return + new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, + token=access_token or self.access_token, loop=self.loop, + client_session=self.http_client, log=self.log) + mxid = await new_client.whoami() + if mxid != self.id: + raise ValueError("MXID mismatch") + new_client.store = self.db_instance + self.stop_sync() + self.client = new_client + self.db_instance.homeserver = homeserver + self.db_instance.access_token = access_token + self.start_sync() + # region Properties @property @@ -172,11 +219,6 @@ class Client: def access_token(self) -> str: return self.db_instance.access_token - @access_token.setter - def access_token(self, value: str) -> None: - self.client.api.token = value - self.db_instance.access_token = value - @property def enabled(self) -> bool: return self.db_instance.enabled @@ -189,25 +231,24 @@ class Client: def next_batch(self) -> SyncToken: return self.db_instance.next_batch - @next_batch.setter - def next_batch(self, value: SyncToken) -> None: - self.db_instance.next_batch = value - @property def filter_id(self) -> FilterID: return self.db_instance.filter_id - @filter_id.setter - def filter_id(self, value: FilterID) -> None: - self.db_instance.filter_id = value - @property def sync(self) -> bool: return self.db_instance.sync @sync.setter def sync(self, value: bool) -> None: + if value == self.db_instance.sync: + return self.db_instance.sync = value + if self.started: + if value: + self.start_sync() + else: + self.stop_sync() @property def autojoin(self) -> bool: @@ -227,18 +268,10 @@ class Client: def displayname(self) -> str: return self.db_instance.displayname - @displayname.setter - def displayname(self, value: str) -> None: - self.db_instance.displayname = value - @property def avatar_url(self) -> ContentURI: return self.db_instance.avatar_url - @avatar_url.setter - def avatar_url(self, value: ContentURI) -> None: - self.db_instance.avatar_url = value - # endregion diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index 8f5eed0..84d5b37 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -17,15 +17,20 @@ from json import JSONDecodeError from aiohttp import web -from mautrix.types import UserID +from mautrix.types import UserID, SyncToken, FilterID +from mautrix.errors import MatrixRequestError, MatrixInvalidToken +from mautrix.client import Client as MatrixClient +from ...db import DBClient from ...client import Client from .base import routes -from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON +from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse, + ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch, + ErrUserExists) @routes.get("/clients") -async def get_clients(request: web.Request) -> web.Response: +async def get_clients(_: web.Request) -> web.Response: return web.json_response([client.to_dict() for client in Client.cache.values()]) @@ -39,17 +44,59 @@ async def get_client(request: web.Request) -> web.Response: async def create_client(user_id: UserID, data: dict) -> web.Response: - return ErrNotImplemented + homeserver = data.get("homeserver", None) + access_token = data.get("access_token", None) + new_client = MatrixClient(base_url=homeserver, token=access_token, loop=Client.loop, + client_session=Client.http_client) + try: + mxid = await new_client.whoami() + except MatrixInvalidToken: + return ErrBadClientAccessToken + except MatrixRequestError: + return ErrBadClientAccessDetails + if user_id == "new": + existing_client = Client.get(mxid, None) + if existing_client is not None: + return ErrUserExists + elif mxid != user_id: + return ErrMXIDMismatch + db_instance = DBClient(id=user_id, homeserver=homeserver, access_token=access_token, + enabled=data.get("enabled", True), next_batch=SyncToken(""), + filter_id=FilterID(""), sync=data.get("sync", True), + autojoin=data.get("autojoin", True), + displayname=data.get("displayname", ""), + avatar_url=data.get("avatar_url", "")) + client = Client(db_instance) + Client.db.add(db_instance) + Client.db.commit() + await client.start() + return web.json_response(client.to_dict()) async def update_client(client: Client, data: dict) -> web.Response: - return ErrNotImplemented + try: + await client.update_access_details(data.get("access_token", None), + data.get("homeserver", None)) + except MatrixInvalidToken: + return ErrBadClientAccessToken + except MatrixRequestError: + return ErrBadClientAccessDetails + except ValueError: + return ErrMXIDMismatch + await client.update_avatar_url(data.get("avatar_url", None)) + await client.update_displayname(data.get("displayname", None)) + await client.update_started(data.get("started", None)) + client.enabled = data.get("enabled", client.enabled) + client.autojoin = data.get("autojoin", client.autojoin) + client.sync = data.get("sync", client.sync) + return web.json_response(client.to_dict()) @routes.put("/client/{id}") async def update_client(request: web.Request) -> web.Response: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + # /client/new always creates a new client + client = Client.get(user_id, None) if user_id != "new" else None try: data = await request.json() except JSONDecodeError: @@ -66,4 +113,9 @@ async def delete_client(request: web.Request) -> web.Response: client = Client.get(user_id, None) if not client: return ErrClientNotFound - return ErrNotImplemented + if len(client.references) > 0: + return ErrClientInUse + if client.started: + await client.stop() + client.delete() + return RespDeleted diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 16efecb..9fd4c40 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -16,6 +16,36 @@ from http import HTTPStatus from aiohttp import web +ErrBodyNotJSON = web.json_response({ + "error": "Request body is not JSON", + "errcode": "body_not_json", +}, status=HTTPStatus.BAD_REQUEST) + +ErrPluginTypeRequired = web.json_response({ + "error": "Plugin type is required when creating plugin instances", + "errcode": "plugin_type_required", +}, status=HTTPStatus.BAD_REQUEST) + +ErrPrimaryUserRequired = web.json_response({ + "error": "Primary user is required when creating plugin instances", + "errcode": "primary_user_required", +}, status=HTTPStatus.BAD_REQUEST) + +ErrBadClientAccessToken = web.json_response({ + "error": "Invalid access token", + "errcode": "bad_client_access_token", +}, status=HTTPStatus.BAD_REQUEST) + +ErrBadClientAccessDetails = web.json_response({ + "error": "Invalid homeserver or access token", + "errcode": "bad_client_access_details" +}, status=HTTPStatus.BAD_REQUEST) + +ErrMXIDMismatch = web.json_response({ + "error": "The Matrix user ID of the client and the user ID of the access token don't match", + "errcode": "mxid_mismatch", +}, status=HTTPStatus.BAD_REQUEST) + ErrBadAuth = web.json_response({ "error": "Invalid username or password", "errcode": "invalid_auth", @@ -56,16 +86,6 @@ ErrPluginTypeNotFound = web.json_response({ "errcode": "plugin_type_not_found", }, status=HTTPStatus.NOT_FOUND) -ErrPluginTypeRequired = web.json_response({ - "error": "Plugin type is required when creating plugin instances", - "errcode": "plugin_type_required", -}, status=HTTPStatus.BAD_REQUEST) - -ErrPrimaryUserRequired = web.json_response({ - "error": "Primary user is required when creating plugin instances", - "errcode": "primary_user_required", -}, status=HTTPStatus.BAD_REQUEST) - ErrPathNotFound = web.json_response({ "error": "Resource not found", "errcode": "resource_not_found", @@ -76,15 +96,20 @@ ErrMethodNotAllowed = web.json_response({ "errcode": "method_not_allowed", }, status=HTTPStatus.METHOD_NOT_ALLOWED) +ErrUserExists = web.json_response({ + "error": "There is already a client with the user ID of that token", + "errcode": "user_exists", +}, status=HTTPStatus.CONFLICT) + ErrPluginInUse = web.json_response({ "error": "Plugin instances of this type still exist", "errcode": "plugin_in_use", }, status=HTTPStatus.PRECONDITION_FAILED) -ErrBodyNotJSON = web.json_response({ - "error": "Request body is not JSON", - "errcode": "body_not_json", -}, status=HTTPStatus.BAD_REQUEST) +ErrClientInUse = web.json_response({ + "error": "Plugin instances with this client as their primary user still exist", + "errcode": "client_in_use", +}, status=HTTPStatus.PRECONDITION_FAILED) def plugin_import_error(error: str, stacktrace: str) -> web.Response: