diff --git a/example-config.yaml b/example-config.yaml index 6579918..9e7aa77 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -72,18 +72,21 @@ api_features: logging: version: 1 formatters: - precise: + colored: + (): maubot.lib.color_log.ColorFormatter + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + normal: format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" handlers: file: class: logging.handlers.RotatingFileHandler - formatter: precise - filename: ./logs/maubot.log + formatter: normal + filename: ./maubot.log maxBytes: 10485760 backupCount: 10 console: class: logging.StreamHandler - formatter: precise + formatter: colored loggers: maubot: level: DEBUG diff --git a/maubot/__main__.py b/maubot/__main__.py index 6c0a96a..412201c 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -56,11 +56,11 @@ log.info(f"Initializing maubot {__version__}") loop = asyncio.get_event_loop() init_zip_loader(config) -db_session = init_db(config) -clients = init_client_class(db_session, loop) +db_engine = init_db(config) +clients = init_client_class(loop) management_api = init_mgmt_api(config, loop) server = MaubotServer(management_api, config, loop) -plugins = init_plugin_instance_class(db_session, config, server, loop) +plugins = init_plugin_instance_class(config, server, loop) for plugin in plugins: plugin.load() @@ -69,30 +69,17 @@ signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler) -async def periodic_commit(): - while True: - await asyncio.sleep(60) - db_session.commit() - - -periodic_commit_task: asyncio.Future = None - try: log.info("Starting server") loop.run_until_complete(server.start()) log.info("Starting clients and plugins") loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop)) log.info("Startup actions complete, running forever") - periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop) loop.run_forever() except KeyboardInterrupt: - log.info("Interrupt received, stopping HTTP clients/servers and saving database") - if periodic_commit_task is not None: - periodic_commit_task.cancel() - log.debug("Stopping clients") + log.info("Interrupt received, stopping clients") loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()], loop=loop)) - db_session.commit() if stop_log_listener is not None: log.debug("Closing websockets") loop.run_until_complete(stop_log_listener()) diff --git a/maubot/client.py b/maubot/client.py index 2b6adfe..9e1c9cc 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -13,11 +13,10 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING +from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING import asyncio import logging -from sqlalchemy.orm import Session from aiohttp import ClientSession from mautrix.errors import MatrixInvalidToken, MatrixRequestError @@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client") class Client: - db: Session = None log: logging.Logger = None loop: asyncio.AbstractEventLoop = None cache: Dict[UserID, 'Client'] = {} @@ -148,9 +146,7 @@ class Client: def clear_cache(self) -> None: self.stop_sync() - self.db_instance.filter_id = "" - self.db_instance.next_batch = "" - self.db.commit() + self.db_instance.edit(filter_id="", next_batch="") self.start_sync() def delete(self) -> None: @@ -158,8 +154,7 @@ class Client: del self.cache[self.id] except KeyError: pass - self.db.delete(self.db_instance) - self.db.commit() + self.db_instance.delete() def to_dict(self) -> dict: return { @@ -183,14 +178,14 @@ class Client: try: return cls.cache[user_id] except KeyError: - db_instance = db_instance or DBClient.query.get(user_id) + db_instance = db_instance or DBClient.get(user_id) if not db_instance: return None return Client(db_instance) @classmethod - def all(cls) -> List['Client']: - return [cls.get(user.id, user) for user in DBClient.query.all()] + def all(cls) -> Iterable['Client']: + return (cls.get(user.id, user) for user in DBClient.all()) async def _handle_invite(self, evt: StrippedStateEvent) -> None: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: @@ -314,8 +309,7 @@ class Client: # endregion -def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]: - Client.db = db +def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]: Client.http_client = ClientSession(loop=loop) Client.loop = loop return Client.all() diff --git a/maubot/db.py b/maubot/db.py index b1ef598..36cba6d 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -13,22 +13,19 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import cast +from typing import Iterable, Optional from sqlalchemy import Column, String, Boolean, ForeignKey, Text -from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.engine.base import Engine import sqlalchemy as sql from mautrix.types import UserID, FilterID, SyncToken, ContentURI +from mautrix.bridge.db import Base from .config import Config -Base: declarative_base = declarative_base() - class DBPlugin(Base): - query: Query __tablename__ = "plugin" id: str = Column(String(255), primary_key=True) @@ -39,9 +36,16 @@ class DBPlugin(Base): nullable=False) config: str = Column(Text, nullable=False, default='') + @classmethod + def all(cls) -> Iterable['DBPlugin']: + return cls._select_all() + + @classmethod + def get(cls, id: str) -> Optional['DBPlugin']: + return cls._select_one_or_none(cls.c.id == id) + class DBClient(Base): - query: Query __tablename__ = "client" id: UserID = Column(String(255), primary_key=True) @@ -58,15 +62,23 @@ class DBClient(Base): displayname: str = Column(String(255), nullable=False, default="") avatar_url: ContentURI = Column(String(255), nullable=False, default="") + @classmethod + def all(cls) -> Iterable['DBClient']: + return cls._select_all() -def init(config: Config) -> Session: - db_engine: sql.engine.Engine = sql.create_engine(config["database"]) - db_factory = sessionmaker(bind=db_engine) - db_session = scoped_session(db_factory) + @classmethod + def get(cls, id: str) -> Optional['DBClient']: + return cls._select_one_or_none(cls.c.id == id) + + +def init(config: Config) -> Engine: + db_engine = sql.create_engine(config["database"]) Base.metadata.bind = db_engine - Base.metadata.create_all() - DBPlugin.query = db_session.query_property() - DBClient.query = db_session.query_property() + for table in (DBPlugin, DBClient): + table.db = db_engine + table.t = table.__table__ + table.c = table.t.c + table.column_names = table.c.keys() - return cast(Session, db_session) + return db_engine diff --git a/maubot/instance.py b/maubot/instance.py index 7da5928..d13c436 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -13,16 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, Iterable, TYPE_CHECKING from asyncio import AbstractEventLoop -from aiohttp import web import os.path import logging import io from ruamel.yaml.comments import CommentedMap from ruamel.yaml import YAML -from sqlalchemy.orm import Session import sqlalchemy as sql from mautrix.util.config import BaseProxyConfig, RecursiveDict @@ -44,7 +42,6 @@ yaml.indent(4) class PluginInstance: - db: Session = None webserver: 'MaubotServer' = None mb_config: Config = None loop: AbstractEventLoop = None @@ -130,8 +127,7 @@ class PluginInstance: del self.cache[self.id] except KeyError: pass - self.db.delete(self.db_instance) - self.db.commit() + self.db_instance.delete() if self.inst_db: self.inst_db.dispose() ZippedPluginLoader.trash( @@ -207,14 +203,14 @@ class PluginInstance: try: return cls.cache[instance_id] except KeyError: - db_instance = db_instance or DBPlugin.query.get(instance_id) + db_instance = db_instance or DBPlugin.get(instance_id) if not db_instance: return None return PluginInstance(db_instance) @classmethod - def all(cls) -> List['PluginInstance']: - return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()] + def all(cls) -> Iterable['PluginInstance']: + return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all()) def update_id(self, new_id: str) -> None: if new_id is not None and new_id != self.id: @@ -293,9 +289,8 @@ class PluginInstance: # endregion -def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[ - PluginInstance]: - PluginInstance.db = db +def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop + ) -> Iterable[PluginInstance]: PluginInstance.mb_config = config PluginInstance.loop = loop PluginInstance.webserver = webserver diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py new file mode 100644 index 0000000..be894ff --- /dev/null +++ b/maubot/lib/color_log.py @@ -0,0 +1,36 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2019 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR, + MXID_COLOR, RESET) + +INST_COLOR = PREFIX + "35m" # magenta +LOADER_COLOR = PREFIX + "36m" # blue + + +class ColorFormatter(BaseColorFormatter): + def _color_name(self, module: str) -> str: + client = "maubot.client" + if module.startswith(client): + return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}" + instance = "maubot.instance" + if module.startswith(instance): + return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" + loader = "maubot.loader" + if module.startswith(loader): + return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" + if module.startswith("maubot"): + return f"{MAU_COLOR}{module}{RESET}" + return super()._color_name(module) diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index a5e38d1..7bfccf5 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -68,8 +68,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: displayname=data.get("displayname", ""), avatar_url=data.get("avatar_url", "")) client = Client(db_instance) - Client.db.add(db_instance) - Client.db.commit() + client.db_instance.insert() await client.start() return resp.created(client.to_dict()) diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index 944c41f..2114907 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -56,8 +56,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response: primary_user=primary_user, config=data.get("config", "")) instance = PluginInstance(db_instance) instance.load() - PluginInstance.db.add(db_instance) - PluginInstance.db.commit() + instance.db_instance.insert() await instance.start() return resp.created(instance.to_dict())