diff --git a/example-config.yaml b/example-config.yaml
index 6579918..9e7aa77 100644
--- a/example-config.yaml
+++ b/example-config.yaml
@@ -72,18 +72,21 @@ api_features:
logging:
version: 1
formatters:
- precise:
+ colored:
+ (): maubot.lib.color_log.ColorFormatter
+ format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
+ normal:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
- formatter: precise
- filename: ./logs/maubot.log
+ formatter: normal
+ filename: ./maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
- formatter: precise
+ formatter: colored
loggers:
maubot:
level: DEBUG
diff --git a/maubot/__main__.py b/maubot/__main__.py
index 6c0a96a..412201c 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -56,11 +56,11 @@ log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
init_zip_loader(config)
-db_session = init_db(config)
-clients = init_client_class(db_session, loop)
+db_engine = init_db(config)
+clients = init_client_class(loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
-plugins = init_plugin_instance_class(db_session, config, server, loop)
+plugins = init_plugin_instance_class(config, server, loop)
for plugin in plugins:
plugin.load()
@@ -69,30 +69,17 @@ signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
-async def periodic_commit():
- while True:
- await asyncio.sleep(60)
- db_session.commit()
-
-
-periodic_commit_task: asyncio.Future = None
-
try:
log.info("Starting server")
loop.run_until_complete(server.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever")
- periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
- log.info("Interrupt received, stopping HTTP clients/servers and saving database")
- if periodic_commit_task is not None:
- periodic_commit_task.cancel()
- log.debug("Stopping clients")
+ log.info("Interrupt received, stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop))
- db_session.commit()
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
diff --git a/maubot/client.py b/maubot/client.py
index 2b6adfe..9e1c9cc 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -13,11 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
+from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
import asyncio
import logging
-from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
@@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
class Client:
- db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@@ -148,9 +146,7 @@ class Client:
def clear_cache(self) -> None:
self.stop_sync()
- self.db_instance.filter_id = ""
- self.db_instance.next_batch = ""
- self.db.commit()
+ self.db_instance.edit(filter_id="", next_batch="")
self.start_sync()
def delete(self) -> None:
@@ -158,8 +154,7 @@ class Client:
del self.cache[self.id]
except KeyError:
pass
- self.db.delete(self.db_instance)
- self.db.commit()
+ self.db_instance.delete()
def to_dict(self) -> dict:
return {
@@ -183,14 +178,14 @@ class Client:
try:
return cls.cache[user_id]
except KeyError:
- db_instance = db_instance or DBClient.query.get(user_id)
+ db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
- def all(cls) -> List['Client']:
- return [cls.get(user.id, user) for user in DBClient.query.all()]
+ def all(cls) -> Iterable['Client']:
+ return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
@@ -314,8 +309,7 @@ class Client:
# endregion
-def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
- Client.db = db
+def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()
diff --git a/maubot/db.py b/maubot/db.py
index b1ef598..36cba6d 100644
--- a/maubot/db.py
+++ b/maubot/db.py
@@ -13,22 +13,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import cast
+from typing import Iterable, Optional
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
-from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
+from mautrix.bridge.db import Base
from .config import Config
-Base: declarative_base = declarative_base()
-
class DBPlugin(Base):
- query: Query
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
@@ -39,9 +36,16 @@ class DBPlugin(Base):
nullable=False)
config: str = Column(Text, nullable=False, default='')
+ @classmethod
+ def all(cls) -> Iterable['DBPlugin']:
+ return cls._select_all()
+
+ @classmethod
+ def get(cls, id: str) -> Optional['DBPlugin']:
+ return cls._select_one_or_none(cls.c.id == id)
+
class DBClient(Base):
- query: Query
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
@@ -58,15 +62,23 @@ class DBClient(Base):
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
+ @classmethod
+ def all(cls) -> Iterable['DBClient']:
+ return cls._select_all()
-def init(config: Config) -> Session:
- db_engine: sql.engine.Engine = sql.create_engine(config["database"])
- db_factory = sessionmaker(bind=db_engine)
- db_session = scoped_session(db_factory)
+ @classmethod
+ def get(cls, id: str) -> Optional['DBClient']:
+ return cls._select_one_or_none(cls.c.id == id)
+
+
+def init(config: Config) -> Engine:
+ db_engine = sql.create_engine(config["database"])
Base.metadata.bind = db_engine
- Base.metadata.create_all()
- DBPlugin.query = db_session.query_property()
- DBClient.query = db_session.query_property()
+ for table in (DBPlugin, DBClient):
+ table.db = db_engine
+ table.t = table.__table__
+ table.c = table.t.c
+ table.column_names = table.c.keys()
- return cast(Session, db_session)
+ return db_engine
diff --git a/maubot/instance.py b/maubot/instance.py
index 7da5928..d13c436 100644
--- a/maubot/instance.py
+++ b/maubot/instance.py
@@ -13,16 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, List, Optional, TYPE_CHECKING
+from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop
-from aiohttp import web
import os.path
import logging
import io
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
-from sqlalchemy.orm import Session
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
@@ -44,7 +42,6 @@ yaml.indent(4)
class PluginInstance:
- db: Session = None
webserver: 'MaubotServer' = None
mb_config: Config = None
loop: AbstractEventLoop = None
@@ -130,8 +127,7 @@ class PluginInstance:
del self.cache[self.id]
except KeyError:
pass
- self.db.delete(self.db_instance)
- self.db.commit()
+ self.db_instance.delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
@@ -207,14 +203,14 @@ class PluginInstance:
try:
return cls.cache[instance_id]
except KeyError:
- db_instance = db_instance or DBPlugin.query.get(instance_id)
+ db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod
- def all(cls) -> List['PluginInstance']:
- return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
+ def all(cls) -> Iterable['PluginInstance']:
+ return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
@@ -293,9 +289,8 @@ class PluginInstance:
# endregion
-def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[
- PluginInstance]:
- PluginInstance.db = db
+def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
+ ) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py
new file mode 100644
index 0000000..be894ff
--- /dev/null
+++ b/maubot/lib/color_log.py
@@ -0,0 +1,36 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2019 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
+ MXID_COLOR, RESET)
+
+INST_COLOR = PREFIX + "35m" # magenta
+LOADER_COLOR = PREFIX + "36m" # blue
+
+
+class ColorFormatter(BaseColorFormatter):
+ def _color_name(self, module: str) -> str:
+ client = "maubot.client"
+ if module.startswith(client):
+ return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
+ instance = "maubot.instance"
+ if module.startswith(instance):
+ return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
+ loader = "maubot.loader"
+ if module.startswith(loader):
+ return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
+ if module.startswith("maubot"):
+ return f"{MAU_COLOR}{module}{RESET}"
+ return super()._color_name(module)
diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py
index a5e38d1..7bfccf5 100644
--- a/maubot/management/api/client.py
+++ b/maubot/management/api/client.py
@@ -68,8 +68,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
client = Client(db_instance)
- Client.db.add(db_instance)
- Client.db.commit()
+ client.db_instance.insert()
await client.start()
return resp.created(client.to_dict())
diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py
index 944c41f..2114907 100644
--- a/maubot/management/api/instance.py
+++ b/maubot/management/api/instance.py
@@ -56,8 +56,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance)
instance.load()
- PluginInstance.db.add(db_instance)
- PluginInstance.db.commit()
+ instance.db_instance.insert()
await instance.start()
return resp.created(instance.to_dict())