From b29a7d186ea5bfab35ca879612c4a8b0654d6d76 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Oct 2018 18:39:16 +0300 Subject: [PATCH] Add configuration for plugins --- maubot/command_spec.py | 1 + maubot/config.py | 4 ++-- maubot/db.py | 1 + maubot/loader/abc.py | 4 ++++ maubot/loader/zip.py | 8 ++++++-- maubot/plugin.py | 32 ++++++++++++++++++++++++++++++-- maubot/plugin_base.py | 16 +++++++++++++--- 7 files changed, 57 insertions(+), 9 deletions(-) diff --git a/maubot/command_spec.py b/maubot/command_spec.py index 8a8dc6a..a2102da 100644 --- a/maubot/command_spec.py +++ b/maubot/command_spec.py @@ -70,6 +70,7 @@ class ParsedCommand: def _init_active(self, command: Command) -> None: self.name = command.syntax self.is_passive = False + self.arguments = [] regex_builder = [] sw_builder = [] diff --git a/maubot/config.py b/maubot/config.py index d9f719f..585ff3f 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -16,10 +16,10 @@ import random import string -from mautrix.util import BaseConfig +from mautrix.util import BaseFileConfig -class Config(BaseConfig): +class Config(BaseFileConfig): @staticmethod def _new_token() -> str: return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) diff --git a/maubot/db.py b/maubot/db.py index 8a3d428..0e46c27 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -58,6 +58,7 @@ class DBPlugin(Base): primary_user: UserID = Column(String(255), ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), nullable=False) + config: str = Column(Text, nullable=False, default='') class DBClient(Base): diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index 20b593a..f41d848 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -47,6 +47,10 @@ class PluginLoader(ABC): def source(self) -> str: pass + @abstractmethod + def read_file(self, path: str) -> bytes: + pass + @abstractmethod def load(self) -> Type[PluginClass]: pass diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 57f341d..a6a107e 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -41,6 +41,7 @@ class ZippedPluginLoader(PluginLoader): main_module: str _loaded: Type[PluginClass] _importer: zipimporter + _file: ZipFile def __init__(self, path: str) -> None: super().__init__() @@ -76,10 +77,13 @@ class ZippedPluginLoader(PluginLoader): f"id='{self.id}' " f"loaded={self._loaded is not None}>") + def read_file(self, path: str) -> bytes: + return self._file.read(path) + def _load_meta(self) -> None: try: - file = ZipFile(self.path) - data = file.read("maubot.ini") + self._file = ZipFile(self.path) + data = self._file.read("maubot.ini") except FileNotFoundError as e: raise MaubotZipImportError("Maubot plugin not found") from e except BadZipFile as e: diff --git a/maubot/plugin.py b/maubot/plugin.py index 691bf1e..d6c5eb0 100644 --- a/maubot/plugin.py +++ b/maubot/plugin.py @@ -14,8 +14,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Dict, List, Optional +from ruamel.yaml.comments import CommentedMap +from ruamel.yaml import YAML import logging +import io +from mautrix.util import BaseProxyConfig, RecursiveDict from mautrix.types import UserID from .db import DBPlugin @@ -25,6 +29,9 @@ from .plugin_base import Plugin log = logging.getLogger("maubot.plugin") +yaml = YAML() +yaml.indent(4) + class PluginInstance: cache: Dict[str, 'PluginInstance'] = {} @@ -34,10 +41,12 @@ class PluginInstance: loader: PluginLoader client: Client plugin: Plugin + config: BaseProxyConfig def __init__(self, db_instance: DBPlugin): self.db_instance = db_instance self.log = logging.getLogger(f"maubot.plugin.{self.id}") + self.config = None self.cache[self.id] = self def load(self) -> None: @@ -53,12 +62,31 @@ class PluginInstance: self.enabled = False self.log.debug("Plugin instance dependencies loaded") + def load_config(self) -> CommentedMap: + return yaml.load(self.db_instance.config) + + def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]: + try: + base = self.loader.read_file("base-config.yaml") + return yaml.load(base.decode("utf-8")) + except (FileNotFoundError, KeyError): + return None + + def save_config(self, data: RecursiveDict[CommentedMap]) -> None: + buf = io.StringIO() + yaml.dump(data, buf) + self.db_instance.config = buf.getvalue() + async def start(self) -> None: if not self.enabled: - self.log.warn(f"Plugin disabled, not starting.") + self.log.warning(f"Plugin disabled, not starting.") return cls = self.loader.load() - self.plugin = cls(self.client.client, self.id, self.log) + config_class = cls.get_config_class() + if config_class: + self.config = config_class(self.load_config, self.load_config_base, + self.save_config) + self.plugin = cls(self.client.client, self.id, self.log, self.config) self.loader.references |= {self} await self.plugin.start() self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} " diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 58f3435..81ac7f0 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -13,30 +13,40 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TYPE_CHECKING +from typing import Type, Optional, TYPE_CHECKING from logging import Logger -from abc import ABC +from abc import ABC, abstractmethod if TYPE_CHECKING: from .client import MaubotMatrixClient from .command_spec import CommandSpec + from mautrix.util import BaseProxyConfig class Plugin(ABC): client: 'MaubotMatrixClient' id: str log: Logger + config: Optional['BaseProxyConfig'] - def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None: + def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger, + config: Optional['BaseProxyConfig']) -> None: self.client = client self.id = plugin_instance_id self.log = log + self.config = config def set_command_spec(self, spec: 'CommandSpec') -> None: self.client.set_command_spec(self.id, spec) + @abstractmethod async def start(self) -> None: pass + @abstractmethod async def stop(self) -> None: pass + + @classmethod + def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]: + return None