Add configuration for plugins

This commit is contained in:
Tulir Asokan 2018-10-20 18:39:16 +03:00
parent ac5f059ef4
commit b29a7d186e
7 changed files with 57 additions and 9 deletions

View File

@ -70,6 +70,7 @@ class ParsedCommand:
def _init_active(self, command: Command) -> None: def _init_active(self, command: Command) -> None:
self.name = command.syntax self.name = command.syntax
self.is_passive = False self.is_passive = False
self.arguments = []
regex_builder = [] regex_builder = []
sw_builder = [] sw_builder = []

View File

@ -16,10 +16,10 @@
import random import random
import string import string
from mautrix.util import BaseConfig from mautrix.util import BaseFileConfig
class Config(BaseConfig): class Config(BaseFileConfig):
@staticmethod @staticmethod
def _new_token() -> str: def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))

View File

@ -58,6 +58,7 @@ class DBPlugin(Base):
primary_user: UserID = Column(String(255), primary_user: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False) nullable=False)
config: str = Column(Text, nullable=False, default='')
class DBClient(Base): class DBClient(Base):

View File

@ -47,6 +47,10 @@ class PluginLoader(ABC):
def source(self) -> str: def source(self) -> str:
pass pass
@abstractmethod
def read_file(self, path: str) -> bytes:
pass
@abstractmethod @abstractmethod
def load(self) -> Type[PluginClass]: def load(self) -> Type[PluginClass]:
pass pass

View File

@ -41,6 +41,7 @@ class ZippedPluginLoader(PluginLoader):
main_module: str main_module: str
_loaded: Type[PluginClass] _loaded: Type[PluginClass]
_importer: zipimporter _importer: zipimporter
_file: ZipFile
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__() super().__init__()
@ -76,10 +77,13 @@ class ZippedPluginLoader(PluginLoader):
f"id='{self.id}' " f"id='{self.id}' "
f"loaded={self._loaded is not None}>") f"loaded={self._loaded is not None}>")
def read_file(self, path: str) -> bytes:
return self._file.read(path)
def _load_meta(self) -> None: def _load_meta(self) -> None:
try: try:
file = ZipFile(self.path) self._file = ZipFile(self.path)
data = file.read("maubot.ini") data = self._file.read("maubot.ini")
except FileNotFoundError as e: except FileNotFoundError as e:
raise MaubotZipImportError("Maubot plugin not found") from e raise MaubotZipImportError("Maubot plugin not found") from e
except BadZipFile as e: except BadZipFile as e:

View File

@ -14,8 +14,12 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional from typing import Dict, List, Optional
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
import logging import logging
import io
from mautrix.util import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID from mautrix.types import UserID
from .db import DBPlugin from .db import DBPlugin
@ -25,6 +29,9 @@ from .plugin_base import Plugin
log = logging.getLogger("maubot.plugin") log = logging.getLogger("maubot.plugin")
yaml = YAML()
yaml.indent(4)
class PluginInstance: class PluginInstance:
cache: Dict[str, 'PluginInstance'] = {} cache: Dict[str, 'PluginInstance'] = {}
@ -34,10 +41,12 @@ class PluginInstance:
loader: PluginLoader loader: PluginLoader
client: Client client: Client
plugin: Plugin plugin: Plugin
config: BaseProxyConfig
def __init__(self, db_instance: DBPlugin): def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance self.db_instance = db_instance
self.log = logging.getLogger(f"maubot.plugin.{self.id}") self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.config = None
self.cache[self.id] = self self.cache[self.id] = self
def load(self) -> None: def load(self) -> None:
@ -53,12 +62,31 @@ class PluginInstance:
self.enabled = False self.enabled = False
self.log.debug("Plugin instance dependencies loaded") self.log.debug("Plugin instance dependencies loaded")
def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config)
def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]:
try:
base = self.loader.read_file("base-config.yaml")
return yaml.load(base.decode("utf-8"))
except (FileNotFoundError, KeyError):
return None
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
self.db_instance.config = buf.getvalue()
async def start(self) -> None: async def start(self) -> None:
if not self.enabled: if not self.enabled:
self.log.warn(f"Plugin disabled, not starting.") self.log.warning(f"Plugin disabled, not starting.")
return return
cls = self.loader.load() cls = self.loader.load()
self.plugin = cls(self.client.client, self.id, self.log) config_class = cls.get_config_class()
if config_class:
self.config = config_class(self.load_config, self.load_config_base,
self.save_config)
self.plugin = cls(self.client.client, self.id, self.log, self.config)
self.loader.references |= {self} self.loader.references |= {self}
await self.plugin.start() await self.plugin.start()
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} " self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "

View File

@ -13,30 +13,40 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING from typing import Type, Optional, TYPE_CHECKING
from logging import Logger from logging import Logger
from abc import ABC from abc import ABC, abstractmethod
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import MaubotMatrixClient from .client import MaubotMatrixClient
from .command_spec import CommandSpec from .command_spec import CommandSpec
from mautrix.util import BaseProxyConfig
class Plugin(ABC): class Plugin(ABC):
client: 'MaubotMatrixClient' client: 'MaubotMatrixClient'
id: str id: str
log: Logger log: Logger
config: Optional['BaseProxyConfig']
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None: def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger,
config: Optional['BaseProxyConfig']) -> None:
self.client = client self.client = client
self.id = plugin_instance_id self.id = plugin_instance_id
self.log = log self.log = log
self.config = config
def set_command_spec(self, spec: 'CommandSpec') -> None: def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec) self.client.set_command_spec(self.id, spec)
@abstractmethod
async def start(self) -> None: async def start(self) -> None:
pass pass
@abstractmethod
async def stop(self) -> None: async def stop(self) -> None:
pass pass
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
return None