Make it run

This commit is contained in:
Tulir Asokan 2018-10-16 22:15:35 +03:00
parent eef052b1e9
commit dce2771588
8 changed files with 88 additions and 28 deletions

View File

@ -20,17 +20,22 @@ import argparse
import asyncio import asyncio
import copy import copy
import sys import sys
import os
from .config import Config from .config import Config
from .db import Base, init as init_db from .db import Base, init as init_db
from .server import MaubotServer from .server import MaubotServer
from .client import Client, init as init_client from .client import Client, init as init_client
from .loader import ZippedPluginLoader, MaubotZipImportError
from .__meta__ import __version__ from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot") prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml", parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file") metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args() args = parser.parse_args()
config = Config(args.config, args.base_config) config = Config(args.config, args.base_config)
@ -38,13 +43,14 @@ config.load()
config.update() config.update()
logging.config.dictConfig(copy.deepcopy(config["logging"])) logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot") log = logging.getLogger("maubot.init")
log.debug(f"Initializing maubot {__version__}") log.debug(f"Initializing maubot {__version__}")
db_engine: sql.engine.Engine = sql.create_engine(config["database"]) db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = orm.sessionmaker(bind=db_engine) db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory) db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind = db_engine Base.metadata.bind = db_engine
Base.metadata.create_all()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -52,8 +58,22 @@ init_db(db_session)
init_client(loop) init_client(loop)
server = MaubotServer(config, loop) server = MaubotServer(config, loop)
loader_log = logging.getLogger("maubot.loader.zip")
loader_log.debug("Preloading plugins...")
for directory in config["plugin_directories"]:
for file in os.listdir(directory):
if not file.endswith(".mbp"):
continue
path = os.path.join(directory, file)
try:
loader = ZippedPluginLoader.get(path)
loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
except MaubotZipImportError:
loader_log.exception(f"Failed to load plugin at {path}.")
try: try:
loop.run_until_complete(server.start()) loop.run_until_complete(server.start())
log.debug("Startup actions complete, running forever.")
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...") log.debug("Keyboard interrupt received, stopping...")

View File

@ -1 +1 @@
__version__ = "0.1.0+dev" __version__ = "0.1.0.dev1"

View File

@ -13,14 +13,28 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import random
import string
from mautrix.util import BaseConfig from mautrix.util import BaseConfig
class Config(BaseConfig): class Config(BaseConfig):
@staticmethod
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self): def update(self):
base, copy, copy_dict = self._pre_update() base, copy, copy_dict = self._pre_update()
copy("database") copy("database")
copy("plugin_directories") copy("plugin_directories")
copy("server.hostname")
copy("server.port")
copy("server.listen") copy("server.listen")
copy("server.base_path") copy("server.base_path")
shared_secret = self["server.shared_secret"]
if shared_secret is None or shared_secret == "generate":
base["server.shared_secret"] = self._new_token()
else:
base["server.shared_secret"] = shared_secret
copy("logging") copy("logging")

View File

@ -60,7 +60,7 @@ class DBPlugin(Base):
nullable=False) nullable=False)
class DBClient(ClientStore, Base): class DBClient(Base):
query: Query query: Query
__tablename__ = "client" __tablename__ = "client"

View File

@ -1,2 +1,2 @@
from .abc import PluginLoader, PluginClass from .abc import PluginLoader, PluginClass
from .zip import ZippedPluginLoader from .zip import ZippedPluginLoader, MaubotZipImportError

View File

@ -21,12 +21,21 @@ from ..plugin_base import Plugin
PluginClass = TypeVar("PluginClass", bound=Plugin) PluginClass = TypeVar("PluginClass", bound=Plugin)
class IDConflictError(Exception):
pass
class PluginLoader(ABC): class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {} id_cache: Dict[str, 'PluginLoader'] = {}
id: str id: str
version: str version: str
@property
@abstractmethod
def source(self) -> str:
pass
@abstractmethod @abstractmethod
def load(self) -> Type[PluginClass]: def load(self) -> Type[PluginClass]:
pass pass

View File

@ -20,7 +20,7 @@ import configparser
from ..lib.zipimport import zipimporter, ZipImportError from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin from ..plugin_base import Plugin
from .abc import PluginLoader, PluginClass from .abc import PluginLoader, PluginClass, IDConflictError
class MaubotZipImportError(Exception): class MaubotZipImportError(Exception):
@ -36,35 +36,50 @@ class ZippedPluginLoader(PluginLoader):
modules: List[str] modules: List[str]
main_class: str main_class: str
main_module: str main_module: str
loaded: bool _loaded: Type[PluginClass]
_importer: zipimporter _importer: zipimporter
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
self.path = path self.path = path
self.id = None self.id = None
self.loaded = False self._loaded = None
self._load_meta() self._load_meta()
self._run_preload_checks(self._get_importer()) self._run_preload_checks(self._get_importer())
try:
existing = self.id_cache[self.id]
raise IDConflictError(f"Plugin with id {self.id} already loaded from {existing.source}")
except KeyError:
pass
self.path_cache[self.path] = self self.path_cache[self.path] = self
self.id_cache[self.id] = self self.id_cache[self.id] = self
@classmethod
def get(cls, path: str) -> 'ZippedPluginLoader':
try:
return cls.path_cache[path]
except KeyError:
return cls(path)
@property
def source(self) -> str:
return self.path
def __repr__(self) -> str: def __repr__(self) -> str:
return ("<ZippedPlugin " return ("<ZippedPlugin "
f"path='{self.path}' " f"path='{self.path}' "
f"id='{self.id}' " f"id='{self.id}' "
f"loaded={self.loaded}>") f"loaded={self._loaded}>")
def _load_meta(self) -> None: def _load_meta(self) -> None:
try: try:
file = ZipFile(self.path) file = ZipFile(self.path)
data = file.read("maubot.ini") data = file.read("maubot.ini")
except FileNotFoundError as e: except FileNotFoundError as e:
raise MaubotZipImportError(f"Maubot plugin not found at {self.path}") from e raise MaubotZipImportError("Maubot plugin not found") from e
except BadZipFile as e: except BadZipFile as e:
raise MaubotZipImportError(f"File at {self.path} is not a maubot plugin") from e raise MaubotZipImportError("File is not a maubot plugin") from e
except KeyError as e: except KeyError as e:
raise MaubotZipImportError( raise MaubotZipImportError("File does not contain a maubot plugin definition") from e
"File at {path} does not contain a maubot plugin definition") from e
config = configparser.ConfigParser() config = configparser.ConfigParser()
try: try:
config.read_string(data.decode("utf-8"), source=f"{self.path}/maubot.ini") config.read_string(data.decode("utf-8"), source=f"{self.path}/maubot.ini")
@ -77,8 +92,7 @@ class ZippedPluginLoader(PluginLoader):
if "/" in main_class: if "/" in main_class:
main_module, main_class = main_class.split("/")[:2] main_module, main_class = main_class.split("/")[:2]
except (configparser.Error, KeyError, IndexError, ValueError) as e: except (configparser.Error, KeyError, IndexError, ValueError) as e:
raise MaubotZipImportError( raise MaubotZipImportError("Maubot plugin definition in file is invalid") from e
f"Maubot plugin definition in file at {self.path} is invalid") from e
if self.id and meta_id != self.id: if self.id and meta_id != self.id:
raise MaubotZipImportError("Maubot plugin ID changed during reload") raise MaubotZipImportError("Maubot plugin ID changed during reload")
self.id, self.version, self.modules = meta_id, version, modules self.id, self.version, self.modules = meta_id, version, modules
@ -91,8 +105,7 @@ class ZippedPluginLoader(PluginLoader):
importer.reset_cache() importer.reset_cache()
return importer return importer
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipImportError( raise MaubotZipImportError("File not found or not a maubot plugin") from e
f"File at {self.path} not found or not a maubot plugin") from e
def _run_preload_checks(self, importer: zipimporter) -> None: def _run_preload_checks(self, importer: zipimporter) -> None:
try: try:
@ -102,29 +115,30 @@ class ZippedPluginLoader(PluginLoader):
f"Main class {self.main_class} not in {self.main_module}") f"Main class {self.main_class} not in {self.main_module}")
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipImportError( raise MaubotZipImportError(
f"Main module {self.main_module} not found in {self.path}") from e f"Main module {self.main_module} not found in file") from e
for module in self.modules: for module in self.modules:
try: try:
importer.find_module(module) importer.find_module(module)
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipImportError(f"Module {module} not found in {self.path}") from e raise MaubotZipImportError(f"Module {module} not found in file") from e
def load(self) -> Type[PluginClass]: def load(self, reset_cache: bool = False) -> Type[PluginClass]:
importer = self._get_importer(reset_cache=self.loaded) if self._loaded is not None and not reset_cache:
return self._loaded
importer = self._get_importer(reset_cache=reset_cache)
self._run_preload_checks(importer) self._run_preload_checks(importer)
for module in self.modules: for module in self.modules:
importer.load_module(module) importer.load_module(module)
self.loaded = True
main_mod = sys.modules[self.main_module] main_mod = sys.modules[self.main_module]
plugin = getattr(main_mod, self.main_class) plugin = getattr(main_mod, self.main_class)
if not issubclass(plugin, Plugin): if not issubclass(plugin, Plugin):
raise MaubotZipImportError( raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
f"Main class of plugin at {self.path} does not extend maubot.Plugin") self._loaded = plugin
return plugin return plugin
def reload(self) -> Type[PluginClass]: def reload(self) -> Type[PluginClass]:
self.unload() self.unload()
return self.load() return self.load(reset_cache=True)
def unload(self) -> None: def unload(self) -> None:
for name, mod in list(sys.modules.items()): for name, mod in list(sys.modules.items()):

View File

@ -16,7 +16,7 @@
from aiohttp import web from aiohttp import web
import asyncio import asyncio
from mautrix.api import PathBuilder from mautrix.api import PathBuilder, Method
from .config import Config from .config import Config
from .__meta__ import __version__ from .__meta__ import __version__
@ -29,13 +29,16 @@ class MaubotServer:
self.config = config self.config = config
path = PathBuilder(config["server.base_path"]) path = PathBuilder(config["server.base_path"])
self.app.router.add_get(path.version, self.version) self.add_route(Method.GET, path.version, self.version)
as_path = PathBuilder(config["server.appservice_base_path"]) as_path = PathBuilder(config["server.appservice_base_path"])
self.app.router.add_put(as_path.transactions, self.handle_transaction) self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
self.runner = web.AppRunner(self.app) self.runner = web.AppRunner(self.app)
def add_route(self, method: Method, path: PathBuilder, handler) -> None:
self.app.router.add_route(method.value, str(path), handler)
async def start(self) -> None: async def start(self) -> None:
await self.runner.setup() await self.runner.setup()
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"]) site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])