Refactor __main__.py and fix things
This commit is contained in:
parent
c685eb5e08
commit
7c9668d8bc
@ -1,5 +1,5 @@
|
|||||||
# maubot - A plugin-based Matrix bot system.
|
# maubot - A plugin-based Matrix bot system.
|
||||||
# Copyright (C) 2019 Tulir Asokan
|
# Copyright (C) 2021 Tulir Asokan
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -13,12 +13,9 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import logging.config
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
|
||||||
import copy
|
from mautrix.util.program import Program
|
||||||
import sys
|
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import init as init_db
|
from .db import init as init_db
|
||||||
@ -27,70 +24,58 @@ from .client import Client, init as init_client_class
|
|||||||
from .loader.zip import init as init_zip_loader
|
from .loader.zip import init as init_zip_loader
|
||||||
from .instance import init as init_plugin_instance_class
|
from .instance import init as init_plugin_instance_class
|
||||||
from .management.api import init as init_mgmt_api
|
from .management.api import init as init_mgmt_api
|
||||||
|
from .lib.future_awaitable import FutureAwaitable
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
|
||||||
prog="python -m maubot")
|
|
||||||
parser.add_argument("-c", "--config", type=str, default="config.yaml",
|
|
||||||
metavar="<path>", help="the path to your config file")
|
|
||||||
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
|
|
||||||
metavar="<path>", help="the path to the example config "
|
|
||||||
"(for automatic config updates)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = Config(args.config, args.base_config)
|
class Maubot(Program):
|
||||||
config.load()
|
config: Config
|
||||||
config.update()
|
server: MaubotServer
|
||||||
|
|
||||||
logging.config.dictConfig(copy.deepcopy(config["logging"]))
|
config_class = Config
|
||||||
|
module = "maubot"
|
||||||
|
name = "maubot"
|
||||||
|
version = __version__
|
||||||
|
command = "python -m maubot"
|
||||||
|
description = "A plugin-based Matrix bot system."
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
def prepare_log_websocket(self) -> None:
|
||||||
|
from .management.api.log import init, stop_all
|
||||||
|
init(self.loop)
|
||||||
|
self.add_shutdown_actions(FutureAwaitable(stop_all))
|
||||||
|
|
||||||
stop_log_listener = None
|
def prepare(self) -> None:
|
||||||
if config["api_features.log"]:
|
super().prepare()
|
||||||
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
|
|
||||||
|
|
||||||
init_log_listener(loop)
|
if self.config["api_features.log"]:
|
||||||
|
self.prepare_log_websocket()
|
||||||
|
|
||||||
log = logging.getLogger("maubot.init")
|
init_zip_loader(self.config)
|
||||||
log.info(f"Initializing maubot {__version__}")
|
init_db(self.config)
|
||||||
|
clients = init_client_class(self.config, self.loop)
|
||||||
|
self.add_startup_actions(*(client.start() for client in clients))
|
||||||
|
management_api = init_mgmt_api(self.config, self.loop)
|
||||||
|
self.server = MaubotServer(management_api, self.config, self.loop)
|
||||||
|
|
||||||
init_zip_loader(config)
|
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
|
||||||
db_engine = init_db(config)
|
for plugin in plugins:
|
||||||
clients = init_client_class(config, loop)
|
|
||||||
management_api = init_mgmt_api(config, loop)
|
|
||||||
server = MaubotServer(management_api, config, loop)
|
|
||||||
plugins = init_plugin_instance_class(config, server, loop)
|
|
||||||
|
|
||||||
for plugin in plugins:
|
|
||||||
plugin.load()
|
plugin.load()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
async def start(self) -> None:
|
||||||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
log.info("Starting server")
|
|
||||||
loop.run_until_complete(server.start())
|
|
||||||
if Client.crypto_db:
|
if Client.crypto_db:
|
||||||
log.debug("Starting client crypto database")
|
self.log.debug("Starting client crypto database")
|
||||||
loop.run_until_complete(Client.crypto_db.start())
|
await Client.crypto_db.start()
|
||||||
log.info("Starting clients and plugins")
|
await super().start()
|
||||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
await self.server.start()
|
||||||
log.info("Startup actions complete, running forever")
|
|
||||||
loop.run_forever()
|
async def stop(self) -> None:
|
||||||
except KeyboardInterrupt:
|
self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
|
||||||
log.info("Interrupt received, stopping clients")
|
await super().stop()
|
||||||
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()]))
|
self.log.debug("Stopping server")
|
||||||
if stop_log_listener is not None:
|
|
||||||
log.debug("Closing websockets")
|
|
||||||
loop.run_until_complete(stop_log_listener())
|
|
||||||
log.debug("Stopping server")
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
|
await asyncio.wait_for(self.server.stop(), 5)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.warning("Stopping server timed out")
|
self.log.warning("Stopping server timed out")
|
||||||
log.debug("Closing event loop")
|
|
||||||
loop.close()
|
|
||||||
log.debug("Everything stopped, shutting down")
|
Maubot().run()
|
||||||
sys.exit(0)
|
|
||||||
|
@ -14,17 +14,15 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||||
from os import path
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
from mautrix.errors import MatrixInvalidToken
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||||
PresenceState, StateFilter)
|
PresenceState, StateFilter, DeviceID)
|
||||||
from mautrix.client import InternalEventType
|
from mautrix.client import InternalEventType
|
||||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||||
|
|
||||||
@ -33,13 +31,12 @@ from .db import DBClient
|
|||||||
from .matrix import MaubotMatrixClient
|
from .matrix import MaubotMatrixClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
|
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
|
||||||
PickleCryptoStore)
|
|
||||||
|
|
||||||
|
|
||||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||||
pass
|
pass
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
||||||
SQLStateStore = BaseSQLStateStore
|
SQLStateStore = BaseSQLStateStore
|
||||||
|
|
||||||
@ -63,8 +60,7 @@ class Client:
|
|||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||||
crypto_pickle_dir: str = None
|
crypto_db: Optional['AsyncDatabase'] = None
|
||||||
crypto_db: 'AsyncDatabase' = None
|
|
||||||
|
|
||||||
references: Set['PluginInstance']
|
references: Set['PluginInstance']
|
||||||
db_instance: DBClient
|
db_instance: DBClient
|
||||||
@ -90,7 +86,7 @@ class Client:
|
|||||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||||
sync_store=SyncStoreProxy(self.db_instance),
|
sync_store=SyncStoreProxy(self.db_instance),
|
||||||
state_store=self.global_state_store)
|
state_store=self.global_state_store)
|
||||||
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
|
if OlmMachine and self.device_id and self.crypto_db:
|
||||||
self.crypto_store = self._make_crypto_store()
|
self.crypto_store = self._make_crypto_store()
|
||||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||||
self.client.crypto = self.crypto
|
self.client.crypto = self.crypto
|
||||||
@ -109,9 +105,6 @@ class Client:
|
|||||||
def _make_crypto_store(self) -> 'CryptoStore':
|
def _make_crypto_store(self) -> 'CryptoStore':
|
||||||
if self.crypto_db:
|
if self.crypto_db:
|
||||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||||
elif self.crypto_pickle_dir:
|
|
||||||
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
|
|
||||||
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
|
|
||||||
raise ValueError("Crypto database not configured")
|
raise ValueError("Crypto database not configured")
|
||||||
|
|
||||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||||
@ -330,7 +323,7 @@ class Client:
|
|||||||
return self.db_instance.access_token
|
return self.db_instance.access_token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_id(self) -> str:
|
def device_id(self) -> DeviceID:
|
||||||
return self.db_instance.device_id
|
return self.db_instance.device_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
|||||||
Client.loop = loop
|
Client.loop = loop
|
||||||
|
|
||||||
if OlmMachine:
|
if OlmMachine:
|
||||||
db_type = config["crypto_database.type"]
|
db_url = config["crypto_database"]
|
||||||
if db_type == "default":
|
if db_url == "default":
|
||||||
db_url = config["database"]
|
db_url = config["database"]
|
||||||
parsed_url = URL(db_url)
|
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
|
||||||
if parsed_url.scheme == "sqlite":
|
|
||||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
|
||||||
elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
|
|
||||||
if not PgCryptoStore:
|
|
||||||
log.warning("Default database is postgres, but asyncpg is not installed. "
|
|
||||||
"Encryption will not work.")
|
|
||||||
else:
|
|
||||||
Client.crypto_db = AsyncDatabase(url=db_url,
|
|
||||||
upgrade_table=PgCryptoStore.upgrade_table)
|
|
||||||
elif db_type == "pickle":
|
|
||||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
|
||||||
elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
|
|
||||||
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
|
|
||||||
upgrade_table=PgCryptoStore.upgrade_table)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported crypto database type")
|
|
||||||
|
|
||||||
return Client.all()
|
return Client.all()
|
||||||
|
@ -32,9 +32,11 @@ class Config(BaseFileConfig):
|
|||||||
base = helper.base
|
base = helper.base
|
||||||
copy = helper.copy
|
copy = helper.copy
|
||||||
copy("database")
|
copy("database")
|
||||||
copy("crypto_database.type")
|
if isinstance(self["crypto_database"], dict):
|
||||||
copy("crypto_database.postgres_uri")
|
if self["crypto_database.type"] == "postgres":
|
||||||
copy("crypto_database.pickle_dir")
|
base["crypto_database"] = self["crypto_database.postgres_uri"]
|
||||||
|
else:
|
||||||
|
copy("crypto_database")
|
||||||
copy("plugin_directories.upload")
|
copy("plugin_directories.upload")
|
||||||
copy("plugin_directories.load")
|
copy("plugin_directories.load")
|
||||||
copy("plugin_directories.trash")
|
copy("plugin_directories.trash")
|
||||||
|
@ -2,22 +2,11 @@
|
|||||||
# Other DBMSes supported by SQLAlchemy may or may not work.
|
# Other DBMSes supported by SQLAlchemy may or may not work.
|
||||||
# Format examples:
|
# Format examples:
|
||||||
# SQLite: sqlite:///filename.db
|
# SQLite: sqlite:///filename.db
|
||||||
# Postgres: postgres://username:password@hostname/dbname
|
# Postgres: postgresql://username:password@hostname/dbname
|
||||||
database: sqlite:///maubot.db
|
database: sqlite:///maubot.db
|
||||||
|
|
||||||
# Database for encryption data.
|
# Separate database URL for the crypto database. "default" means use the same database as above.
|
||||||
crypto_database:
|
crypto_database: default
|
||||||
# Type of database. Either "default", "pickle" or "postgres".
|
|
||||||
# When set to default, using SQLite as the main database will use pickle as the crypto database
|
|
||||||
# and using Postgres as the main database will use the same one as the crypto database.
|
|
||||||
#
|
|
||||||
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
|
|
||||||
# When using non-default postgres, postgres_uri is used to connect to postgres.
|
|
||||||
#
|
|
||||||
# WARNING: The pickle database is dangerous and should not be used in production.
|
|
||||||
type: default
|
|
||||||
postgres_uri: postgres://username:password@hostname/dbname
|
|
||||||
pickle_dir: ./crypto
|
|
||||||
|
|
||||||
plugin_directories:
|
plugin_directories:
|
||||||
# The directory where uploaded new plugins should be stored.
|
# The directory where uploaded new plugins should be stored.
|
9
maubot/lib/future_awaitable.py
Normal file
9
maubot/lib/future_awaitable.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from typing import Callable, Awaitable, Generator, Any
|
||||||
|
|
||||||
|
class FutureAwaitable:
|
||||||
|
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
def __await__(self) -> Generator[Any, None, None]:
|
||||||
|
return self._func().__await__()
|
||||||
|
|
@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def stop_all() -> None:
|
async def stop_all() -> None:
|
||||||
|
log.debug("Closing log listener websockets")
|
||||||
log_root.removeHandler(handler)
|
log_root.removeHandler(handler)
|
||||||
for socket in sockets:
|
for socket in sockets:
|
||||||
try:
|
try:
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
|
|
||||||
#/postgres
|
#/postgres
|
||||||
psycopg2-binary>=2,<3
|
psycopg2-binary>=2,<3
|
||||||
|
asyncpg>=0.20,<0.26
|
||||||
|
|
||||||
#/e2be
|
#/e2be
|
||||||
asyncpg>=0.20,<0.25
|
aiosqlite>=0.16,<0.18
|
||||||
python-olm>=3,<4
|
python-olm>=3,<4
|
||||||
pycryptodome>=3,<4
|
pycryptodome>=3,<4
|
||||||
unpaddedbase64>=1,<2
|
unpaddedbase64>=1,<2
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mautrix>=0.10.9,<0.11
|
mautrix>=0.11,<0.12
|
||||||
aiohttp>=3,<4
|
aiohttp>=3,<4
|
||||||
yarl>=1,<2
|
yarl>=1,<2
|
||||||
SQLAlchemy>=1,<1.4
|
SQLAlchemy>=1,<1.4
|
||||||
|
9
setup.py
9
setup.py
@ -57,15 +57,18 @@ setuptools.setup(
|
|||||||
mbc=maubot.cli:app
|
mbc=maubot.cli:app
|
||||||
""",
|
""",
|
||||||
data_files=[
|
data_files=[
|
||||||
(".", ["example-config.yaml", "alembic.ini"]),
|
(".", ["maubot/example-config.yaml", "alembic.ini"]),
|
||||||
("alembic", ["alembic/env.py"]),
|
("alembic", ["alembic/env.py"]),
|
||||||
("alembic/versions", glob.glob("alembic/versions/*.py")),
|
("alembic/versions", glob.glob("alembic/versions/*.py")),
|
||||||
],
|
],
|
||||||
package_data={
|
package_data={
|
||||||
"maubot": ["management/frontend/build/*",
|
"maubot": [
|
||||||
|
"example-config.yaml",
|
||||||
|
"management/frontend/build/*",
|
||||||
"management/frontend/build/static/css/*",
|
"management/frontend/build/static/css/*",
|
||||||
"management/frontend/build/static/js/*",
|
"management/frontend/build/static/js/*",
|
||||||
"management/frontend/build/static/media/*"],
|
"management/frontend/build/static/media/*",
|
||||||
|
],
|
||||||
"maubot.cli": ["res/*"],
|
"maubot.cli": ["res/*"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user