Implement client API

This commit is contained in:
Tulir Asokan 2018-11-01 23:31:30 +02:00
parent bc87b2a02b
commit 383c9ce5ec
4 changed files with 163 additions and 50 deletions

View File

@ -74,18 +74,21 @@ try:
log.info("Starting server") log.info("Starting server")
loop.run_until_complete(server.start()) loop.run_until_complete(server.start())
log.info("Starting clients and plugins") log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever") log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop) periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
log.debug("Interrupt received, stopping HTTP clients/servers and saving database") log.info("Interrupt received, stopping HTTP clients/servers and saving database")
if periodic_commit_task is not None: if periodic_commit_task is not None:
periodic_commit_task.cancel() periodic_commit_task.cancel()
for client in Client.cache.values(): log.debug("Stopping clients")
client.stop() loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop))
db_session.commit() db_session.commit()
log.debug("Stopping server")
loop.run_until_complete(server.stop()) loop.run_until_complete(server.stop())
log.debug("Closing event loop")
loop.close() loop.close()
log.debug("Everything stopped, shutting down") log.debug("Everything stopped, shutting down")
sys.exit(0) sys.exit(0)

View File

@ -92,7 +92,7 @@ class Client:
self.db_instance.enabled = False self.db_instance.enabled = False
return return
if not self.filter_id: if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter( self.db_instance.filter_id = await self.client.create_filter(Filter(
room=RoomFilter( room=RoomFilter(
timeline=RoomEventFilter( timeline=RoomEventFilter(
limit=50, limit=50,
@ -122,10 +122,19 @@ class Client:
def stop_sync(self) -> None: def stop_sync(self) -> None:
self.client.stop() self.client.stop()
def stop(self) -> None: async def stop(self) -> None:
if self.started:
self.started = False self.started = False
await self.stop_plugins()
self.stop_sync() self.stop_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"id": self.id, "id": self.id,
@ -158,6 +167,44 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id) await self.client.join_room(evt.room_id)
async def update_started(self, started: bool) -> None:
if started is None or started == self.started:
return
if started:
await self.start()
else:
await self.stop()
async def update_displayname(self, displayname: str) -> None:
if not displayname or displayname == self.displayname:
return
self.db_instance.displayname = displayname
await self.client.set_displayname(self.displayname)
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
if not avatar_url or avatar_url == self.avatar_url:
return
self.db_instance.avatar_url = avatar_url
await self.client.set_avatar_url(self.avatar_url)
async def update_access_details(self, access_token: str, homeserver: str) -> None:
if not access_token and not homeserver:
return
elif access_token == self.access_token and homeserver == self.homeserver:
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, log=self.log)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError("MXID mismatch")
new_client.store = self.db_instance
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.start_sync()
# region Properties # region Properties
@property @property
@ -172,11 +219,6 @@ class Client:
def access_token(self) -> str: def access_token(self) -> str:
return self.db_instance.access_token return self.db_instance.access_token
@access_token.setter
def access_token(self, value: str) -> None:
self.client.api.token = value
self.db_instance.access_token = value
@property @property
def enabled(self) -> bool: def enabled(self) -> bool:
return self.db_instance.enabled return self.db_instance.enabled
@ -189,25 +231,24 @@ class Client:
def next_batch(self) -> SyncToken: def next_batch(self) -> SyncToken:
return self.db_instance.next_batch return self.db_instance.next_batch
@next_batch.setter
def next_batch(self, value: SyncToken) -> None:
self.db_instance.next_batch = value
@property @property
def filter_id(self) -> FilterID: def filter_id(self) -> FilterID:
return self.db_instance.filter_id return self.db_instance.filter_id
@filter_id.setter
def filter_id(self, value: FilterID) -> None:
self.db_instance.filter_id = value
@property @property
def sync(self) -> bool: def sync(self) -> bool:
return self.db_instance.sync return self.db_instance.sync
@sync.setter @sync.setter
def sync(self, value: bool) -> None: def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property @property
def autojoin(self) -> bool: def autojoin(self) -> bool:
@ -227,18 +268,10 @@ class Client:
def displayname(self) -> str: def displayname(self) -> str:
return self.db_instance.displayname return self.db_instance.displayname
@displayname.setter
def displayname(self, value: str) -> None:
self.db_instance.displayname = value
@property @property
def avatar_url(self) -> ContentURI: def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url return self.db_instance.avatar_url
@avatar_url.setter
def avatar_url(self, value: ContentURI) -> None:
self.db_instance.avatar_url = value
# endregion # endregion

View File

@ -17,15 +17,20 @@ from json import JSONDecodeError
from aiohttp import web from aiohttp import web
from mautrix.types import UserID from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient
from ...db import DBClient
from ...client import Client from ...client import Client
from .base import routes from .base import routes
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse,
ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch,
ErrUserExists)
@routes.get("/clients") @routes.get("/clients")
async def get_clients(request: web.Request) -> web.Response: async def get_clients(_: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()]) return web.json_response([client.to_dict() for client in Client.cache.values()])
@ -39,17 +44,59 @@ async def get_client(request: web.Request) -> web.Response:
async def create_client(user_id: UserID, data: dict) -> web.Response: async def create_client(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
new_client = MatrixClient(base_url=homeserver, token=access_token, loop=Client.loop,
client_session=Client.http_client)
try:
mxid = await new_client.whoami()
except MatrixInvalidToken:
return ErrBadClientAccessToken
except MatrixRequestError:
return ErrBadClientAccessDetails
if user_id == "new":
existing_client = Client.get(mxid, None)
if existing_client is not None:
return ErrUserExists
elif mxid != user_id:
return ErrMXIDMismatch
db_instance = DBClient(id=user_id, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
client = Client(db_instance)
Client.db.add(db_instance)
Client.db.commit()
await client.start()
return web.json_response(client.to_dict())
async def update_client(client: Client, data: dict) -> web.Response: async def update_client(client: Client, data: dict) -> web.Response:
return ErrNotImplemented try:
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None))
except MatrixInvalidToken:
return ErrBadClientAccessToken
except MatrixRequestError:
return ErrBadClientAccessDetails
except ValueError:
return ErrMXIDMismatch
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
await client.update_started(data.get("started", None))
client.enabled = data.get("enabled", client.enabled)
client.autojoin = data.get("autojoin", client.autojoin)
client.sync = data.get("sync", client.sync)
return web.json_response(client.to_dict())
@routes.put("/client/{id}") @routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response: async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = Client.get(user_id, None) # /client/new always creates a new client
client = Client.get(user_id, None) if user_id != "new" else None
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
@ -66,4 +113,9 @@ async def delete_client(request: web.Request) -> web.Response:
client = Client.get(user_id, None) client = Client.get(user_id, None)
if not client: if not client:
return ErrClientNotFound return ErrClientNotFound
return ErrNotImplemented if len(client.references) > 0:
return ErrClientInUse
if client.started:
await client.stop()
client.delete()
return RespDeleted

View File

@ -16,6 +16,36 @@
from http import HTTPStatus from http import HTTPStatus
from aiohttp import web from aiohttp import web
ErrBodyNotJSON = web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
ErrPluginTypeRequired = web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessToken = web.json_response({
"error": "Invalid access token",
"errcode": "bad_client_access_token",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessDetails = web.json_response({
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details"
}, status=HTTPStatus.BAD_REQUEST)
ErrMXIDMismatch = web.json_response({
"error": "The Matrix user ID of the client and the user ID of the access token don't match",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadAuth = web.json_response({ ErrBadAuth = web.json_response({
"error": "Invalid username or password", "error": "Invalid username or password",
"errcode": "invalid_auth", "errcode": "invalid_auth",
@ -56,16 +86,6 @@ ErrPluginTypeNotFound = web.json_response({
"errcode": "plugin_type_not_found", "errcode": "plugin_type_not_found",
}, status=HTTPStatus.NOT_FOUND) }, status=HTTPStatus.NOT_FOUND)
ErrPluginTypeRequired = web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPathNotFound = web.json_response({ ErrPathNotFound = web.json_response({
"error": "Resource not found", "error": "Resource not found",
"errcode": "resource_not_found", "errcode": "resource_not_found",
@ -76,15 +96,20 @@ ErrMethodNotAllowed = web.json_response({
"errcode": "method_not_allowed", "errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED) }, status=HTTPStatus.METHOD_NOT_ALLOWED)
ErrUserExists = web.json_response({
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
}, status=HTTPStatus.CONFLICT)
ErrPluginInUse = web.json_response({ ErrPluginInUse = web.json_response({
"error": "Plugin instances of this type still exist", "error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use", "errcode": "plugin_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED) }, status=HTTPStatus.PRECONDITION_FAILED)
ErrBodyNotJSON = web.json_response({ ErrClientInUse = web.json_response({
"error": "Request body is not JSON", "error": "Plugin instances with this client as their primary user still exist",
"errcode": "body_not_json", "errcode": "client_in_use",
}, status=HTTPStatus.BAD_REQUEST) }, status=HTTPStatus.PRECONDITION_FAILED)
def plugin_import_error(error: str, stacktrace: str) -> web.Response: def plugin_import_error(error: str, stacktrace: str) -> web.Response: