Stop using SQLAlchemy ORM and add colorful logs

This commit is contained in:
Tulir Asokan 2019-09-01 14:46:08 +03:00
parent 59998b99b1
commit b59eab2953
8 changed files with 90 additions and 65 deletions

View File

@ -72,18 +72,21 @@ api_features:
logging: logging:
version: 1 version: 1
formatters: formatters:
precise: colored:
(): maubot.lib.color_log.ColorFormatter
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
normal:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers: handlers:
file: file:
class: logging.handlers.RotatingFileHandler class: logging.handlers.RotatingFileHandler
formatter: precise formatter: normal
filename: ./logs/maubot.log filename: ./maubot.log
maxBytes: 10485760 maxBytes: 10485760
backupCount: 10 backupCount: 10
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: colored
loggers: loggers:
maubot: maubot:
level: DEBUG level: DEBUG

View File

@ -56,11 +56,11 @@ log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
init_zip_loader(config) init_zip_loader(config)
db_session = init_db(config) db_engine = init_db(config)
clients = init_client_class(db_session, loop) clients = init_client_class(loop)
management_api = init_mgmt_api(config, loop) management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop) server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(db_session, config, server, loop) plugins = init_plugin_instance_class(config, server, loop)
for plugin in plugins: for plugin in plugins:
plugin.load() plugin.load()
@ -69,30 +69,17 @@ signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler) signal.signal(signal.SIGTERM, signal.default_int_handler)
async def periodic_commit():
while True:
await asyncio.sleep(60)
db_session.commit()
periodic_commit_task: asyncio.Future = None
try: try:
log.info("Starting server") log.info("Starting server")
loop.run_until_complete(server.start()) loop.run_until_complete(server.start())
log.info("Starting clients and plugins") log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop)) loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever") log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
log.info("Interrupt received, stopping HTTP clients/servers and saving database") log.info("Interrupt received, stopping clients")
if periodic_commit_task is not None:
periodic_commit_task.cancel()
log.debug("Stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()], loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop)) loop=loop))
db_session.commit()
if stop_log_listener is not None: if stop_log_listener is not None:
log.debug("Closing websockets") log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener()) loop.run_until_complete(stop_log_listener())

View File

@ -13,11 +13,10 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
import asyncio import asyncio
import logging import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.errors import MatrixInvalidToken, MatrixRequestError
@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
class Client: class Client:
db: Session = None
log: logging.Logger = None log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
@ -148,9 +146,7 @@ class Client:
def clear_cache(self) -> None: def clear_cache(self) -> None:
self.stop_sync() self.stop_sync()
self.db_instance.filter_id = "" self.db_instance.edit(filter_id="", next_batch="")
self.db_instance.next_batch = ""
self.db.commit()
self.start_sync() self.start_sync()
def delete(self) -> None: def delete(self) -> None:
@ -158,8 +154,7 @@ class Client:
del self.cache[self.id] del self.cache[self.id]
except KeyError: except KeyError:
pass pass
self.db.delete(self.db_instance) self.db_instance.delete()
self.db.commit()
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@ -183,14 +178,14 @@ class Client:
try: try:
return cls.cache[user_id] return cls.cache[user_id]
except KeyError: except KeyError:
db_instance = db_instance or DBClient.query.get(user_id) db_instance = db_instance or DBClient.get(user_id)
if not db_instance: if not db_instance:
return None return None
return Client(db_instance) return Client(db_instance)
@classmethod @classmethod
def all(cls) -> List['Client']: def all(cls) -> Iterable['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()] return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_invite(self, evt: StrippedStateEvent) -> None: async def _handle_invite(self, evt: StrippedStateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
@ -314,8 +309,7 @@ class Client:
# endregion # endregion
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]: def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.db = db
Client.http_client = ClientSession(loop=loop) Client.http_client = ClientSession(loop=loop)
Client.loop = loop Client.loop = loop
return Client.all() return Client.all()

View File

@ -13,22 +13,19 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import cast from typing import Iterable, Optional
from sqlalchemy import Column, String, Boolean, ForeignKey, Text from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.bridge.db import Base
from .config import Config from .config import Config
Base: declarative_base = declarative_base()
class DBPlugin(Base): class DBPlugin(Base):
query: Query
__tablename__ = "plugin" __tablename__ = "plugin"
id: str = Column(String(255), primary_key=True) id: str = Column(String(255), primary_key=True)
@ -39,9 +36,16 @@ class DBPlugin(Base):
nullable=False) nullable=False)
config: str = Column(Text, nullable=False, default='') config: str = Column(Text, nullable=False, default='')
@classmethod
def all(cls) -> Iterable['DBPlugin']:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBPlugin']:
return cls._select_one_or_none(cls.c.id == id)
class DBClient(Base): class DBClient(Base):
query: Query
__tablename__ = "client" __tablename__ = "client"
id: UserID = Column(String(255), primary_key=True) id: UserID = Column(String(255), primary_key=True)
@ -58,15 +62,23 @@ class DBClient(Base):
displayname: str = Column(String(255), nullable=False, default="") displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="") avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBClient']:
return cls._select_all()
def init(config: Config) -> Session: @classmethod
db_engine: sql.engine.Engine = sql.create_engine(config["database"]) def get(cls, id: str) -> Optional['DBClient']:
db_factory = sessionmaker(bind=db_engine) return cls._select_one_or_none(cls.c.id == id)
db_session = scoped_session(db_factory)
def init(config: Config) -> Engine:
db_engine = sql.create_engine(config["database"])
Base.metadata.bind = db_engine Base.metadata.bind = db_engine
Base.metadata.create_all()
DBPlugin.query = db_session.query_property() for table in (DBPlugin, DBClient):
DBClient.query = db_session.query_property() table.db = db_engine
table.t = table.__table__
table.c = table.t.c
table.column_names = table.c.keys()
return cast(Session, db_session) return db_engine

View File

@ -13,16 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, TYPE_CHECKING from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from aiohttp import web
import os.path import os.path
import logging import logging
import io import io
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML from ruamel.yaml import YAML
from sqlalchemy.orm import Session
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.config import BaseProxyConfig, RecursiveDict
@ -44,7 +42,6 @@ yaml.indent(4)
class PluginInstance: class PluginInstance:
db: Session = None
webserver: 'MaubotServer' = None webserver: 'MaubotServer' = None
mb_config: Config = None mb_config: Config = None
loop: AbstractEventLoop = None loop: AbstractEventLoop = None
@ -130,8 +127,7 @@ class PluginInstance:
del self.cache[self.id] del self.cache[self.id]
except KeyError: except KeyError:
pass pass
self.db.delete(self.db_instance) self.db_instance.delete()
self.db.commit()
if self.inst_db: if self.inst_db:
self.inst_db.dispose() self.inst_db.dispose()
ZippedPluginLoader.trash( ZippedPluginLoader.trash(
@ -207,14 +203,14 @@ class PluginInstance:
try: try:
return cls.cache[instance_id] return cls.cache[instance_id]
except KeyError: except KeyError:
db_instance = db_instance or DBPlugin.query.get(instance_id) db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance: if not db_instance:
return None return None
return PluginInstance(db_instance) return PluginInstance(db_instance)
@classmethod @classmethod
def all(cls) -> List['PluginInstance']: def all(cls) -> Iterable['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()] return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None: def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id: if new_id is not None and new_id != self.id:
@ -293,9 +289,8 @@ class PluginInstance:
# endregion # endregion
def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[ def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
PluginInstance]: ) -> Iterable[PluginInstance]:
PluginInstance.db = db
PluginInstance.mb_config = config PluginInstance.mb_config = config
PluginInstance.loop = loop PluginInstance.loop = loop
PluginInstance.webserver = webserver PluginInstance.webserver = webserver

36
maubot/lib/color_log.py Normal file
View File

@ -0,0 +1,36 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
MXID_COLOR, RESET)
INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
client = "maubot.client"
if module.startswith(client):
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
instance = "maubot.instance"
if module.startswith(instance):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
loader = "maubot.loader"
if module.startswith(loader):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot"):
return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module)

View File

@ -68,8 +68,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
displayname=data.get("displayname", ""), displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", "")) avatar_url=data.get("avatar_url", ""))
client = Client(db_instance) client = Client(db_instance)
Client.db.add(db_instance) client.db_instance.insert()
Client.db.commit()
await client.start() await client.start()
return resp.created(client.to_dict()) return resp.created(client.to_dict())

View File

@ -56,8 +56,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
primary_user=primary_user, config=data.get("config", "")) primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance) instance = PluginInstance(db_instance)
instance.load() instance.load()
PluginInstance.db.add(db_instance) instance.db_instance.insert()
PluginInstance.db.commit()
await instance.start() await instance.start()
return resp.created(instance.to_dict()) return resp.created(instance.to_dict())