diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..daa36da --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include README.md +include LICENSE +include requirements.txt +include optional-requirements.txt diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py new file mode 100644 index 0000000..37a68eb --- /dev/null +++ b/alembic/versions/90aa88820eab_add_matrix_state_store.py @@ -0,0 +1,47 @@ +"""Add Matrix state store + +Revision ID: 90aa88820eab +Revises: 4b93300852aa +Create Date: 2020-07-12 01:50:06.215623 + +""" +from alembic import op +import sqlalchemy as sa + +from mautrix.client.state_store.sqlalchemy import SerializableType +from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent + + +# revision identifiers, used by Alembic. +revision = '90aa88820eab' +down_revision = '4b93300852aa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('mx_room_state', + sa.Column('room_id', sa.String(length=255), nullable=False), + sa.Column('is_encrypted', sa.Boolean(), nullable=True), + sa.Column('has_full_member_list', sa.Boolean(), nullable=True), + sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True), + sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True), + sa.PrimaryKeyConstraint('room_id') + ) + op.create_table('mx_user_profile', + sa.Column('room_id', sa.String(length=255), nullable=False), + sa.Column('user_id', sa.String(length=255), nullable=False), + sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False), + sa.Column('displayname', sa.String(), nullable=True), + sa.Column('avatar_url', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('room_id', 'user_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('mx_user_profile') + op.drop_table('mx_room_state') + # ### end Alembic commands ### diff --git a/example-config.yaml b/example-config.yaml index 9e7aa77..908f4a4 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -5,6 +5,18 @@ # Postgres: postgres://username:password@hostname/dbname database: sqlite:///maubot.db +# Database for encryption data. +crypto_database: + # Type of database. Either "default", "pickle" or "postgres". + # When set to default, using SQLite as the main database will use pickle as the crypto database + # and using Postgres as the main database will use the same one as the crypto database. + # + # When using pickle, individual crypto databases are stored in the pickle_dir directory. + # When using non-default postgres, postgres_uri is used to connect to postgres. + type: default + postgres_uri: postgres://username:password@hostname/dbname + pickle_dir: ./crypto + plugin_directories: # The directory where uploaded new plugins should be stored. upload: ./plugins diff --git a/maubot/__main__.py b/maubot/__main__.py index de78526..2ef73f9 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}") init_zip_loader(config) db_engine = init_db(config) -clients = init_client_class(loop) +clients = init_client_class(config, loop) management_api = init_mgmt_api(config, loop) server = MaubotServer(management_api, config, loop) plugins = init_plugin_instance_class(config, server, loop) @@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler) try: log.info("Starting server") loop.run_until_complete(server.start()) + if Client.crypto_db: + log.debug("Starting client crypto database") + loop.run_until_complete(Client.crypto_db.start()) log.info("Starting clients and plugins") loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) log.info("Startup actions complete, running forever") diff --git a/maubot/cli/commands/build.py b/maubot/cli/commands/build.py index c0969dd..9e3724c 100644 --- a/maubot/cli/commands/build.py +++ b/maubot/cli/commands/build.py @@ -18,12 +18,13 @@ from io import BytesIO import zipfile import os -from mautrix.client.api.types.util import SerializerError from ruamel.yaml import YAML, YAMLError from colorama import Fore from PyInquirer import prompt import click +from mautrix.types import SerializerError + from ...loader import PluginMeta from ..cliq.validators import PathValidator from ..base import app diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py index d16b68e..8d0a578 100644 --- a/maubot/cli/commands/logs.py +++ b/maubot/cli/commands/logs.py @@ -18,9 +18,10 @@ import asyncio from colorama import Fore from aiohttp import WSMsgType, WSMessage, ClientSession -from mautrix.client.api.types.util import Obj import click +from mautrix.types import Obj + from ..config import get_token from ..base import app diff --git a/maubot/client.py b/maubot/client.py index 6eab38a..73da264 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -13,24 +13,46 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING +from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING +from os import path import asyncio import logging from aiohttp import ClientSession +from yarl import URL from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, PresenceState, StateFilter) from mautrix.client import InternalEventType +from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore -from .lib.store_proxy import ClientStoreProxy +from .lib.store_proxy import SyncStoreProxy from .db import DBClient from .matrix import MaubotMatrixClient +try: + from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore, + PickleCryptoStore) + + + class SQLStateStore(BaseSQLStateStore, CryptoStateStore): + pass +except ImportError: + OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None + SQLStateStore = BaseSQLStateStore + +try: + from mautrix.util.async_db import Database as AsyncDatabase + from mautrix.crypto import PgCryptoStore +except ImportError: + AsyncDatabase = None + PgCryptoStore = None + if TYPE_CHECKING: from .instance import PluginInstance + from .config import Config log = logging.getLogger("maubot.client") @@ -40,10 +62,15 @@ class Client: loop: asyncio.AbstractEventLoop = None cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None + global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() + crypto_pickle_dir: str = None + crypto_db: 'AsyncDatabase' = None references: Set['PluginInstance'] db_instance: DBClient client: MaubotMatrixClient + crypto: Optional['OlmMachine'] + crypto_store: Optional['CryptoStore'] started: bool remote_displayname: Optional[str] @@ -61,7 +88,15 @@ class Client: self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, token=self.access_token, client_session=self.http_client, log=self.log, loop=self.loop, device_id=self.device_id, - store=ClientStoreProxy(self.db_instance)) + sync_store=SyncStoreProxy(self.db_instance), + state_store=self.global_state_store) + if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir): + self.crypto_store = self._make_crypto_store() + self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) + self.client.crypto = self.crypto + else: + self.crypto_store = None + self.crypto = None self.client.ignore_initial_sync = True self.client.ignore_first_sync = True self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE @@ -71,6 +106,14 @@ class Client: self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) + def _make_crypto_store(self) -> 'CryptoStore': + if self.crypto_db: + return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) + elif self.crypto_pickle_dir: + return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto", + path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle")) + raise ValueError("Crypto database not configured") + def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: async def handler(data: Dict[str, Any]) -> None: self.sync_ok = ok @@ -130,6 +173,16 @@ class Client: await self.client.set_displayname(self.displayname) if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) + if self.crypto: + self.log.debug("Enabling end-to-end encryption support") + await self.crypto_store.open() + crypto_device_id = await self.crypto_store.get_device_id() + if crypto_device_id and crypto_device_id != self.device_id: + self.log.warning("Mismatching device ID in crypto store and main database. " + "Encryption may not work.") + await self.crypto.load() + if not crypto_device_id: + await self.crypto_store.put_device_id(self.device_id) self.start_sync() await self._update_remote_profile() self.started = True @@ -154,6 +207,8 @@ class Client: self.started = False await self.stop_plugins() self.stop_sync() + if self.crypto: + await self.crypto_store.close() def clear_cache(self) -> None: self.stop_sync() @@ -172,6 +227,7 @@ class Client: "id": self.id, "homeserver": self.homeserver, "access_token": self.access_token, + "device_id": self.device_id, "enabled": self.enabled, "started": self.started, "sync": self.sync, @@ -243,11 +299,12 @@ class Client: return new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, token=access_token or self.access_token, loop=self.loop, - client_session=self.http_client, log=self.log) + client_session=self.http_client, device_id=self.device_id, + log=self.log, state_store=self.global_state_store) mxid = await new_client.whoami() if mxid != self.id: raise ValueError(f"MXID mismatch: {mxid}") - new_client.store = self.db_instance + new_client.sync_store = self.db_instance self.stop_sync() self.client = new_client self.db_instance.homeserver = homeserver @@ -341,7 +398,30 @@ class Client: # endregion -def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]: +def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]: Client.http_client = ClientSession(loop=loop) Client.loop = loop + + if OlmMachine: + db_type = config["crypto_database.type"] + if db_type == "default": + db_url = config["database"] + parsed_url = URL(db_url) + if parsed_url.scheme == "sqlite": + Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] + elif parsed_url.scheme == "postgres": + if not PgCryptoStore: + log.warning("Default database is postgres, but asyncpg is not installed. " + "Encryption will not work.") + else: + Client.crypto_db = AsyncDatabase(url=db_url, + upgrade_table=PgCryptoStore.upgrade_table) + elif db_type == "pickle": + Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] + elif db_type == "postgres" and PgCryptoStore: + Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"], + upgrade_table=PgCryptoStore.upgrade_table) + else: + raise ValueError("Unsupported crypto database type") + return Client.all() diff --git a/maubot/config.py b/maubot/config.py index 3901dad..34466cc 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -32,6 +32,9 @@ class Config(BaseFileConfig): base = helper.base copy = helper.copy copy("database") + copy("crypto_database.type") + copy("crypto_database.postgres_uri") + copy("crypto_database.pickle_dir") copy("plugin_directories.upload") copy("plugin_directories.load") copy("plugin_directories.trash") diff --git a/maubot/db.py b/maubot/db.py index 1a6b9fb..3817882 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -23,6 +23,7 @@ import sqlalchemy as sql from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI from mautrix.util.db import Base +from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile from .config import Config @@ -79,7 +80,7 @@ def init(config: Config) -> Engine: db = sql.create_engine(config["database"]) Base.metadata.bind = db - for table in (DBPlugin, DBClient): + for table in (DBPlugin, DBClient, RoomState, UserProfile): table.bind(db) if not db.has_table("alembic_version"): diff --git a/maubot/lib/store_proxy.py b/maubot/lib/store_proxy.py index e9e3696..6e402aa 100644 --- a/maubot/lib/store_proxy.py +++ b/maubot/lib/store_proxy.py @@ -13,11 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.client import ClientStore +from mautrix.client import SyncStore from mautrix.types import SyncToken -class ClientStoreProxy(ClientStore): +class SyncStoreProxy(SyncStore): def __init__(self, db_instance) -> None: self.db_instance = db_instance diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index cb46a4c..01713f4 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -19,8 +19,8 @@ import asyncio from attr import dataclass from packaging.version import Version, InvalidVersion -from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer, - deserializer) + +from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer from ..__meta__ import __version__ from ..plugin_base import Plugin diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 7262bf0..72daf22 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -22,7 +22,8 @@ import os from ruamel.yaml import YAML, YAMLError from packaging.version import Version -from mautrix.client.api.types.util import SerializerError + +from mautrix.types import SerializerError from ..lib.zipimport import zipimporter, ZipImportError from ..plugin_base import Plugin diff --git a/maubot/matrix.py b/maubot/matrix.py index 1821a27..968541a 100644 --- a/maubot/matrix.py +++ b/maubot/matrix.py @@ -13,13 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Union, Awaitable, Optional, Tuple +from typing import Union, Awaitable, Optional, Tuple, List from html import escape +import asyncio + import attr from mautrix.client import Client as MatrixClient, SyncStream -from mautrix.util.formatter import parse_html -from mautrix.util import markdown +from mautrix.util import markdown, formatter from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent, MessageType, TextMessageEventContent, Format, RelatesTo) @@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo html = message else: return message, escape(message) - return parse_html(html), html + return formatter.parse_html(html), html class MaubotMessageEvent(MessageEvent): @@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient): content.set_edit(edits) return self.send_message(room_id, content, **kwargs) - async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None: + def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]: if isinstance(event, MessageEvent): event = MaubotMessageEvent(event, self) elif source != SyncStream.INTERNAL: event.client = self - return await super().dispatch_event(event, source) + return super().dispatch_event(event, source) async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: event = await super().get_event(room_id, event_id) diff --git a/maubot/standalone/__main__.py b/maubot/standalone/__main__.py index b593610..bcc2c72 100644 --- a/maubot/standalone/__main__.py +++ b/maubot/standalone/__main__.py @@ -36,7 +36,7 @@ from .config import Config from ..plugin_base import Plugin from ..loader import PluginMeta from ..matrix import MaubotMatrixClient -from ..lib.store_proxy import ClientStoreProxy +from ..lib.store_proxy import SyncStoreProxy from ..__meta__ import __version__ parser = argparse.ArgumentParser( @@ -143,7 +143,7 @@ async def main(): global client, bot client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token, - client_session=http_client, loop=loop, store=ClientStoreProxy(nb), + client_session=http_client, loop=loop, store=SyncStoreProxy(nb), log=logging.getLogger("maubot.client").getChild(user_id)) while True: diff --git a/optional-requirements.txt b/optional-requirements.txt new file mode 100644 index 0000000..ffbf672 --- /dev/null +++ b/optional-requirements.txt @@ -0,0 +1,11 @@ +# Format: #/name defines a new extras_require group called name +# Uncommented lines after the group definition insert things into that group. + +#/postgres +psycopg2-binary>=2,<3 + +#/e2be +asyncpg>=0.20,<0.21 +python-olm>=3,<4 +pycryptodome>=3,<4 +unpaddedbase64>=1,<2 diff --git a/requirements.txt b/requirements.txt index e37ad38..cd71bc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -mautrix==0.6.0.beta7 +mautrix==0.6.0rc1 aiohttp>=3,<4 SQLAlchemy>=1,<2 alembic>=1,<2 diff --git a/setup.py b/setup.py index a1f0ba4..c2cc680 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,19 @@ import os with open("requirements.txt") as reqs: install_requires = reqs.read().splitlines() +with open("optional-requirements.txt") as reqs: + extras_require = {} + current = [] + for line in reqs.read().splitlines(): + if line.startswith("#/"): + extras_require[line[2:]] = current = [] + elif not line or line.startswith("#"): + continue + else: + current.append(line) + +extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps}) + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py") __version__ = "UNKNOWN" with open(path) as f: @@ -25,6 +38,7 @@ setuptools.setup( packages=setuptools.find_packages(), install_requires=install_requires, + extras_require=extras_require, classifiers=[ "Development Status :: 3 - Alpha",