Refactor how plugins are started and update spec

This commit is contained in:
Tulir Asokan 2018-11-01 01:51:54 +02:00
parent b96d6e6a94
commit 9e066478a9
10 changed files with 160 additions and 79 deletions

View File

@ -57,7 +57,7 @@ Base.metadata.create_all()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
init_db(db_session) init_db(db_session)
init_client(loop) clients = init_client(loop)
init_plugin_instance_class(db_session, config) init_plugin_instance_class(db_session, config)
management_api = init_management(config, loop) management_api = init_management(config, loop)
server = MaubotServer(config, management_api, loop) server = MaubotServer(config, management_api, loop)
@ -84,9 +84,10 @@ async def periodic_commit():
try: try:
loop.run_until_complete(asyncio.gather( log.debug("Starting server")
server.start(), loop.run_until_complete(server.start())
*[plugin.start() for plugin in plugins])) log.debug("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.debug("Startup actions complete, running forever") log.debug("Startup actions complete, running forever")
loop.run_until_complete(periodic_commit()) loop.run_until_complete(periodic_commit())
loop.run_forever() loop.run_forever()

View File

@ -1 +1 @@
__version__ = "0.1.0.dev5" __version__ = "0.1.0.dev6"

View File

@ -18,6 +18,7 @@ from aiohttp import ClientSession
import asyncio import asyncio
import logging import logging
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter) EventType, Filter, RoomFilter, RoomEventFilter)
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
class Client: class Client:
log: logging.Logger
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
@ -38,23 +40,54 @@ class Client:
references: Set['PluginInstance'] references: Set['PluginInstance']
db_instance: DBClient db_instance: DBClient
client: MaubotMatrixClient client: MaubotMatrixClient
started: bool
def __init__(self, db_instance: DBClient) -> None: def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance self.db_instance = db_instance
self.cache[self.id] = self self.cache[self.id] = self
self.log = log.getChild(self.id) self.log = log.getChild(self.id)
self.references = set() self.references = set()
self.started = False
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client, token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, store=self.db_instance) log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin: if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None: async def start(self, try_n: Optional[int] = 0) -> None:
asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
try: try:
if try_n > 0:
await asyncio.sleep(try_n * 10)
await self._start(try_n)
except Exception:
self.log.exception("Failed to start")
async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
elif self.started:
self.log.warning("Ignoring start() call to started client")
return
try:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False
return
if not self.filter_id: if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter( self.filter_id = await self.client.create_filter(Filter(
room=RoomFilter( room=RoomFilter(
@ -67,13 +100,37 @@ class Client:
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
await self.client.start(self.filter_id) if self.sync:
except Exception: self.client.start(self.filter_id)
self.log.exception("starting raised exception") self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
async def start_plugins(self) -> None:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
loop=self.loop)
def stop(self) -> None: def stop(self) -> None:
self.started = False
self.client.stop() self.client.stop()
def to_dict(self) -> dict:
return {
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
"autojoin": self.autojoin,
"displayname": self.displayname,
"avatar_url": self.avatar_url,
"instances": [instance.to_dict() for instance in self.references],
}
@classmethod @classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try: try:
@ -111,6 +168,14 @@ class Client:
self.client.api.token = value self.client.api.token = value
self.db_instance.access_token = value self.db_instance.access_token = value
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property @property
def next_batch(self) -> SyncToken: def next_batch(self) -> SyncToken:
return self.db_instance.next_batch return self.db_instance.next_batch
@ -168,8 +233,7 @@ class Client:
# endregion # endregion
def init(loop: asyncio.AbstractEventLoop) -> None: def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.http_client = ClientSession(loop=loop) Client.http_client = ClientSession(loop=loop)
Client.loop = loop Client.loop = loop
for client in Client.all(): return Client.all()
client.start()

View File

@ -42,6 +42,7 @@ class DBClient(Base):
id: UserID = Column(String(255), primary_key=True) id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False) homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(String(255), nullable=False) access_token: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="") next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="") filter_id: FilterID = Column(String(255), nullable=False, default="")

View File

@ -75,6 +75,7 @@ class PluginInstance:
if not self.client: if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}") self.log.error(f"Failed to get client for user {self.primary_user}")
self.enabled = False self.enabled = False
return
self.log.debug("Plugin instance dependencies loaded") self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self) self.loader.references.add(self)
self.client.references.add(self) self.client.references.add(self)
@ -93,8 +94,11 @@ class PluginInstance:
self.db_instance.config = buf.getvalue() self.db_instance.config = buf.getvalue()
async def start(self) -> None: async def start(self) -> None:
if not self.enabled: if self.running:
self.log.warning(f"Plugin disabled, not starting.") self.log.warning("Ignoring start() call to already started plugin")
return
elif not self.enabled:
self.log.warning("Plugin disabled, not starting.")
return return
cls = await self.loader.load() cls = await self.loader.load()
config_class = cls.get_config_class() config_class = cls.get_config_class()
@ -118,6 +122,9 @@ class PluginInstance:
f"with user {self.client.id}") f"with user {self.client.id}")
async def stop(self) -> None: async def stop(self) -> None:
if not self.running:
self.log.warning("Ignoring stop() call to non-running plugin")
return
self.log.debug("Stopping plugin instance...") self.log.debug("Stopping plugin instance...")
self.running = False self.running = False
await self.plugin.stop() await self.plugin.stop()

View File

@ -47,6 +47,7 @@ class PluginLoader(ABC):
return { return {
"id": self.id, "id": self.id,
"version": self.version, "version": self.version,
"instances": [instance.to_dict() for instance in self.references],
} }
@property @property

View File

@ -14,26 +14,56 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web from aiohttp import web
from json import JSONDecodeError
from mautrix.types import UserID
from ...db import DBClient
from ...client import Client
from .base import routes from .base import routes
from .responses import ErrNotImplemented from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
@routes.get("/instances") @routes.get("/instances")
def get_instances(request: web.Request) -> web.Response: async def get_instances(_: web.Request) -> web.Response:
return ErrNotImplemented return web.json_response([client.to_dict() for client in Client.cache.values()])
@routes.get("/instance/{id}") @routes.get("/instance/{id}")
def get_instance(request: web.Request) -> web.Response: async def get_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return web.json_response(client.to_dict())
async def create_instance(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented
async def update_instance(client: Client, data: dict) -> web.Response:
return ErrNotImplemented return ErrNotImplemented
@routes.put("/instance/{id}") @routes.put("/instance/{id}")
def update_instance(request: web.Request) -> web.Response: async def update_instance(request: web.Request) -> web.Response:
return ErrNotImplemented user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
if not client:
return await create_instance(user_id, data)
else:
return await update_instance(client, data)
@routes.delete("/instance/{id}") @routes.delete("/instance/{id}")
def delete_instance(request: web.Request) -> web.Response: async def delete_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return ErrNotImplemented return ErrNotImplemented

View File

@ -24,16 +24,9 @@ from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error,
from .base import routes, get_config from .base import routes, get_config
def _plugin_to_dict(plugin: PluginLoader) -> dict:
return {
**plugin.to_dict(),
"instances": [instance.to_dict() for instance in plugin.references]
}
@routes.get("/plugins") @routes.get("/plugins")
async def get_plugins(_) -> web.Response: async def get_plugins(_) -> web.Response:
return web.json_response([_plugin_to_dict(plugin) for plugin in PluginLoader.id_cache.values()]) return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
@routes.get("/plugin/{id}") @routes.get("/plugin/{id}")
@ -42,7 +35,7 @@ async def get_plugin(request: web.Request) -> web.Response:
plugin = PluginLoader.id_cache.get(plugin_id, None) plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin: if not plugin:
return ErrPluginNotFound return ErrPluginNotFound
return web.json_response(_plugin_to_dict(plugin)) return web.json_response(plugin.to_dict())
@routes.delete("/plugin/{id}") @routes.delete("/plugin/{id}")
@ -78,11 +71,11 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo
with open(path, "wb") as p: with open(path, "wb") as p:
p.write(content) p.write(content)
try: try:
ZippedPluginLoader.get(path) plugin = ZippedPluginLoader.get(path)
except MaubotZipImportError as e: except MaubotZipImportError as e:
ZippedPluginLoader.trash(path) ZippedPluginLoader.trash(path)
return plugin_import_error(str(e), traceback.format_exc()) return plugin_import_error(str(e), traceback.format_exc())
return RespOK return web.json_response(plugin.to_dict())
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
@ -110,7 +103,7 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
return plugin_import_error(str(e), traceback.format_exc()) return plugin_import_error(str(e), traceback.format_exc())
await plugin.start_instances() await plugin.start_instances()
ZippedPluginLoader.trash(old_path, reason="update") ZippedPluginLoader.trash(old_path, reason="update")
return RespOK return web.json_response(plugin.to_dict())
@routes.post("/plugins/upload") @routes.post("/plugins/upload")

View File

@ -36,6 +36,11 @@ ErrPluginNotFound = web.json_response({
"errcode": "plugin_not_found", "errcode": "plugin_not_found",
}, status=HTTPStatus.NOT_FOUND) }, status=HTTPStatus.NOT_FOUND)
ErrClientNotFound = web.json_response({
"error": "Client not found",
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPathNotFound = web.json_response({ ErrPathNotFound = web.json_response({
"error": "Resource not found", "error": "Resource not found",
"errcode": "resource_not_found", "errcode": "resource_not_found",

View File

@ -231,12 +231,14 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/MatrixClientList' type: array
items:
$ref: '#/components/schemas/MatrixClient'
401: 401:
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
'/client/{user_id}': '/client/{id}':
parameters: parameters:
- name: user_id - name: id
in: path in: path
description: The Matrix user ID of the client to get description: The Matrix user ID of the client to get
required: true required: true
@ -338,38 +340,12 @@ components:
enabled: enabled:
type: boolean type: boolean
example: true example: true
started:
type: boolean
example: true
primary_user: primary_user:
type: string type: string
example: '@putkiteippi:maunium.net' example: '@putkiteippi:maunium.net'
MatrixClientList:
type: array
items:
type: object
properties:
id:
type: string
example: '@putkiteippi:maunium.net'
homeserver:
type: string
example: 'https://maunium.net'
enabled:
type: boolean
example: true
sync:
type: boolean
example: true
autojoin:
type: boolean
example: true
displayname:
type: string
example: J. E. Saarinen
avatar_url:
type: string
example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU'
instance_count:
type: integer
example: 1
MatrixClient: MatrixClient:
type: object type: object
properties: properties:
@ -385,6 +361,9 @@ components:
enabled: enabled:
type: boolean type: boolean
example: true example: true
started:
type: boolean
example: true
sync: sync:
type: boolean type: boolean
example: true example: true