diff --git a/maubot/__main__.py b/maubot/__main__.py index 6cf22a9..5d5daec 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -57,7 +57,7 @@ Base.metadata.create_all() loop = asyncio.get_event_loop() init_db(db_session) -init_client(loop) +clients = init_client(loop) init_plugin_instance_class(db_session, config) management_api = init_management(config, loop) server = MaubotServer(config, management_api, loop) @@ -84,9 +84,10 @@ async def periodic_commit(): try: - loop.run_until_complete(asyncio.gather( - server.start(), - *[plugin.start() for plugin in plugins])) + log.debug("Starting server") + loop.run_until_complete(server.start()) + log.debug("Starting clients and plugins") + loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) log.debug("Startup actions complete, running forever") loop.run_until_complete(periodic_commit()) loop.run_forever() diff --git a/maubot/__meta__.py b/maubot/__meta__.py index 32cb26d..e305370 100644 --- a/maubot/__meta__.py +++ b/maubot/__meta__.py @@ -1 +1 @@ -__version__ = "0.1.0.dev5" +__version__ = "0.1.0.dev6" diff --git a/maubot/client.py b/maubot/client.py index e96c918..9f4fcc8 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -18,6 +18,7 @@ from aiohttp import ClientSession import asyncio import logging +from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, EventType, Filter, RoomFilter, RoomEventFilter) @@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client") class Client: + log: logging.Logger loop: asyncio.AbstractEventLoop cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None @@ -38,42 +40,97 @@ class Client: references: Set['PluginInstance'] db_instance: DBClient client: MaubotMatrixClient + started: bool def __init__(self, db_instance: DBClient) -> None: self.db_instance = db_instance self.cache[self.id] = self self.log = log.getChild(self.id) self.references = set() + self.started = False self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, token=self.access_token, client_session=self.http_client, log=self.log, loop=self.loop, store=self.db_instance) if self.autojoin: self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) - def start(self) -> None: - asyncio.ensure_future(self._start(), loop=self.loop) - - async def _start(self) -> None: + async def start(self, try_n: Optional[int] = 0) -> None: try: - if not self.filter_id: - self.filter_id = await self.client.create_filter(Filter( - room=RoomFilter( - timeline=RoomEventFilter( - limit=50, - ), - ), - )) - if self.displayname != "disable": - await self.client.set_displayname(self.displayname) - if self.avatar_url != "disable": - await self.client.set_avatar_url(self.avatar_url) - await self.client.start(self.filter_id) + if try_n > 0: + await asyncio.sleep(try_n * 10) + await self._start(try_n) except Exception: - self.log.exception("starting raised exception") + self.log.exception("Failed to start") + + async def _start(self, try_n: Optional[int] = 0) -> None: + if not self.enabled: + self.log.debug("Not starting disabled client") + return + elif self.started: + self.log.warning("Ignoring start() call to started client") + return + try: + user_id = await self.client.whoami() + except MatrixInvalidToken as e: + self.log.error(f"Invalid token: {e}. Disabling client") + self.enabled = False + return + except MatrixRequestError: + if try_n >= 5: + self.log.exception("Failed to get /account/whoami, disabling client") + self.enabled = False + else: + self.log.exception(f"Failed to get /account/whoami, " + f"retrying in {(try_n + 1) * 10}s") + _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) + return + if user_id != self.id: + self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") + self.enabled = False + return + if not self.filter_id: + self.filter_id = await self.client.create_filter(Filter( + room=RoomFilter( + timeline=RoomEventFilter( + limit=50, + ), + ), + )) + if self.displayname != "disable": + await self.client.set_displayname(self.displayname) + if self.avatar_url != "disable": + await self.client.set_avatar_url(self.avatar_url) + if self.sync: + self.client.start(self.filter_id) + self.started = True + self.log.info("Client started, starting plugin instances...") + await self.start_plugins() + + async def start_plugins(self) -> None: + await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop) + + async def stop_plugins(self) -> None: + await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running], + loop=self.loop) def stop(self) -> None: + self.started = False self.client.stop() + def to_dict(self) -> dict: + return { + "id": self.id, + "homeserver": self.homeserver, + "access_token": self.access_token, + "enabled": self.enabled, + "started": self.started, + "sync": self.sync, + "autojoin": self.autojoin, + "displayname": self.displayname, + "avatar_url": self.avatar_url, + "instances": [instance.to_dict() for instance in self.references], + } + @classmethod def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: try: @@ -111,6 +168,14 @@ class Client: self.client.api.token = value self.db_instance.access_token = value + @property + def enabled(self) -> bool: + return self.db_instance.enabled + + @enabled.setter + def enabled(self, value: bool) -> None: + self.db_instance.enabled = value + @property def next_batch(self) -> SyncToken: return self.db_instance.next_batch @@ -168,8 +233,7 @@ class Client: # endregion -def init(loop: asyncio.AbstractEventLoop) -> None: +def init(loop: asyncio.AbstractEventLoop) -> List[Client]: Client.http_client = ClientSession(loop=loop) Client.loop = loop - for client in Client.all(): - client.start() + return Client.all() diff --git a/maubot/db.py b/maubot/db.py index 6210bee..d53658c 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -42,6 +42,7 @@ class DBClient(Base): id: UserID = Column(String(255), primary_key=True) homeserver: str = Column(String(255), nullable=False) access_token: str = Column(String(255), nullable=False) + enabled: bool = Column(Boolean, nullable=False, default=False) next_batch: SyncToken = Column(String(255), nullable=False, default="") filter_id: FilterID = Column(String(255), nullable=False, default="") diff --git a/maubot/instance.py b/maubot/instance.py index 71098fc..83bb85c 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -75,6 +75,7 @@ class PluginInstance: if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") self.enabled = False + return self.log.debug("Plugin instance dependencies loaded") self.loader.references.add(self) self.client.references.add(self) @@ -93,8 +94,11 @@ class PluginInstance: self.db_instance.config = buf.getvalue() async def start(self) -> None: - if not self.enabled: - self.log.warning(f"Plugin disabled, not starting.") + if self.running: + self.log.warning("Ignoring start() call to already started plugin") + return + elif not self.enabled: + self.log.warning("Plugin disabled, not starting.") return cls = await self.loader.load() config_class = cls.get_config_class() @@ -118,6 +122,9 @@ class PluginInstance: f"with user {self.client.id}") async def stop(self) -> None: + if not self.running: + self.log.warning("Ignoring stop() call to non-running plugin") + return self.log.debug("Stopping plugin instance...") self.running = False await self.plugin.stop() diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index 09d0f75..4dbb299 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -47,6 +47,7 @@ class PluginLoader(ABC): return { "id": self.id, "version": self.version, + "instances": [instance.to_dict() for instance in self.references], } @property diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index 8baa9f9..064c3f8 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -14,26 +14,56 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from aiohttp import web +from json import JSONDecodeError +from mautrix.types import UserID + +from ...db import DBClient +from ...client import Client from .base import routes -from .responses import ErrNotImplemented +from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON @routes.get("/instances") -def get_instances(request: web.Request) -> web.Response: - return ErrNotImplemented +async def get_instances(_: web.Request) -> web.Response: + return web.json_response([client.to_dict() for client in Client.cache.values()]) @routes.get("/instance/{id}") -def get_instance(request: web.Request) -> web.Response: +async def get_instance(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + if not client: + return ErrClientNotFound + return web.json_response(client.to_dict()) + + +async def create_instance(user_id: UserID, data: dict) -> web.Response: + return ErrNotImplemented + + +async def update_instance(client: Client, data: dict) -> web.Response: return ErrNotImplemented @routes.put("/instance/{id}") -def update_instance(request: web.Request) -> web.Response: - return ErrNotImplemented +async def update_instance(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + try: + data = await request.json() + except JSONDecodeError: + return ErrBodyNotJSON + if not client: + return await create_instance(user_id, data) + else: + return await update_instance(client, data) @routes.delete("/instance/{id}") -def delete_instance(request: web.Request) -> web.Response: +async def delete_instance(request: web.Request) -> web.Response: + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) + if not client: + return ErrClientNotFound return ErrNotImplemented diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index 4bdcbbd..d07030c 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -24,16 +24,9 @@ from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error, from .base import routes, get_config -def _plugin_to_dict(plugin: PluginLoader) -> dict: - return { - **plugin.to_dict(), - "instances": [instance.to_dict() for instance in plugin.references] - } - - @routes.get("/plugins") async def get_plugins(_) -> web.Response: - return web.json_response([_plugin_to_dict(plugin) for plugin in PluginLoader.id_cache.values()]) + return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()]) @routes.get("/plugin/{id}") @@ -42,7 +35,7 @@ async def get_plugin(request: web.Request) -> web.Response: plugin = PluginLoader.id_cache.get(plugin_id, None) if not plugin: return ErrPluginNotFound - return web.json_response(_plugin_to_dict(plugin)) + return web.json_response(plugin.to_dict()) @routes.delete("/plugin/{id}") @@ -78,11 +71,11 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo with open(path, "wb") as p: p.write(content) try: - ZippedPluginLoader.get(path) + plugin = ZippedPluginLoader.get(path) except MaubotZipImportError as e: ZippedPluginLoader.trash(path) return plugin_import_error(str(e), traceback.format_exc()) - return RespOK + return web.json_response(plugin.to_dict()) async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str @@ -110,7 +103,7 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, return plugin_import_error(str(e), traceback.format_exc()) await plugin.start_instances() ZippedPluginLoader.trash(old_path, reason="update") - return RespOK + return web.json_response(plugin.to_dict()) @routes.post("/plugins/upload") diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 8b3d1c0..ad3033b 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -36,6 +36,11 @@ ErrPluginNotFound = web.json_response({ "errcode": "plugin_not_found", }, status=HTTPStatus.NOT_FOUND) +ErrClientNotFound = web.json_response({ + "error": "Client not found", + "errcode": "client_not_found", +}, status=HTTPStatus.NOT_FOUND) + ErrPathNotFound = web.json_response({ "error": "Resource not found", "errcode": "resource_not_found", diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml index db9e483..8cfde21 100644 --- a/maubot/management/api/spec.yaml +++ b/maubot/management/api/spec.yaml @@ -231,12 +231,14 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/MatrixClientList' + type: array + items: + $ref: '#/components/schemas/MatrixClient' 401: $ref: '#/components/responses/Unauthorized' - '/client/{user_id}': + '/client/{id}': parameters: - - name: user_id + - name: id in: path description: The Matrix user ID of the client to get required: true @@ -338,38 +340,12 @@ components: enabled: type: boolean example: true + started: + type: boolean + example: true primary_user: type: string example: '@putkiteippi:maunium.net' - MatrixClientList: - type: array - items: - type: object - properties: - id: - type: string - example: '@putkiteippi:maunium.net' - homeserver: - type: string - example: 'https://maunium.net' - enabled: - type: boolean - example: true - sync: - type: boolean - example: true - autojoin: - type: boolean - example: true - displayname: - type: string - example: J. E. Saarinen - avatar_url: - type: string - example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU' - instance_count: - type: integer - example: 1 MatrixClient: type: object properties: @@ -385,6 +361,9 @@ components: enabled: type: boolean example: true + started: + type: boolean + example: true sync: type: boolean example: true