Refactor how plugins are started and update spec

This commit is contained in:
Tulir Asokan 2018-11-01 01:51:54 +02:00
parent b96d6e6a94
commit 9e066478a9
10 changed files with 160 additions and 79 deletions

View File

@ -57,7 +57,7 @@ Base.metadata.create_all()
loop = asyncio.get_event_loop()
init_db(db_session)
init_client(loop)
clients = init_client(loop)
init_plugin_instance_class(db_session, config)
management_api = init_management(config, loop)
server = MaubotServer(config, management_api, loop)
@ -84,9 +84,10 @@ async def periodic_commit():
try:
loop.run_until_complete(asyncio.gather(
server.start(),
*[plugin.start() for plugin in plugins]))
log.debug("Starting server")
loop.run_until_complete(server.start())
log.debug("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.debug("Startup actions complete, running forever")
loop.run_until_complete(periodic_commit())
loop.run_forever()

View File

@ -1 +1 @@
__version__ = "0.1.0.dev5"
__version__ = "0.1.0.dev6"

View File

@ -18,6 +18,7 @@ from aiohttp import ClientSession
import asyncio
import logging
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger
loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
@ -38,23 +40,54 @@ class Client:
references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient
started: bool
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.log = log.getChild(self.id)
self.references = set()
self.started = False
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
async def start(self, try_n: Optional[int] = 0) -> None:
try:
if try_n > 0:
await asyncio.sleep(try_n * 10)
await self._start(try_n)
except Exception:
self.log.exception("Failed to start")
async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
elif self.started:
self.log.warning("Ignoring start() call to started client")
return
try:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
@ -67,13 +100,37 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
await self.client.start(self.filter_id)
except Exception:
self.log.exception("starting raised exception")
if self.sync:
self.client.start(self.filter_id)
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
async def start_plugins(self) -> None:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
loop=self.loop)
def stop(self) -> None:
self.started = False
self.client.stop()
def to_dict(self) -> dict:
return {
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
"autojoin": self.autojoin,
"displayname": self.displayname,
"avatar_url": self.avatar_url,
"instances": [instance.to_dict() for instance in self.references],
}
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
@ -111,6 +168,14 @@ class Client:
self.client.api.token = value
self.db_instance.access_token = value
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@ -168,8 +233,7 @@ class Client:
# endregion
def init(loop: asyncio.AbstractEventLoop) -> None:
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
for client in Client.all():
client.start()
return Client.all()

View File

@ -42,6 +42,7 @@ class DBClient(Base):
id: UserID = Column(String(255), primary_key=True)
homeserver: str = Column(String(255), nullable=False)
access_token: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
next_batch: SyncToken = Column(String(255), nullable=False, default="")
filter_id: FilterID = Column(String(255), nullable=False, default="")

View File

@ -75,6 +75,7 @@ class PluginInstance:
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.enabled = False
return
self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self)
self.client.references.add(self)
@ -93,8 +94,11 @@ class PluginInstance:
self.db_instance.config = buf.getvalue()
async def start(self) -> None:
if not self.enabled:
self.log.warning(f"Plugin disabled, not starting.")
if self.running:
self.log.warning("Ignoring start() call to already started plugin")
return
elif not self.enabled:
self.log.warning("Plugin disabled, not starting.")
return
cls = await self.loader.load()
config_class = cls.get_config_class()
@ -118,6 +122,9 @@ class PluginInstance:
f"with user {self.client.id}")
async def stop(self) -> None:
if not self.running:
self.log.warning("Ignoring stop() call to non-running plugin")
return
self.log.debug("Stopping plugin instance...")
self.running = False
await self.plugin.stop()

View File

@ -47,6 +47,7 @@ class PluginLoader(ABC):
return {
"id": self.id,
"version": self.version,
"instances": [instance.to_dict() for instance in self.references],
}
@property

View File

@ -14,26 +14,56 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from json import JSONDecodeError
from mautrix.types import UserID
from ...db import DBClient
from ...client import Client
from .base import routes
from .responses import ErrNotImplemented
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
@routes.get("/instances")
def get_instances(request: web.Request) -> web.Response:
return ErrNotImplemented
async def get_instances(_: web.Request) -> web.Response:
return web.json_response([client.to_dict() for client in Client.cache.values()])
@routes.get("/instance/{id}")
def get_instance(request: web.Request) -> web.Response:
async def get_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return web.json_response(client.to_dict())
async def create_instance(user_id: UserID, data: dict) -> web.Response:
return ErrNotImplemented
async def update_instance(client: Client, data: dict) -> web.Response:
return ErrNotImplemented
@routes.put("/instance/{id}")
def update_instance(request: web.Request) -> web.Response:
return ErrNotImplemented
async def update_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try:
data = await request.json()
except JSONDecodeError:
return ErrBodyNotJSON
if not client:
return await create_instance(user_id, data)
else:
return await update_instance(client, data)
@routes.delete("/instance/{id}")
def delete_instance(request: web.Request) -> web.Response:
async def delete_instance(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
if not client:
return ErrClientNotFound
return ErrNotImplemented

View File

@ -24,16 +24,9 @@ from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error,
from .base import routes, get_config
def _plugin_to_dict(plugin: PluginLoader) -> dict:
return {
**plugin.to_dict(),
"instances": [instance.to_dict() for instance in plugin.references]
}
@routes.get("/plugins")
async def get_plugins(_) -> web.Response:
return web.json_response([_plugin_to_dict(plugin) for plugin in PluginLoader.id_cache.values()])
return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
@routes.get("/plugin/{id}")
@ -42,7 +35,7 @@ async def get_plugin(request: web.Request) -> web.Response:
plugin = PluginLoader.id_cache.get(plugin_id, None)
if not plugin:
return ErrPluginNotFound
return web.json_response(_plugin_to_dict(plugin))
return web.json_response(plugin.to_dict())
@routes.delete("/plugin/{id}")
@ -78,11 +71,11 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo
with open(path, "wb") as p:
p.write(content)
try:
ZippedPluginLoader.get(path)
plugin = ZippedPluginLoader.get(path)
except MaubotZipImportError as e:
ZippedPluginLoader.trash(path)
return plugin_import_error(str(e), traceback.format_exc())
return RespOK
return web.json_response(plugin.to_dict())
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
@ -110,7 +103,7 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
return plugin_import_error(str(e), traceback.format_exc())
await plugin.start_instances()
ZippedPluginLoader.trash(old_path, reason="update")
return RespOK
return web.json_response(plugin.to_dict())
@routes.post("/plugins/upload")

View File

@ -36,6 +36,11 @@ ErrPluginNotFound = web.json_response({
"errcode": "plugin_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrClientNotFound = web.json_response({
"error": "Client not found",
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",

View File

@ -231,12 +231,14 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClientList'
type: array
items:
$ref: '#/components/schemas/MatrixClient'
401:
$ref: '#/components/responses/Unauthorized'
'/client/{user_id}':
'/client/{id}':
parameters:
- name: user_id
- name: id
in: path
description: The Matrix user ID of the client to get
required: true
@ -338,38 +340,12 @@ components:
enabled:
type: boolean
example: true
started:
type: boolean
example: true
primary_user:
type: string
example: '@putkiteippi:maunium.net'
MatrixClientList:
type: array
items:
type: object
properties:
id:
type: string
example: '@putkiteippi:maunium.net'
homeserver:
type: string
example: 'https://maunium.net'
enabled:
type: boolean
example: true
sync:
type: boolean
example: true
autojoin:
type: boolean
example: true
displayname:
type: string
example: J. E. Saarinen
avatar_url:
type: string
example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU'
instance_count:
type: integer
example: 1
MatrixClient:
type: object
properties:
@ -385,6 +361,9 @@ components:
enabled:
type: boolean
example: true
started:
type: boolean
example: true
sync:
type: boolean
example: true