Implement client API
This commit is contained in:
parent
bc87b2a02b
commit
383c9ce5ec
@ -74,18 +74,21 @@ try:
|
|||||||
log.info("Starting server")
|
log.info("Starting server")
|
||||||
loop.run_until_complete(server.start())
|
loop.run_until_complete(server.start())
|
||||||
log.info("Starting clients and plugins")
|
log.info("Starting clients and plugins")
|
||||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
|
||||||
log.info("Startup actions complete, running forever")
|
log.info("Startup actions complete, running forever")
|
||||||
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
|
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
|
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
|
||||||
if periodic_commit_task is not None:
|
if periodic_commit_task is not None:
|
||||||
periodic_commit_task.cancel()
|
periodic_commit_task.cancel()
|
||||||
for client in Client.cache.values():
|
log.debug("Stopping clients")
|
||||||
client.stop()
|
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
|
||||||
|
loop=loop))
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
log.debug("Stopping server")
|
||||||
loop.run_until_complete(server.stop())
|
loop.run_until_complete(server.stop())
|
||||||
|
log.debug("Closing event loop")
|
||||||
loop.close()
|
loop.close()
|
||||||
log.debug("Everything stopped, shutting down")
|
log.debug("Everything stopped, shutting down")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -92,7 +92,7 @@ class Client:
|
|||||||
self.db_instance.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
if not self.filter_id:
|
if not self.filter_id:
|
||||||
self.filter_id = await self.client.create_filter(Filter(
|
self.db_instance.filter_id = await self.client.create_filter(Filter(
|
||||||
room=RoomFilter(
|
room=RoomFilter(
|
||||||
timeline=RoomEventFilter(
|
timeline=RoomEventFilter(
|
||||||
limit=50,
|
limit=50,
|
||||||
@ -122,10 +122,19 @@ class Client:
|
|||||||
def stop_sync(self) -> None:
|
def stop_sync(self) -> None:
|
||||||
self.client.stop()
|
self.client.stop()
|
||||||
|
|
||||||
def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
|
if self.started:
|
||||||
self.started = False
|
self.started = False
|
||||||
|
await self.stop_plugins()
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
try:
|
||||||
|
del self.cache[self.id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
self.db.delete(self.db_instance)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@ -158,6 +167,44 @@ class Client:
|
|||||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||||
await self.client.join_room(evt.room_id)
|
await self.client.join_room(evt.room_id)
|
||||||
|
|
||||||
|
async def update_started(self, started: bool) -> None:
|
||||||
|
if started is None or started == self.started:
|
||||||
|
return
|
||||||
|
if started:
|
||||||
|
await self.start()
|
||||||
|
else:
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
async def update_displayname(self, displayname: str) -> None:
|
||||||
|
if not displayname or displayname == self.displayname:
|
||||||
|
return
|
||||||
|
self.db_instance.displayname = displayname
|
||||||
|
await self.client.set_displayname(self.displayname)
|
||||||
|
|
||||||
|
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
||||||
|
if not avatar_url or avatar_url == self.avatar_url:
|
||||||
|
return
|
||||||
|
self.db_instance.avatar_url = avatar_url
|
||||||
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
|
|
||||||
|
async def update_access_details(self, access_token: str, homeserver: str) -> None:
|
||||||
|
if not access_token and not homeserver:
|
||||||
|
return
|
||||||
|
elif access_token == self.access_token and homeserver == self.homeserver:
|
||||||
|
return
|
||||||
|
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||||
|
token=access_token or self.access_token, loop=self.loop,
|
||||||
|
client_session=self.http_client, log=self.log)
|
||||||
|
mxid = await new_client.whoami()
|
||||||
|
if mxid != self.id:
|
||||||
|
raise ValueError("MXID mismatch")
|
||||||
|
new_client.store = self.db_instance
|
||||||
|
self.stop_sync()
|
||||||
|
self.client = new_client
|
||||||
|
self.db_instance.homeserver = homeserver
|
||||||
|
self.db_instance.access_token = access_token
|
||||||
|
self.start_sync()
|
||||||
|
|
||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -172,11 +219,6 @@ class Client:
|
|||||||
def access_token(self) -> str:
|
def access_token(self) -> str:
|
||||||
return self.db_instance.access_token
|
return self.db_instance.access_token
|
||||||
|
|
||||||
@access_token.setter
|
|
||||||
def access_token(self, value: str) -> None:
|
|
||||||
self.client.api.token = value
|
|
||||||
self.db_instance.access_token = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
return self.db_instance.enabled
|
return self.db_instance.enabled
|
||||||
@ -189,25 +231,24 @@ class Client:
|
|||||||
def next_batch(self) -> SyncToken:
|
def next_batch(self) -> SyncToken:
|
||||||
return self.db_instance.next_batch
|
return self.db_instance.next_batch
|
||||||
|
|
||||||
@next_batch.setter
|
|
||||||
def next_batch(self, value: SyncToken) -> None:
|
|
||||||
self.db_instance.next_batch = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filter_id(self) -> FilterID:
|
def filter_id(self) -> FilterID:
|
||||||
return self.db_instance.filter_id
|
return self.db_instance.filter_id
|
||||||
|
|
||||||
@filter_id.setter
|
|
||||||
def filter_id(self, value: FilterID) -> None:
|
|
||||||
self.db_instance.filter_id = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sync(self) -> bool:
|
def sync(self) -> bool:
|
||||||
return self.db_instance.sync
|
return self.db_instance.sync
|
||||||
|
|
||||||
@sync.setter
|
@sync.setter
|
||||||
def sync(self, value: bool) -> None:
|
def sync(self, value: bool) -> None:
|
||||||
|
if value == self.db_instance.sync:
|
||||||
|
return
|
||||||
self.db_instance.sync = value
|
self.db_instance.sync = value
|
||||||
|
if self.started:
|
||||||
|
if value:
|
||||||
|
self.start_sync()
|
||||||
|
else:
|
||||||
|
self.stop_sync()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def autojoin(self) -> bool:
|
def autojoin(self) -> bool:
|
||||||
@ -227,18 +268,10 @@ class Client:
|
|||||||
def displayname(self) -> str:
|
def displayname(self) -> str:
|
||||||
return self.db_instance.displayname
|
return self.db_instance.displayname
|
||||||
|
|
||||||
@displayname.setter
|
|
||||||
def displayname(self, value: str) -> None:
|
|
||||||
self.db_instance.displayname = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def avatar_url(self) -> ContentURI:
|
def avatar_url(self) -> ContentURI:
|
||||||
return self.db_instance.avatar_url
|
return self.db_instance.avatar_url
|
||||||
|
|
||||||
@avatar_url.setter
|
|
||||||
def avatar_url(self, value: ContentURI) -> None:
|
|
||||||
self.db_instance.avatar_url = value
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,15 +17,20 @@ from json import JSONDecodeError
|
|||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID, SyncToken, FilterID
|
||||||
|
from mautrix.errors import MatrixRequestError, MatrixInvalidToken
|
||||||
|
from mautrix.client import Client as MatrixClient
|
||||||
|
|
||||||
|
from ...db import DBClient
|
||||||
from ...client import Client
|
from ...client import Client
|
||||||
from .base import routes
|
from .base import routes
|
||||||
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
from .responses import (RespDeleted, ErrClientNotFound, ErrBodyNotJSON, ErrClientInUse,
|
||||||
|
ErrBadClientAccessToken, ErrBadClientAccessDetails, ErrMXIDMismatch,
|
||||||
|
ErrUserExists)
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/clients")
|
@routes.get("/clients")
|
||||||
async def get_clients(request: web.Request) -> web.Response:
|
async def get_clients(_: web.Request) -> web.Response:
|
||||||
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
||||||
|
|
||||||
|
|
||||||
@ -39,17 +44,59 @@ async def get_client(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
|
|
||||||
async def create_client(user_id: UserID, data: dict) -> web.Response:
|
async def create_client(user_id: UserID, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
homeserver = data.get("homeserver", None)
|
||||||
|
access_token = data.get("access_token", None)
|
||||||
|
new_client = MatrixClient(base_url=homeserver, token=access_token, loop=Client.loop,
|
||||||
|
client_session=Client.http_client)
|
||||||
|
try:
|
||||||
|
mxid = await new_client.whoami()
|
||||||
|
except MatrixInvalidToken:
|
||||||
|
return ErrBadClientAccessToken
|
||||||
|
except MatrixRequestError:
|
||||||
|
return ErrBadClientAccessDetails
|
||||||
|
if user_id == "new":
|
||||||
|
existing_client = Client.get(mxid, None)
|
||||||
|
if existing_client is not None:
|
||||||
|
return ErrUserExists
|
||||||
|
elif mxid != user_id:
|
||||||
|
return ErrMXIDMismatch
|
||||||
|
db_instance = DBClient(id=user_id, homeserver=homeserver, access_token=access_token,
|
||||||
|
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||||
|
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||||
|
autojoin=data.get("autojoin", True),
|
||||||
|
displayname=data.get("displayname", ""),
|
||||||
|
avatar_url=data.get("avatar_url", ""))
|
||||||
|
client = Client(db_instance)
|
||||||
|
Client.db.add(db_instance)
|
||||||
|
Client.db.commit()
|
||||||
|
await client.start()
|
||||||
|
return web.json_response(client.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def update_client(client: Client, data: dict) -> web.Response:
|
async def update_client(client: Client, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
try:
|
||||||
|
await client.update_access_details(data.get("access_token", None),
|
||||||
|
data.get("homeserver", None))
|
||||||
|
except MatrixInvalidToken:
|
||||||
|
return ErrBadClientAccessToken
|
||||||
|
except MatrixRequestError:
|
||||||
|
return ErrBadClientAccessDetails
|
||||||
|
except ValueError:
|
||||||
|
return ErrMXIDMismatch
|
||||||
|
await client.update_avatar_url(data.get("avatar_url", None))
|
||||||
|
await client.update_displayname(data.get("displayname", None))
|
||||||
|
await client.update_started(data.get("started", None))
|
||||||
|
client.enabled = data.get("enabled", client.enabled)
|
||||||
|
client.autojoin = data.get("autojoin", client.autojoin)
|
||||||
|
client.sync = data.get("sync", client.sync)
|
||||||
|
return web.json_response(client.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@routes.put("/client/{id}")
|
@routes.put("/client/{id}")
|
||||||
async def update_client(request: web.Request) -> web.Response:
|
async def update_client(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info.get("id", None)
|
||||||
client = Client.get(user_id, None)
|
# /client/new always creates a new client
|
||||||
|
client = Client.get(user_id, None) if user_id != "new" else None
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
@ -66,4 +113,9 @@ async def delete_client(request: web.Request) -> web.Response:
|
|||||||
client = Client.get(user_id, None)
|
client = Client.get(user_id, None)
|
||||||
if not client:
|
if not client:
|
||||||
return ErrClientNotFound
|
return ErrClientNotFound
|
||||||
return ErrNotImplemented
|
if len(client.references) > 0:
|
||||||
|
return ErrClientInUse
|
||||||
|
if client.started:
|
||||||
|
await client.stop()
|
||||||
|
client.delete()
|
||||||
|
return RespDeleted
|
||||||
|
@ -16,6 +16,36 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
ErrBodyNotJSON = web.json_response({
|
||||||
|
"error": "Request body is not JSON",
|
||||||
|
"errcode": "body_not_json",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrPluginTypeRequired = web.json_response({
|
||||||
|
"error": "Plugin type is required when creating plugin instances",
|
||||||
|
"errcode": "plugin_type_required",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrPrimaryUserRequired = web.json_response({
|
||||||
|
"error": "Primary user is required when creating plugin instances",
|
||||||
|
"errcode": "primary_user_required",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrBadClientAccessToken = web.json_response({
|
||||||
|
"error": "Invalid access token",
|
||||||
|
"errcode": "bad_client_access_token",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrBadClientAccessDetails = web.json_response({
|
||||||
|
"error": "Invalid homeserver or access token",
|
||||||
|
"errcode": "bad_client_access_details"
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrMXIDMismatch = web.json_response({
|
||||||
|
"error": "The Matrix user ID of the client and the user ID of the access token don't match",
|
||||||
|
"errcode": "mxid_mismatch",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
ErrBadAuth = web.json_response({
|
ErrBadAuth = web.json_response({
|
||||||
"error": "Invalid username or password",
|
"error": "Invalid username or password",
|
||||||
"errcode": "invalid_auth",
|
"errcode": "invalid_auth",
|
||||||
@ -56,16 +86,6 @@ ErrPluginTypeNotFound = web.json_response({
|
|||||||
"errcode": "plugin_type_not_found",
|
"errcode": "plugin_type_not_found",
|
||||||
}, status=HTTPStatus.NOT_FOUND)
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
ErrPluginTypeRequired = web.json_response({
|
|
||||||
"error": "Plugin type is required when creating plugin instances",
|
|
||||||
"errcode": "plugin_type_required",
|
|
||||||
}, status=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
ErrPrimaryUserRequired = web.json_response({
|
|
||||||
"error": "Primary user is required when creating plugin instances",
|
|
||||||
"errcode": "primary_user_required",
|
|
||||||
}, status=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
ErrPathNotFound = web.json_response({
|
ErrPathNotFound = web.json_response({
|
||||||
"error": "Resource not found",
|
"error": "Resource not found",
|
||||||
"errcode": "resource_not_found",
|
"errcode": "resource_not_found",
|
||||||
@ -76,15 +96,20 @@ ErrMethodNotAllowed = web.json_response({
|
|||||||
"errcode": "method_not_allowed",
|
"errcode": "method_not_allowed",
|
||||||
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
|
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||||
|
|
||||||
|
ErrUserExists = web.json_response({
|
||||||
|
"error": "There is already a client with the user ID of that token",
|
||||||
|
"errcode": "user_exists",
|
||||||
|
}, status=HTTPStatus.CONFLICT)
|
||||||
|
|
||||||
ErrPluginInUse = web.json_response({
|
ErrPluginInUse = web.json_response({
|
||||||
"error": "Plugin instances of this type still exist",
|
"error": "Plugin instances of this type still exist",
|
||||||
"errcode": "plugin_in_use",
|
"errcode": "plugin_in_use",
|
||||||
}, status=HTTPStatus.PRECONDITION_FAILED)
|
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||||
|
|
||||||
ErrBodyNotJSON = web.json_response({
|
ErrClientInUse = web.json_response({
|
||||||
"error": "Request body is not JSON",
|
"error": "Plugin instances with this client as their primary user still exist",
|
||||||
"errcode": "body_not_json",
|
"errcode": "client_in_use",
|
||||||
}, status=HTTPStatus.BAD_REQUEST)
|
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||||
|
|
||||||
|
|
||||||
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||||
|
Loading…
Reference in New Issue
Block a user