Refactor things and implement instance API
This commit is contained in:
parent
cbeff0c0cb
commit
bc87b2a02b
@ -13,8 +13,6 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from sqlalchemy import orm
|
|
||||||
import sqlalchemy as sql
|
|
||||||
import logging.config
|
import logging.config
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -23,11 +21,11 @@ import copy
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import Base, init as init_db
|
from .db import init as init_db
|
||||||
from .server import MaubotServer
|
from .server import MaubotServer
|
||||||
from .client import Client, init as init_client
|
from .client import Client, init as init_client_class
|
||||||
from .loader import ZippedPluginLoader
|
from .loader.zip import init as init_zip_loader
|
||||||
from .instance import PluginInstance, init as init_plugin_instance_class
|
from .instance import init as init_plugin_instance_class
|
||||||
from .management.api import init as init_management
|
from .management.api import init as init_management
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
|
|
||||||
@ -46,57 +44,48 @@ config.update()
|
|||||||
|
|
||||||
logging.config.dictConfig(copy.deepcopy(config["logging"]))
|
logging.config.dictConfig(copy.deepcopy(config["logging"]))
|
||||||
log = logging.getLogger("maubot.init")
|
log = logging.getLogger("maubot.init")
|
||||||
log.debug(f"Initializing maubot {__version__}")
|
log.info(f"Initializing maubot {__version__}")
|
||||||
|
|
||||||
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
|
||||||
db_factory = orm.sessionmaker(bind=db_engine)
|
|
||||||
db_session = orm.scoping.scoped_session(db_factory)
|
|
||||||
Base.metadata.bind = db_engine
|
|
||||||
Base.metadata.create_all()
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
init_db(db_session)
|
init_zip_loader(config)
|
||||||
clients = init_client(loop)
|
db_session = init_db(config)
|
||||||
init_plugin_instance_class(db_session, config, loop)
|
clients = init_client_class(db_session, loop)
|
||||||
|
plugins = init_plugin_instance_class(db_session, config, loop)
|
||||||
management_api = init_management(config, loop)
|
management_api = init_management(config, loop)
|
||||||
server = MaubotServer(config, management_api, loop)
|
server = MaubotServer(config, management_api, loop)
|
||||||
|
|
||||||
ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
|
|
||||||
ZippedPluginLoader.directories = config["plugin_directories.load"]
|
|
||||||
ZippedPluginLoader.load_all()
|
|
||||||
|
|
||||||
plugins = PluginInstance.all()
|
|
||||||
|
|
||||||
for plugin in plugins:
|
for plugin in plugins:
|
||||||
plugin.load()
|
plugin.load()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
||||||
|
|
||||||
stop = False
|
|
||||||
|
|
||||||
|
|
||||||
async def periodic_commit():
|
async def periodic_commit():
|
||||||
while not stop:
|
while True:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
periodic_commit_task: asyncio.Future = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug("Starting server")
|
log.info("Starting server")
|
||||||
loop.run_until_complete(server.start())
|
loop.run_until_complete(server.start())
|
||||||
log.debug("Starting clients and plugins")
|
log.info("Starting clients and plugins")
|
||||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||||
log.debug("Startup actions complete, running forever")
|
log.info("Startup actions complete, running forever")
|
||||||
loop.run_until_complete(periodic_commit())
|
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
|
log.debug("Interrupt received, stopping HTTP clients/servers and saving database")
|
||||||
stop = True
|
if periodic_commit_task is not None:
|
||||||
|
periodic_commit_task.cancel()
|
||||||
for client in Client.cache.values():
|
for client in Client.cache.values():
|
||||||
client.stop()
|
client.stop()
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
loop.run_until_complete(server.stop())
|
loop.run_until_complete(server.stop())
|
||||||
|
loop.close()
|
||||||
log.debug("Everything stopped, shutting down")
|
log.debug("Everything stopped, shutting down")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -14,10 +14,12 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||||
from aiohttp import ClientSession
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||||
EventType, Filter, RoomFilter, RoomEventFilter)
|
EventType, Filter, RoomFilter, RoomEventFilter)
|
||||||
@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client")
|
|||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
db: Session = None
|
||||||
log: logging.Logger = None
|
log: logging.Logger = None
|
||||||
loop: asyncio.AbstractEventLoop = None
|
loop: asyncio.AbstractEventLoop = None
|
||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
@ -73,12 +76,12 @@ class Client:
|
|||||||
user_id = await self.client.whoami()
|
user_id = await self.client.whoami()
|
||||||
except MatrixInvalidToken as e:
|
except MatrixInvalidToken as e:
|
||||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
if try_n >= 5:
|
if try_n >= 5:
|
||||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
else:
|
else:
|
||||||
self.log.exception(f"Failed to get /account/whoami, "
|
self.log.exception(f"Failed to get /account/whoami, "
|
||||||
f"retrying in {(try_n + 1) * 10}s")
|
f"retrying in {(try_n + 1) * 10}s")
|
||||||
@ -86,7 +89,7 @@ class Client:
|
|||||||
return
|
return
|
||||||
if user_id != self.id:
|
if user_id != self.id:
|
||||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
if not self.filter_id:
|
if not self.filter_id:
|
||||||
self.filter_id = await self.client.create_filter(Filter(
|
self.filter_id = await self.client.create_filter(Filter(
|
||||||
@ -100,8 +103,7 @@ class Client:
|
|||||||
await self.client.set_displayname(self.displayname)
|
await self.client.set_displayname(self.displayname)
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
await self.client.set_avatar_url(self.avatar_url)
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
if self.sync:
|
self.start_sync()
|
||||||
self.client.start(self.filter_id)
|
|
||||||
self.started = True
|
self.started = True
|
||||||
self.log.info("Client started, starting plugin instances...")
|
self.log.info("Client started, starting plugin instances...")
|
||||||
await self.start_plugins()
|
await self.start_plugins()
|
||||||
@ -110,12 +112,19 @@ class Client:
|
|||||||
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
||||||
|
|
||||||
async def stop_plugins(self) -> None:
|
async def stop_plugins(self) -> None:
|
||||||
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
|
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
|
||||||
loop=self.loop)
|
loop=self.loop)
|
||||||
|
|
||||||
|
def start_sync(self) -> None:
|
||||||
|
if self.sync:
|
||||||
|
self.client.start(self.filter_id)
|
||||||
|
|
||||||
|
def stop_sync(self) -> None:
|
||||||
|
self.client.stop()
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
self.started = False
|
self.started = False
|
||||||
self.client.stop()
|
self.stop_sync()
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -233,7 +242,8 @@ class Client:
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
|
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||||
|
Client.db = db
|
||||||
Client.http_client = ClientSession(loop=loop)
|
Client.http_client = ClientSession(loop=loop)
|
||||||
Client.loop = loop
|
Client.loop = loop
|
||||||
return Client.all()
|
return Client.all()
|
||||||
|
21
maubot/db.py
21
maubot/db.py
@ -13,12 +13,17 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
|
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
|
||||||
from sqlalchemy.orm import Query, scoped_session
|
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
Base: declarative_base = declarative_base()
|
Base: declarative_base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +59,14 @@ class DBClient(Base):
|
|||||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||||
|
|
||||||
|
|
||||||
def init(session: scoped_session) -> None:
|
def init(config: Config) -> Session:
|
||||||
DBPlugin.query = session.query_property()
|
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
||||||
DBClient.query = session.query_property()
|
db_factory = sessionmaker(bind=db_engine)
|
||||||
|
db_session = scoped_session(db_factory)
|
||||||
|
Base.metadata.bind = db_engine
|
||||||
|
Base.metadata.create_all()
|
||||||
|
|
||||||
|
DBPlugin.query = db_session.query_property()
|
||||||
|
DBClient.query = db_session.query_property()
|
||||||
|
|
||||||
|
return cast(Session, db_session)
|
||||||
|
@ -48,13 +48,14 @@ class PluginInstance:
|
|||||||
client: Client
|
client: Client
|
||||||
plugin: Plugin
|
plugin: Plugin
|
||||||
config: BaseProxyConfig
|
config: BaseProxyConfig
|
||||||
running: bool
|
base_cfg: RecursiveDict[CommentedMap]
|
||||||
|
started: bool
|
||||||
|
|
||||||
def __init__(self, db_instance: DBPlugin):
|
def __init__(self, db_instance: DBPlugin):
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
|
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
|
||||||
self.config = None
|
self.config = None
|
||||||
self.running = False
|
self.started = False
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
@ -62,7 +63,7 @@ class PluginInstance:
|
|||||||
"id": self.id,
|
"id": self.id,
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
"enabled": self.enabled,
|
"enabled": self.enabled,
|
||||||
"running": self.running,
|
"started": self.started,
|
||||||
"primary_user": self.primary_user,
|
"primary_user": self.primary_user,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,19 +72,26 @@ class PluginInstance:
|
|||||||
self.loader = PluginLoader.find(self.type)
|
self.loader = PluginLoader.find(self.type)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.log.error(f"Failed to find loader for type {self.type}")
|
self.log.error(f"Failed to find loader for type {self.type}")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
self.client = Client.get(self.primary_user)
|
self.client = Client.get(self.primary_user)
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
self.log.debug("Plugin instance dependencies loaded")
|
self.log.debug("Plugin instance dependencies loaded")
|
||||||
self.loader.references.add(self)
|
self.loader.references.add(self)
|
||||||
self.client.references.add(self)
|
self.client.references.add(self)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
if self.loader is not None:
|
||||||
self.loader.references.remove(self)
|
self.loader.references.remove(self)
|
||||||
|
if self.client is not None:
|
||||||
|
self.client.references.remove(self)
|
||||||
|
try:
|
||||||
|
del self.cache[self.id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
self.db.delete(self.db_instance)
|
self.db.delete(self.db_instance)
|
||||||
# TODO delete plugin db
|
# TODO delete plugin db
|
||||||
|
|
||||||
@ -96,7 +104,7 @@ class PluginInstance:
|
|||||||
self.db_instance.config = buf.getvalue()
|
self.db_instance.config = buf.getvalue()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self.running:
|
if self.started:
|
||||||
self.log.warning("Ignoring start() call to already started plugin")
|
self.log.warning("Ignoring start() call to already started plugin")
|
||||||
return
|
return
|
||||||
elif not self.enabled:
|
elif not self.enabled:
|
||||||
@ -107,28 +115,28 @@ class PluginInstance:
|
|||||||
if config_class:
|
if config_class:
|
||||||
try:
|
try:
|
||||||
base = await self.loader.read_file("base-config.yaml")
|
base = await self.loader.read_file("base-config.yaml")
|
||||||
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
||||||
except (FileNotFoundError, KeyError):
|
except (FileNotFoundError, KeyError):
|
||||||
base_file = None
|
self.base_cfg = None
|
||||||
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
|
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
|
||||||
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
|
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
|
||||||
self.log, self.config, self.mb_config["plugin_directories.db"])
|
self.log, self.config, self.mb_config["plugin_directories.db"])
|
||||||
try:
|
try:
|
||||||
await self.plugin.start()
|
await self.plugin.start()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to start instance")
|
self.log.exception("Failed to start instance")
|
||||||
self.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
self.running = True
|
self.started = True
|
||||||
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
||||||
f"with user {self.client.id}")
|
f"with user {self.client.id}")
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if not self.running:
|
if not self.started:
|
||||||
self.log.warning("Ignoring stop() call to non-running plugin")
|
self.log.warning("Ignoring stop() call to non-running plugin")
|
||||||
return
|
return
|
||||||
self.log.debug("Stopping plugin instance...")
|
self.log.debug("Stopping plugin instance...")
|
||||||
self.running = False
|
self.started = False
|
||||||
try:
|
try:
|
||||||
await self.plugin.stop()
|
await self.plugin.stop()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -150,6 +158,37 @@ class PluginInstance:
|
|||||||
def all(cls) -> List['PluginInstance']:
|
def all(cls) -> List['PluginInstance']:
|
||||||
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
||||||
|
|
||||||
|
def update_id(self, new_id: str) -> None:
|
||||||
|
if new_id is not None and new_id != self.id:
|
||||||
|
self.db_instance.id = new_id
|
||||||
|
|
||||||
|
def update_config(self, config: str) -> None:
|
||||||
|
if not config or self.db_instance.config == config:
|
||||||
|
return
|
||||||
|
self.db_instance.config = config
|
||||||
|
if self.started and self.plugin is not None:
|
||||||
|
self.plugin.on_external_config_update()
|
||||||
|
|
||||||
|
async def update_primary_user(self, primary_user: UserID) -> bool:
|
||||||
|
client = Client.get(primary_user)
|
||||||
|
if not client:
|
||||||
|
return False
|
||||||
|
await self.stop()
|
||||||
|
self.db_instance.primary_user = client.id
|
||||||
|
self.client.references.remove(self)
|
||||||
|
self.client = client
|
||||||
|
await self.start()
|
||||||
|
self.log.debug(f"Primary user switched to {self.client.id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def update_started(self, started: bool) -> None:
|
||||||
|
if started is not None and started != self.started:
|
||||||
|
await (self.start() if started else self.stop())
|
||||||
|
|
||||||
|
def update_enabled(self, enabled: bool) -> None:
|
||||||
|
if enabled is not None and enabled != self.enabled:
|
||||||
|
self.db_instance.enabled = enabled
|
||||||
|
|
||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -168,22 +207,15 @@ class PluginInstance:
|
|||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
return self.db_instance.enabled
|
return self.db_instance.enabled
|
||||||
|
|
||||||
@enabled.setter
|
|
||||||
def enabled(self, value: bool) -> None:
|
|
||||||
self.db_instance.enabled = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def primary_user(self) -> UserID:
|
def primary_user(self) -> UserID:
|
||||||
return self.db_instance.primary_user
|
return self.db_instance.primary_user
|
||||||
|
|
||||||
@primary_user.setter
|
|
||||||
def primary_user(self, value: UserID) -> None:
|
|
||||||
self.db_instance.primary_user = value
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(db: Session, config: Config, loop: AbstractEventLoop):
|
def init(db: Session, config: Config, loop: AbstractEventLoop) -> List[PluginInstance]:
|
||||||
PluginInstance.db = db
|
PluginInstance.db = db
|
||||||
PluginInstance.mb_config = config
|
PluginInstance.mb_config = config
|
||||||
PluginInstance.loop = loop
|
PluginInstance.loop = loop
|
||||||
|
return PluginInstance.all()
|
||||||
|
@ -61,7 +61,7 @@ class PluginLoader(ABC):
|
|||||||
|
|
||||||
async def stop_instances(self) -> None:
|
async def stop_instances(self) -> None:
|
||||||
await asyncio.gather(*[instance.stop() for instance
|
await asyncio.gather(*[instance.stop() for instance
|
||||||
in self.references if instance.running])
|
in self.references if instance.started])
|
||||||
|
|
||||||
async def start_instances(self) -> None:
|
async def start_instances(self) -> None:
|
||||||
await asyncio.gather(*[instance.start() for instance
|
await asyncio.gather(*[instance.start() for instance
|
||||||
|
@ -23,6 +23,7 @@ import os
|
|||||||
|
|
||||||
from ..lib.zipimport import zipimporter, ZipImportError
|
from ..lib.zipimport import zipimporter, ZipImportError
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
from ..config import Config
|
||||||
from .abc import PluginLoader, PluginClass, IDConflictError
|
from .abc import PluginLoader, PluginClass, IDConflictError
|
||||||
|
|
||||||
|
|
||||||
@ -264,3 +265,9 @@ class ZippedPluginLoader(PluginLoader):
|
|||||||
except IDConflictError:
|
except IDConflictError:
|
||||||
cls.log.error(f"Duplicate plugin ID at {path}, trashing...")
|
cls.log.error(f"Duplicate plugin ID at {path}, trashing...")
|
||||||
cls.trash(path)
|
cls.trash(path)
|
||||||
|
|
||||||
|
|
||||||
|
def init(config: Config) -> None:
|
||||||
|
ZippedPluginLoader.trash_path = config["plugin_directories.trash"]
|
||||||
|
ZippedPluginLoader.directories = config["plugin_directories.load"]
|
||||||
|
ZippedPluginLoader.load_all()
|
||||||
|
@ -21,6 +21,8 @@ from .base import routes, set_config
|
|||||||
from .middleware import auth, error
|
from .middleware import auth, error
|
||||||
from .auth import web as _
|
from .auth import web as _
|
||||||
from .plugin import web as _
|
from .plugin import web as _
|
||||||
|
from .instance import web as _
|
||||||
|
from .client import web as _
|
||||||
|
|
||||||
|
|
||||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||||
|
@ -13,27 +13,57 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
from mautrix.types import UserID
|
||||||
|
|
||||||
|
from ...client import Client
|
||||||
from .base import routes
|
from .base import routes
|
||||||
from .responses import ErrNotImplemented
|
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/clients")
|
@routes.get("/clients")
|
||||||
def get_clients(request: web.Request) -> web.Response:
|
async def get_clients(request: web.Request) -> web.Response:
|
||||||
return ErrNotImplemented
|
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/client/{id}")
|
@routes.get("/client/{id}")
|
||||||
def get_client(request: web.Request) -> web.Response:
|
async def get_client(request: web.Request) -> web.Response:
|
||||||
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
if not client:
|
||||||
|
return ErrClientNotFound
|
||||||
|
return web.json_response(client.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
async def create_client(user_id: UserID, data: dict) -> web.Response:
|
||||||
|
return ErrNotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
async def update_client(client: Client, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
|
|
||||||
|
|
||||||
@routes.put("/client/{id}")
|
@routes.put("/client/{id}")
|
||||||
def update_client(request: web.Request) -> web.Response:
|
async def update_client(request: web.Request) -> web.Response:
|
||||||
return ErrNotImplemented
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except JSONDecodeError:
|
||||||
|
return ErrBodyNotJSON
|
||||||
|
if not client:
|
||||||
|
return await create_client(user_id, data)
|
||||||
|
else:
|
||||||
|
return await update_client(client, data)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete("/client/{id}")
|
@routes.delete("/client/{id}")
|
||||||
def delete_client(request: web.Request) -> web.Response:
|
async def delete_client(request: web.Request) -> web.Response:
|
||||||
|
user_id = request.match_info.get("id", None)
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
if not client:
|
||||||
|
return ErrClientNotFound
|
||||||
return ErrNotImplemented
|
return ErrNotImplemented
|
||||||
|
@ -13,57 +13,88 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from aiohttp import web
|
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from aiohttp import web
|
||||||
|
|
||||||
from ...db import DBClient
|
from ...db import DBPlugin
|
||||||
|
from ...instance import PluginInstance
|
||||||
|
from ...loader import PluginLoader
|
||||||
from ...client import Client
|
from ...client import Client
|
||||||
from .base import routes
|
from .base import routes
|
||||||
from .responses import ErrNotImplemented, ErrClientNotFound, ErrBodyNotJSON
|
from .responses import (ErrInstanceNotFound, ErrBodyNotJSON, RespDeleted, ErrPrimaryUserNotFound,
|
||||||
|
ErrPluginTypeRequired, ErrPrimaryUserRequired, ErrPluginTypeNotFound)
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/instances")
|
@routes.get("/instances")
|
||||||
async def get_instances(_: web.Request) -> web.Response:
|
async def get_instances(_: web.Request) -> web.Response:
|
||||||
return web.json_response([client.to_dict() for client in Client.cache.values()])
|
return web.json_response([instance.to_dict() for instance in PluginInstance.cache.values()])
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/instance/{id}")
|
@routes.get("/instance/{id}")
|
||||||
async def get_instance(request: web.Request) -> web.Response:
|
async def get_instance(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
instance_id = request.match_info.get("id", "").lower()
|
||||||
client = Client.get(user_id, None)
|
instance = PluginInstance.get(instance_id, None)
|
||||||
if not client:
|
if not instance:
|
||||||
return ErrClientNotFound
|
return ErrInstanceNotFound
|
||||||
return web.json_response(client.to_dict())
|
return web.json_response(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def create_instance(user_id: UserID, data: dict) -> web.Response:
|
async def create_instance(instance_id: str, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
plugin_type = data.get("type", None)
|
||||||
|
primary_user = data.get("primary_user", None)
|
||||||
|
if not plugin_type:
|
||||||
|
return ErrPluginTypeRequired
|
||||||
|
elif not primary_user:
|
||||||
|
return ErrPrimaryUserRequired
|
||||||
|
elif not Client.get(primary_user):
|
||||||
|
return ErrPrimaryUserNotFound
|
||||||
|
try:
|
||||||
|
PluginLoader.find(plugin_type)
|
||||||
|
except KeyError:
|
||||||
|
return ErrPluginTypeNotFound
|
||||||
|
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
|
||||||
|
primary_user=primary_user, config=data.get("config", ""))
|
||||||
|
instance = PluginInstance(db_instance)
|
||||||
|
instance.load()
|
||||||
|
PluginInstance.db.add(db_instance)
|
||||||
|
PluginInstance.db.commit()
|
||||||
|
await instance.start()
|
||||||
|
return web.json_response(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def update_instance(client: Client, data: dict) -> web.Response:
|
async def update_instance(instance: PluginInstance, data: dict) -> web.Response:
|
||||||
return ErrNotImplemented
|
if not await instance.update_primary_user(data.get("primary_user")):
|
||||||
|
return ErrPrimaryUserNotFound
|
||||||
|
instance.update_id(data.get("id", None))
|
||||||
|
instance.update_enabled(data.get("enabled", None))
|
||||||
|
instance.update_config(data.get("config", None))
|
||||||
|
await instance.update_started(data.get("started", None))
|
||||||
|
instance.db.commit()
|
||||||
|
return web.json_response(instance.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@routes.put("/instance/{id}")
|
@routes.put("/instance/{id}")
|
||||||
async def update_instance(request: web.Request) -> web.Response:
|
async def update_instance(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
instance_id = request.match_info.get("id", "").lower()
|
||||||
client = Client.get(user_id, None)
|
instance = PluginInstance.get(instance_id, None)
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return ErrBodyNotJSON
|
return ErrBodyNotJSON
|
||||||
if not client:
|
if not instance:
|
||||||
return await create_instance(user_id, data)
|
return await create_instance(instance_id, data)
|
||||||
else:
|
else:
|
||||||
return await update_instance(client, data)
|
return await update_instance(instance, data)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete("/instance/{id}")
|
@routes.delete("/instance/{id}")
|
||||||
async def delete_instance(request: web.Request) -> web.Response:
|
async def delete_instance(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
instance_id = request.match_info.get("id", "").lower()
|
||||||
client = Client.get(user_id, None)
|
instance = PluginInstance.get(instance_id, None)
|
||||||
if not client:
|
if not instance:
|
||||||
return ErrClientNotFound
|
return ErrInstanceNotFound
|
||||||
return ErrNotImplemented
|
if instance.started:
|
||||||
|
await instance.stop()
|
||||||
|
instance.delete()
|
||||||
|
return RespDeleted
|
||||||
|
@ -41,6 +41,31 @@ ErrClientNotFound = web.json_response({
|
|||||||
"errcode": "client_not_found",
|
"errcode": "client_not_found",
|
||||||
}, status=HTTPStatus.NOT_FOUND)
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrPrimaryUserNotFound = web.json_response({
|
||||||
|
"error": "Client for given primary user not found",
|
||||||
|
"errcode": "primary_user_not_found",
|
||||||
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrInstanceNotFound = web.json_response({
|
||||||
|
"error": "Plugin instance not found",
|
||||||
|
"errcode": "instance_not_found",
|
||||||
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrPluginTypeNotFound = web.json_response({
|
||||||
|
"error": "Given plugin type not found",
|
||||||
|
"errcode": "plugin_type_not_found",
|
||||||
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrPluginTypeRequired = web.json_response({
|
||||||
|
"error": "Plugin type is required when creating plugin instances",
|
||||||
|
"errcode": "plugin_type_required",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
ErrPrimaryUserRequired = web.json_response({
|
||||||
|
"error": "Primary user is required when creating plugin instances",
|
||||||
|
"errcode": "primary_user_required",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
ErrPathNotFound = web.json_response({
|
ErrPathNotFound = web.json_response({
|
||||||
"error": "Resource not found",
|
"error": "Resource not found",
|
||||||
"errcode": "resource_not_found",
|
"errcode": "resource_not_found",
|
||||||
|
@ -346,6 +346,9 @@ components:
|
|||||||
primary_user:
|
primary_user:
|
||||||
type: string
|
type: string
|
||||||
example: '@putkiteippi:maunium.net'
|
example: '@putkiteippi:maunium.net'
|
||||||
|
config:
|
||||||
|
type: string
|
||||||
|
example: "YAML"
|
||||||
MatrixClient:
|
MatrixClient:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import React from "react"
|
import React from "react"
|
||||||
import ReactDOM from "react-dom"
|
import ReactDOM from "react-dom"
|
||||||
import "./style/base"
|
import "./style/index.sass"
|
||||||
import MaubotManager from "./MaubotManager"
|
import MaubotManager from "./MaubotManager"
|
||||||
|
|
||||||
ReactDOM.render(<MaubotManager/>, document.getElementById("root"))
|
ReactDOM.render(<MaubotManager/>, document.getElementById("root"))
|
||||||
|
@ -28,7 +28,6 @@ if TYPE_CHECKING:
|
|||||||
from .command_spec import CommandSpec
|
from .command_spec import CommandSpec
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
|
||||||
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
|
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
|
||||||
|
|
||||||
|
|
||||||
@ -69,3 +68,7 @@ class Plugin(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
|
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def on_external_config_update(self) -> None:
|
||||||
|
if self.config:
|
||||||
|
self.config.load_and_update()
|
||||||
|
Loading…
Reference in New Issue
Block a user