Refactor how plugins are started and update spec
This commit is contained in:
parent
b96d6e6a94
commit
9e066478a9
@ -57,7 +57,7 @@ Base.metadata.create_all()
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
init_db(db_session)
|
init_db(db_session)
|
||||||
init_client(loop)
|
clients = init_client(loop)
|
||||||
init_plugin_instance_class(db_session, config)
|
init_plugin_instance_class(db_session, config)
|
||||||
management_api = init_management(config, loop)
|
management_api = init_management(config, loop)
|
||||||
server = MaubotServer(config, management_api, loop)
|
server = MaubotServer(config, management_api, loop)
|
||||||
@ -84,9 +84,10 @@ async def periodic_commit():
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(asyncio.gather(
|
log.debug("Starting server")
|
||||||
server.start(),
|
loop.run_until_complete(server.start())
|
||||||
*[plugin.start() for plugin in plugins]))
|
log.debug("Starting clients and plugins")
|
||||||
|
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||||
log.debug("Startup actions complete, running forever")
|
log.debug("Startup actions complete, running forever")
|
||||||
loop.run_until_complete(periodic_commit())
|
loop.run_until_complete(periodic_commit())
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "0.1.0.dev5"
|
__version__ = "0.1.0.dev6"
|
||||||
|
@ -18,6 +18,7 @@ from aiohttp import ClientSession
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||||
EventType, Filter, RoomFilter, RoomEventFilter)
|
EventType, Filter, RoomFilter, RoomEventFilter)
|
||||||
|
|
||||||
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
|
|||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
log: logging.Logger
|
||||||
loop: asyncio.AbstractEventLoop
|
loop: asyncio.AbstractEventLoop
|
||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
@ -38,23 +40,54 @@ class Client:
|
|||||||
references: Set['PluginInstance']
|
references: Set['PluginInstance']
|
||||||
db_instance: DBClient
|
db_instance: DBClient
|
||||||
client: MaubotMatrixClient
|
client: MaubotMatrixClient
|
||||||
|
started: bool
|
||||||
|
|
||||||
def __init__(self, db_instance: DBClient) -> None:
|
def __init__(self, db_instance: DBClient) -> None:
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
self.log = log.getChild(self.id)
|
self.log = log.getChild(self.id)
|
||||||
self.references = set()
|
self.references = set()
|
||||||
|
self.started = False
|
||||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||||
token=self.access_token, client_session=self.http_client,
|
token=self.access_token, client_session=self.http_client,
|
||||||
log=self.log, loop=self.loop, store=self.db_instance)
|
log=self.log, loop=self.loop, store=self.db_instance)
|
||||||
if self.autojoin:
|
if self.autojoin:
|
||||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
|
|
||||||
def start(self) -> None:
|
async def start(self, try_n: Optional[int] = 0) -> None:
|
||||||
asyncio.ensure_future(self._start(), loop=self.loop)
|
|
||||||
|
|
||||||
async def _start(self) -> None:
|
|
||||||
try:
|
try:
|
||||||
|
if try_n > 0:
|
||||||
|
await asyncio.sleep(try_n * 10)
|
||||||
|
await self._start(try_n)
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Failed to start")
|
||||||
|
|
||||||
|
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
self.log.debug("Not starting disabled client")
|
||||||
|
return
|
||||||
|
elif self.started:
|
||||||
|
self.log.warning("Ignoring start() call to started client")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
user_id = await self.client.whoami()
|
||||||
|
except MatrixInvalidToken as e:
|
||||||
|
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
|
except MatrixRequestError:
|
||||||
|
if try_n >= 5:
|
||||||
|
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||||
|
self.enabled = False
|
||||||
|
else:
|
||||||
|
self.log.exception(f"Failed to get /account/whoami, "
|
||||||
|
f"retrying in {(try_n + 1) * 10}s")
|
||||||
|
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||||
|
return
|
||||||
|
if user_id != self.id:
|
||||||
|
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
if not self.filter_id:
|
if not self.filter_id:
|
||||||
self.filter_id = await self.client.create_filter(Filter(
|
self.filter_id = await self.client.create_filter(Filter(
|
||||||
room=RoomFilter(
|
room=RoomFilter(
|
||||||
@ -67,13 +100,37 @@ class Client:
|
|||||||
await self.client.set_displayname(self.displayname)
|
await self.client.set_displayname(self.displayname)
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
await self.client.set_avatar_url(self.avatar_url)
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
await self.client.start(self.filter_id)
|
if self.sync:
|
||||||
except Exception:
|
self.client.start(self.filter_id)
|
||||||
self.log.exception("starting raised exception")
|
self.started = True
|
||||||
|
self.log.info("Client started, starting plugin instances...")
|
||||||
|
await self.start_plugins()
|
||||||
|
|
||||||
|
async def start_plugins(self) -> None:
|
||||||
|
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
||||||
|
|
||||||
|
async def stop_plugins(self) -> None:
|
||||||
|
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
|
||||||
|
loop=self.loop)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
|
self.started = False
|
||||||
self.client.stop()
|
self.client.stop()
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"homeserver": self.homeserver,
|
||||||
|
"access_token": self.access_token,
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"started": self.started,
|
||||||
|
"sync": self.sync,
|
||||||
|
"autojoin": self.autojoin,
|
||||||
|
"displayname": self.displayname,
|
||||||
|
"avatar_url": self.avatar_url,
|
||||||
|
"instances": [instance.to_dict() for instance in self.references],
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||||
try:
|
try:
|
||||||
@ -111,6 +168,14 @@ class Client:
|
|||||||
self.client.api.token = value
|
self.client.api.token = value
|
||||||
self.db_instance.access_token = value
|
self.db_instance.access_token = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.db_instance.enabled
|
||||||
|
|
||||||
|
@enabled.setter
|
||||||
|
def enabled(self, value: bool) -> None:
|
||||||
|
self.db_instance.enabled = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_batch(self) -> SyncToken:
|
def next_batch(self) -> SyncToken:
|
||||||
return self.db_instance.next_batch
|
return self.db_instance.next_batch
|
||||||
@ -168,8 +233,7 @@ class Client:
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||||
Client.http_client = ClientSession(loop=loop)
|
Client.http_client = ClientSession(loop=loop)
|
||||||
Client.loop = loop
|
Client.loop = loop
|
||||||
for client in Client.all():
|
return Client.all()
|
||||||
client.start()
|
|
||||||
|
@ -42,6 +42,7 @@ class DBClient(Base):
|
|||||||
id: UserID = Column(String(255), primary_key=True)
|
id: UserID = Column(String(255), primary_key=True)
|
||||||
homeserver: str = Column(String(255), nullable=False)
|
homeserver: str = Column(String(255), nullable=False)
|
||||||
access_token: str = Column(String(255), nullable=False)
|
access_token: str = Column(String(255), nullable=False)
|
||||||
|
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
||||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
||||||
|
@ -75,6 +75,7 @@ class PluginInstance:
|
|||||||
if not self.client:
|
if not self.client:
|
||||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||||
self.enabled = False
|
self.enabled = False
|
||||||
|
return
|
||||||
self.log.debug("Plugin instance dependencies loaded")
|
self.log.debug("Plugin instance dependencies loaded")
|
||||||
self.loader.references.add(self)
|
self.loader.references.add(self)
|
||||||
self.client.references.add(self)
|
self.client.references.add(self)
|
||||||
@ -93,8 +94,11 @@ class PluginInstance:
|
|||||||
self.db_instance.config = buf.getvalue()
|
self.db_instance.config = buf.getvalue()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if not self.enabled:
|
if self.running:
|
||||||
self.log.warning(f"Plugin disabled, not starting.")
|
self.log.warning("Ignoring start() call to already started plugin")
|
||||||
|
return
|
||||||
|
elif not self.enabled:
|
||||||
|
self.log.warning("Plugin disabled, not starting.")
|
||||||
return
|
return
|
||||||
cls = await self.loader.load()
|
cls = await self.loader.load()
|
||||||
config_class = cls.get_config_class()
|
config_class = cls.get_config_class()
|
||||||
@ -118,6 +122,9 @@ class PluginInstance:
|
|||||||
f"with user {self.client.id}")
|
f"with user {self.client.id}")
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
|
if not self.running:
|
||||||
|
self.log.warning("Ignoring stop() call to non-running plugin")
|
||||||
|
return
|
||||||
self.log.debug("Stopping plugin instance...")
|
self.log.debug("Stopping plugin instance...")
|
||||||
self.running = False
|
self.running = False
|
||||||
await self.plugin.stop()
|
await self.plugin.stop()
|
||||||
|
@ -47,6 +47,7 @@ class PluginLoader(ABC):
|
|||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"version": self.version,
|
"version": self.version,
|
||||||
|
"instances": [instance.to_dict() for instance in self.references],
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -14,26 +14,56 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
from mautrix.types import UserID
|
||||||
|
|
||||||
|
from ...db import DBClient
|
||||||
|
from ...client import Client
|
||||||
from .base import routes
|
from .base import routes
|
||||||
from .responses import ErrNotImplemented
|
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/instances")
|
@routes.get("/instances")
|
||||||
def get_instances(request: web.Request) -> web.Response:
|
async def get_instances(_: web.Request) -> web.Response:
|
||||||
return ErrNotImplemented
|
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/instance/{id}")
|
@routes.get("/instance/{id}")
|
||||||
def get_instance(request: web.Request) -> web.Response:
|
async def get_instance(request: web.Request) -> web.Response:
|
||||||
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
if not client:
|
||||||
|
return ErrClientNotFound
|
||||||
|
return web.json_response(client.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
async def create_instance(user_id: UserID, data: dict) -> web.Response:
|
||||||
|
return ErrNotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
async def update_instance(client: Client, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
|
|
||||||
|
|
||||||
@routes.put("/instance/{id}")
|
@routes.put("/instance/{id}")
|
||||||
def update_instance(request: web.Request) -> web.Response:
|
async def update_instance(request: web.Request) -> web.Response:
|
||||||
return ErrNotImplemented
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except JSONDecodeError:
|
||||||
|
return ErrBodyNotJSON
|
||||||
|
if not client:
|
||||||
|
return await create_instance(user_id, data)
|
||||||
|
else:
|
||||||
|
return await update_instance(client, data)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete("/instance/{id}")
|
@routes.delete("/instance/{id}")
|
||||||
def delete_instance(request: web.Request) -> web.Response:
|
async def delete_instance(request: web.Request) -> web.Response:
|
||||||
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
if not client:
|
||||||
|
return ErrClientNotFound
|
||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
|
@ -24,16 +24,9 @@ from .responses import (ErrPluginNotFound, ErrPluginInUse, plugin_import_error,
|
|||||||
from .base import routes, get_config
|
from .base import routes, get_config
|
||||||
|
|
||||||
|
|
||||||
def _plugin_to_dict(plugin: PluginLoader) -> dict:
|
|
||||||
return {
|
|
||||||
**plugin.to_dict(),
|
|
||||||
"instances": [instance.to_dict() for instance in plugin.references]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/plugins")
|
@routes.get("/plugins")
|
||||||
async def get_plugins(_) -> web.Response:
|
async def get_plugins(_) -> web.Response:
|
||||||
return web.json_response([_plugin_to_dict(plugin) for plugin in PluginLoader.id_cache.values()])
|
return web.json_response([plugin.to_dict() for plugin in PluginLoader.id_cache.values()])
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/plugin/{id}")
|
@routes.get("/plugin/{id}")
|
||||||
@ -42,7 +35,7 @@ async def get_plugin(request: web.Request) -> web.Response:
|
|||||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
return ErrPluginNotFound
|
return ErrPluginNotFound
|
||||||
return web.json_response(_plugin_to_dict(plugin))
|
return web.json_response(plugin.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@routes.delete("/plugin/{id}")
|
@routes.delete("/plugin/{id}")
|
||||||
@ -78,11 +71,11 @@ async def upload_new_plugin(content: bytes, pid: str, version: str) -> web.Respo
|
|||||||
with open(path, "wb") as p:
|
with open(path, "wb") as p:
|
||||||
p.write(content)
|
p.write(content)
|
||||||
try:
|
try:
|
||||||
ZippedPluginLoader.get(path)
|
plugin = ZippedPluginLoader.get(path)
|
||||||
except MaubotZipImportError as e:
|
except MaubotZipImportError as e:
|
||||||
ZippedPluginLoader.trash(path)
|
ZippedPluginLoader.trash(path)
|
||||||
return plugin_import_error(str(e), traceback.format_exc())
|
return plugin_import_error(str(e), traceback.format_exc())
|
||||||
return RespOK
|
return web.json_response(plugin.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
|
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, new_version: str
|
||||||
@ -110,7 +103,7 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
|
|||||||
return plugin_import_error(str(e), traceback.format_exc())
|
return plugin_import_error(str(e), traceback.format_exc())
|
||||||
await plugin.start_instances()
|
await plugin.start_instances()
|
||||||
ZippedPluginLoader.trash(old_path, reason="update")
|
ZippedPluginLoader.trash(old_path, reason="update")
|
||||||
return RespOK
|
return web.json_response(plugin.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/plugins/upload")
|
@routes.post("/plugins/upload")
|
||||||
|
@ -36,6 +36,11 @@ ErrPluginNotFound = web.json_response({
|
|||||||
"errcode": "plugin_not_found",
|
"errcode": "plugin_not_found",
|
||||||
}, status=HTTPStatus.NOT_FOUND)
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrClientNotFound = web.json_response({
|
||||||
|
"error": "Client not found",
|
||||||
|
"errcode": "client_not_found",
|
||||||
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
ErrPathNotFound = web.json_response({
|
ErrPathNotFound = web.json_response({
|
||||||
"error": "Resource not found",
|
"error": "Resource not found",
|
||||||
"errcode": "resource_not_found",
|
"errcode": "resource_not_found",
|
||||||
|
@ -231,12 +231,14 @@ paths:
|
|||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/MatrixClientList'
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MatrixClient'
|
||||||
401:
|
401:
|
||||||
$ref: '#/components/responses/Unauthorized'
|
$ref: '#/components/responses/Unauthorized'
|
||||||
'/client/{user_id}':
|
'/client/{id}':
|
||||||
parameters:
|
parameters:
|
||||||
- name: user_id
|
- name: id
|
||||||
in: path
|
in: path
|
||||||
description: The Matrix user ID of the client to get
|
description: The Matrix user ID of the client to get
|
||||||
required: true
|
required: true
|
||||||
@ -338,38 +340,12 @@ components:
|
|||||||
enabled:
|
enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
started:
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
primary_user:
|
primary_user:
|
||||||
type: string
|
type: string
|
||||||
example: '@putkiteippi:maunium.net'
|
example: '@putkiteippi:maunium.net'
|
||||||
MatrixClientList:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
id:
|
|
||||||
type: string
|
|
||||||
example: '@putkiteippi:maunium.net'
|
|
||||||
homeserver:
|
|
||||||
type: string
|
|
||||||
example: 'https://maunium.net'
|
|
||||||
enabled:
|
|
||||||
type: boolean
|
|
||||||
example: true
|
|
||||||
sync:
|
|
||||||
type: boolean
|
|
||||||
example: true
|
|
||||||
autojoin:
|
|
||||||
type: boolean
|
|
||||||
example: true
|
|
||||||
displayname:
|
|
||||||
type: string
|
|
||||||
example: J. E. Saarinen
|
|
||||||
avatar_url:
|
|
||||||
type: string
|
|
||||||
example: 'mxc://maunium.net/FsPQQTntCCqhJMFtwArmJdaU'
|
|
||||||
instance_count:
|
|
||||||
type: integer
|
|
||||||
example: 1
|
|
||||||
MatrixClient:
|
MatrixClient:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@ -385,6 +361,9 @@ components:
|
|||||||
enabled:
|
enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
started:
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
sync:
|
sync:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
Loading…
Reference in New Issue
Block a user