Implement client API

This commit is contained in:
Tulir Asokan 2018-11-01 23:31:30 +02:00
parent bc87b2a02b
commit 383c9ce5ec
4 changed files with 163 additions and 50 deletions

View File

@ -74,18 +74,21 @@ try:
log.info("Starting server")
loop.run_until_complete(server.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
if periodic_commit_task is not None:
periodic_commit_task.cancel()
for client in Client.cache.values():
client.stop()
log.debug("Stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop))
db_session.commit()
log.debug("Stopping server")
loop.run_until_complete(server.stop())
log.debug("Closing event loop")
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)

View File

@ -92,7 +92,7 @@ class Client:
self.db_instance.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
self.db_instance.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
@ -122,10 +122,19 @@ class Client:
def stop_sync(self) -> None:
self.client.stop()
def stop(self) -> None:
async def stop(self) -> None:
if self.started:
self.started = False
await self.stop_plugins()
self.stop_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
def to_dict(self) -> dict:
return {
"id": self.id,
@ -158,6 +167,44 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
async def update_started(self, started: bool) -> None:
if started is None or started == self.started:
return
if started:
await self.start()
else:
await self.stop()
async def update_displayname(self, displayname: str) -> None:
if not displayname or displayname == self.displayname:
return
self.db_instance.displayname = displayname
await self.client.set_displayname(self.displayname)
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
if not avatar_url or avatar_url == self.avatar_url:
return
self.db_instance.avatar_url = avatar_url
await self.client.set_avatar_url(self.avatar_url)
async def update_access_details(self, access_token: str, homeserver: str) -> None:
if not access_token and not homeserver:
return
elif access_token == self.access_token and homeserver == self.homeserver:
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, log=self.log)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError("MXID mismatch")
new_client.store = self.db_instance
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.start_sync()
# region Properties
@property
@ -172,11 +219,6 @@ class Client:
def access_token(self) -> str:
return self.db_instance.access_token
@access_token.setter
def access_token(self, value: str) -> None:
self.client.api.token = value
self.db_instance.access_token = value
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@ -189,25 +231,24 @@ class Client:
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@next_batch.setter
def next_batch(self, value: SyncToken) -> None:
self.db_instance.next_batch = value
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@filter_id.setter
def filter_id(self, value: FilterID) -> None:
self.db_instance.filter_id = value
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property
def autojoin(self) -> bool:
@ -227,18 +268,10 @@ class Client:
def displayname(self) -> str:
return self.db_instance.displayname
@displayname.setter
def displayname(self, value: str) -> None:
self.db_instance.displayname = value
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
@avatar_url.setter
def avatar_url(self, value: ContentURI) -> None:
self.db_instance.avatar_url = value
# endregion

View File

@ -17,15 +17,20 @@ from json import JSONDecodeError
from aiohttp import web
from mautrix.types import UserID
from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient
from ...db import DBClient
from ...client import Client
from .base import routes
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse,
ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch,
ErrUserExists)
@routes.get("/clients")
async def get_clients(request: web.Request) -> web.Response:
async def get_clients(_: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()])
@ -39,17 +44,59 @@ async def get_client(request: web.Request) -> web.Response:
async def create_client(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented
homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
new_client = MatrixClient(base_url=homeserver, token=access_token, loop=Client.loop,
client_session=Client.http_client)
try:
mxid = await new_client.whoami()
except MatrixInvalidToken:
return ErrBadClientAccessToken
except MatrixRequestError:
return ErrBadClientAccessDetails
if user_id == "new":
existing_client = Client.get(mxid, None)
if existing_client is not None:
return ErrUserExists
elif mxid != user_id:
return ErrMXIDMismatch
db_instance = DBClient(id=user_id, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
client = Client(db_instance)
Client.db.add(db_instance)
Client.db.commit()
await client.start()
return web.json_response(client.to_dict())
async def update_client(client: Client, data: dict) -> web.Response:
return ErrNotImplemented
try:
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None))
except MatrixInvalidToken:
return ErrBadClientAccessToken
except MatrixRequestError:
return ErrBadClientAccessDetails
except ValueError:
return ErrMXIDMismatch
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
await client.update_started(data.get("started", None))
client.enabled = data.get("enabled", client.enabled)
client.autojoin = data.get("autojoin", client.autojoin)
client.sync = data.get("sync", client.sync)
return web.json_response(client.to_dict())
@routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
# /client/new always creates a new client
client = Client.get(user_id, None) if user_id != "new" else None
try:
data = await request.json()
except JSONDecodeError:
@ -66,4 +113,9 @@ async def delete_client(request: web.Request) -> web.Response:
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return ErrNotImplemented
if len(client.references) > 0:
return ErrClientInUse
if client.started:
await client.stop()
client.delete()
return RespDeleted

View File

@ -16,6 +16,36 @@
from http import HTTPStatus
from aiohttp import web
ErrBodyNotJSON = web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
ErrPluginTypeRequired = web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessToken = web.json_response({
"error": "Invalid access token",
"errcode": "bad_client_access_token",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadClientAccessDetails = web.json_response({
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details"
}, status=HTTPStatus.BAD_REQUEST)
ErrMXIDMismatch = web.json_response({
"error": "The Matrix user ID of the client and the user ID of the access token don't match",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
ErrBadAuth = web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
@ -56,16 +86,6 @@ ErrPluginTypeNotFound = web.json_response({
"errcode": "plugin_type_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPluginTypeRequired = web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPrimaryUserRequired = web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
@ -76,15 +96,20 @@ ErrMethodNotAllowed = web.json_response({
"errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
ErrUserExists = web.json_response({
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
}, status=HTTPStatus.CONFLICT)
ErrPluginInUse = web.json_response({
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
ErrBodyNotJSON = web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
ErrClientInUse = web.json_response({
"error": "Plugin instances with this client as their primary user still exist",
"errcode": "client_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
def plugin_import_error(error: str, stacktrace: str) -> web.Response: