maubot/maubot/instance.py
2020-04-22 22:40:43 +03:00

306 lines
11 KiB
Python

# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop
import os.path
import logging
import io
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID
from .db import DBPlugin
from .config import Config
from .client import Client
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .server import MaubotServer, PluginWebApp
log = logging.getLogger("maubot.instance")
yaml = YAML()
yaml.indent(4)
yaml.width = 200
class PluginInstance:
webserver: 'MaubotServer' = None
mb_config: Config = None
loop: AbstractEventLoop = None
cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
config: BaseProxyConfig
base_cfg: Optional[RecursiveDict[CommentedMap]]
base_cfg_str: Optional[str]
inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table]
inst_webapp: Optional['PluginWebApp']
inst_webapp_url: Optional[str]
started: bool
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
self.log = log.getChild(self.id)
self.config = None
self.started = False
self.loader = None
self.client = None
self.plugin = None
self.inst_db = None
self.inst_db_tables = None
self.inst_webapp = None
self.inst_webapp_url = None
self.base_cfg = None
self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict:
return {
"id": self.id,
"type": self.type,
"enabled": self.enabled,
"started": self.started,
"primary_user": self.primary_user,
"config": self.db_instance.config,
"base_config": self.base_cfg_str,
"database": (self.inst_db is not None
and self.mb_config["api_features.instance_database"]),
}
def get_db_tables(self) -> Dict[str, sql.Table]:
if not self.inst_db_tables:
metadata = sql.MetaData()
metadata.reflect(self.inst_db)
self.inst_db_tables = metadata.tables
return self.inst_db_tables
def load(self) -> bool:
if not self.loader:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False
return False
if not self.client:
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False
return False
if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
if self.loader.meta.webapp:
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
self.log.debug("Plugin instance dependencies loaded")
self.loader.references.add(self)
self.client.references.add(self)
return True
def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
self.client.references.remove(self)
try:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted")
if self.inst_webapp:
self.webserver.remove_instance_webapp(self.id)
def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
self.db_instance.config = buf.getvalue()
async def start(self) -> None:
if self.started:
self.log.warning("Ignoring start() call to already started plugin")
return
elif not self.enabled:
self.log.warning("Plugin disabled, not starting.")
return
if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not self.load():
return
cls = await self.loader.load()
config_class = cls.get_config_class()
if config_class:
try:
base = await self.loader.read_file("base-config.yaml")
self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
buf = io.StringIO()
yaml.dump(self.base_cfg._data, buf)
self.base_cfg_str = buf.getvalue()
except (FileNotFoundError, KeyError):
self.base_cfg = None
self.base_cfg_str = None
if self.base_cfg:
base_cfg_func = self.base_cfg.clone
else:
def base_cfg_func() -> None:
return None
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config,
database=self.inst_db, webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url)
try:
await self.plugin.internal_start()
except Exception:
self.log.exception("Failed to start instance")
self.db_instance.enabled = False
return
self.started = True
self.inst_db_tables = None
self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}")
async def stop(self) -> None:
if not self.started:
self.log.warning("Ignoring stop() call to non-running plugin")
return
self.log.debug("Stopping plugin instance...")
self.started = False
try:
await self.plugin.internal_stop()
except Exception:
self.log.exception("Failed to stop instance")
self.plugin = None
self.inst_db_tables = None
@classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
) -> Optional['PluginInstance']:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod
def all(cls) -> Iterable['PluginInstance']:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id.lower()
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
return
self.db_instance.config = config
if self.started and self.plugin is not None:
self.plugin.on_external_config_update()
async def update_primary_user(self, primary_user: UserID) -> bool:
if not primary_user or primary_user == self.primary_user:
return True
client = Client.get(primary_user)
if not client:
return False
await self.stop()
self.db_instance.primary_user = client.id
if self.client:
self.client.references.remove(self)
self.client = client
self.client.references.add(self)
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
async def update_type(self, type: str) -> bool:
if not type or type == self.type:
return True
try:
loader = PluginLoader.find(type)
except KeyError:
return False
await self.stop()
self.db_instance.type = loader.meta.id
if self.loader:
self.loader.references.remove(self)
self.loader = loader
self.loader.references.add(self)
await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}")
return True
async def update_started(self, started: bool) -> None:
if started is not None and started != self.started:
await (self.start() if started else self.stop())
def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
self.db_instance.enabled = enabled
# region Properties
@property
def id(self) -> str:
return self.db_instance.id
@id.setter
def id(self, value: str) -> None:
self.db_instance.id = value
@property
def type(self) -> str:
return self.db_instance.type
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
# endregion
def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all()