diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..daa36da
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,4 @@
+include README.md
+include LICENSE
+include requirements.txt
+include optional-requirements.txt
diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py
new file mode 100644
index 0000000..37a68eb
--- /dev/null
+++ b/alembic/versions/90aa88820eab_add_matrix_state_store.py
@@ -0,0 +1,47 @@
+"""Add Matrix state store
+
+Revision ID: 90aa88820eab
+Revises: 4b93300852aa
+Create Date: 2020-07-12 01:50:06.215623
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+from mautrix.client.state_store.sqlalchemy import SerializableType
+from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
+
+
+# revision identifiers, used by Alembic.
+revision = '90aa88820eab'
+down_revision = '4b93300852aa'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('mx_room_state',
+ sa.Column('room_id', sa.String(length=255), nullable=False),
+ sa.Column('is_encrypted', sa.Boolean(), nullable=True),
+ sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
+ sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
+ sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
+ sa.PrimaryKeyConstraint('room_id')
+ )
+ op.create_table('mx_user_profile',
+ sa.Column('room_id', sa.String(length=255), nullable=False),
+ sa.Column('user_id', sa.String(length=255), nullable=False),
+ sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
+ sa.Column('displayname', sa.String(), nullable=True),
+ sa.Column('avatar_url', sa.String(length=255), nullable=True),
+ sa.PrimaryKeyConstraint('room_id', 'user_id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('mx_user_profile')
+ op.drop_table('mx_room_state')
+ # ### end Alembic commands ###
diff --git a/example-config.yaml b/example-config.yaml
index 9e7aa77..908f4a4 100644
--- a/example-config.yaml
+++ b/example-config.yaml
@@ -5,6 +5,18 @@
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:///maubot.db
+# Database for encryption data.
+crypto_database:
+ # Type of database. Either "default", "pickle" or "postgres".
+ # When set to default, using SQLite as the main database will use pickle as the crypto database
+ # and using Postgres as the main database will use the same one as the crypto database.
+ #
+ # When using pickle, individual crypto databases are stored in the pickle_dir directory.
+ # When using non-default postgres, postgres_uri is used to connect to postgres.
+ type: default
+ postgres_uri: postgres://username:password@hostname/dbname
+ pickle_dir: ./crypto
+
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
diff --git a/maubot/__main__.py b/maubot/__main__.py
index de78526..2ef73f9 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
init_zip_loader(config)
db_engine = init_db(config)
-clients = init_client_class(loop)
+clients = init_client_class(config, loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
@@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
try:
log.info("Starting server")
loop.run_until_complete(server.start())
+ if Client.crypto_db:
+ log.debug("Starting client crypto database")
+ loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever")
diff --git a/maubot/cli/commands/build.py b/maubot/cli/commands/build.py
index c0969dd..9e3724c 100644
--- a/maubot/cli/commands/build.py
+++ b/maubot/cli/commands/build.py
@@ -18,12 +18,13 @@ from io import BytesIO
import zipfile
import os
-from mautrix.client.api.types.util import SerializerError
from ruamel.yaml import YAML, YAMLError
from colorama import Fore
from PyInquirer import prompt
import click
+from mautrix.types import SerializerError
+
from ...loader import PluginMeta
from ..cliq.validators import PathValidator
from ..base import app
diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py
index d16b68e..8d0a578 100644
--- a/maubot/cli/commands/logs.py
+++ b/maubot/cli/commands/logs.py
@@ -18,9 +18,10 @@ import asyncio
from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
-from mautrix.client.api.types.util import Obj
import click
+from mautrix.types import Obj
+
from ..config import get_token
from ..base import app
diff --git a/maubot/client.py b/maubot/client.py
index 6eab38a..73da264 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -13,24 +13,46 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
+from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
+from os import path
import asyncio
import logging
from aiohttp import ClientSession
+from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter)
from mautrix.client import InternalEventType
+from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
-from .lib.store_proxy import ClientStoreProxy
+from .lib.store_proxy import SyncStoreProxy
from .db import DBClient
from .matrix import MaubotMatrixClient
+try:
+ from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
+ PickleCryptoStore)
+
+
+ class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
+ pass
+except ImportError:
+ OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
+ SQLStateStore = BaseSQLStateStore
+
+try:
+ from mautrix.util.async_db import Database as AsyncDatabase
+ from mautrix.crypto import PgCryptoStore
+except ImportError:
+ AsyncDatabase = None
+ PgCryptoStore = None
+
if TYPE_CHECKING:
from .instance import PluginInstance
+ from .config import Config
log = logging.getLogger("maubot.client")
@@ -40,10 +62,15 @@ class Client:
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
+ global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
+ crypto_pickle_dir: str = None
+ crypto_db: 'AsyncDatabase' = None
references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient
+ crypto: Optional['OlmMachine']
+ crypto_store: Optional['CryptoStore']
started: bool
remote_displayname: Optional[str]
@@ -61,7 +88,15 @@ class Client:
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, device_id=self.device_id,
- store=ClientStoreProxy(self.db_instance))
+ sync_store=SyncStoreProxy(self.db_instance),
+ state_store=self.global_state_store)
+ if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
+ self.crypto_store = self._make_crypto_store()
+ self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
+ self.client.crypto = self.crypto
+ else:
+ self.crypto_store = None
+ self.crypto = None
self.client.ignore_initial_sync = True
self.client.ignore_first_sync = True
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
@@ -71,6 +106,14 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
+ def _make_crypto_store(self) -> 'CryptoStore':
+ if self.crypto_db:
+ return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
+ elif self.crypto_pickle_dir:
+ return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
+ path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
+ raise ValueError("Crypto database not configured")
+
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None:
self.sync_ok = ok
@@ -130,6 +173,16 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
+ if self.crypto:
+ self.log.debug("Enabling end-to-end encryption support")
+ await self.crypto_store.open()
+ crypto_device_id = await self.crypto_store.get_device_id()
+ if crypto_device_id and crypto_device_id != self.device_id:
+ self.log.warning("Mismatching device ID in crypto store and main database. "
+ "Encryption may not work.")
+ await self.crypto.load()
+ if not crypto_device_id:
+ await self.crypto_store.put_device_id(self.device_id)
self.start_sync()
await self._update_remote_profile()
self.started = True
@@ -154,6 +207,8 @@ class Client:
self.started = False
await self.stop_plugins()
self.stop_sync()
+ if self.crypto:
+ await self.crypto_store.close()
def clear_cache(self) -> None:
self.stop_sync()
@@ -172,6 +227,7 @@ class Client:
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
+ "device_id": self.device_id,
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
@@ -243,11 +299,12 @@ class Client:
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
- client_session=self.http_client, log=self.log)
+ client_session=self.http_client, device_id=self.device_id,
+ log=self.log, state_store=self.global_state_store)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError(f"MXID mismatch: {mxid}")
- new_client.store = self.db_instance
+ new_client.sync_store = self.db_instance
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
@@ -341,7 +398,30 @@ class Client:
# endregion
-def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
+def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
+
+ if OlmMachine:
+ db_type = config["crypto_database.type"]
+ if db_type == "default":
+ db_url = config["database"]
+ parsed_url = URL(db_url)
+ if parsed_url.scheme == "sqlite":
+ Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
+ elif parsed_url.scheme == "postgres":
+ if not PgCryptoStore:
+ log.warning("Default database is postgres, but asyncpg is not installed. "
+ "Encryption will not work.")
+ else:
+ Client.crypto_db = AsyncDatabase(url=db_url,
+ upgrade_table=PgCryptoStore.upgrade_table)
+ elif db_type == "pickle":
+ Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
+ elif db_type == "postgres" and PgCryptoStore:
+ Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
+ upgrade_table=PgCryptoStore.upgrade_table)
+ else:
+ raise ValueError("Unsupported crypto database type")
+
return Client.all()
diff --git a/maubot/config.py b/maubot/config.py
index 3901dad..34466cc 100644
--- a/maubot/config.py
+++ b/maubot/config.py
@@ -32,6 +32,9 @@ class Config(BaseFileConfig):
base = helper.base
copy = helper.copy
copy("database")
+ copy("crypto_database.type")
+ copy("crypto_database.postgres_uri")
+ copy("crypto_database.pickle_dir")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
diff --git a/maubot/db.py b/maubot/db.py
index 1a6b9fb..3817882 100644
--- a/maubot/db.py
+++ b/maubot/db.py
@@ -23,6 +23,7 @@ import sqlalchemy as sql
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
from mautrix.util.db import Base
+from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from .config import Config
@@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
db = sql.create_engine(config["database"])
Base.metadata.bind = db
- for table in (DBPlugin, DBClient):
+ for table in (DBPlugin, DBClient, RoomState, UserProfile):
table.bind(db)
if not db.has_table("alembic_version"):
diff --git a/maubot/lib/store_proxy.py b/maubot/lib/store_proxy.py
index e9e3696..6e402aa 100644
--- a/maubot/lib/store_proxy.py
+++ b/maubot/lib/store_proxy.py
@@ -13,11 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from mautrix.client import ClientStore
+from mautrix.client import SyncStore
from mautrix.types import SyncToken
-class ClientStoreProxy(ClientStore):
+class SyncStoreProxy(SyncStore):
def __init__(self, db_instance) -> None:
self.db_instance = db_instance
diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py
index cb46a4c..01713f4 100644
--- a/maubot/loader/abc.py
+++ b/maubot/loader/abc.py
@@ -19,8 +19,8 @@ import asyncio
from attr import dataclass
from packaging.version import Version, InvalidVersion
-from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
- deserializer)
+
+from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
from ..__meta__ import __version__
from ..plugin_base import Plugin
diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py
index 7262bf0..72daf22 100644
--- a/maubot/loader/zip.py
+++ b/maubot/loader/zip.py
@@ -22,7 +22,8 @@ import os
from ruamel.yaml import YAML, YAMLError
from packaging.version import Version
-from mautrix.client.api.types.util import SerializerError
+
+from mautrix.types import SerializerError
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
diff --git a/maubot/matrix.py b/maubot/matrix.py
index 1821a27..968541a 100644
--- a/maubot/matrix.py
+++ b/maubot/matrix.py
@@ -13,13 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Union, Awaitable, Optional, Tuple
+from typing import Union, Awaitable, Optional, Tuple, List
from html import escape
+import asyncio
+
import attr
from mautrix.client import Client as MatrixClient, SyncStream
-from mautrix.util.formatter import parse_html
-from mautrix.util import markdown
+from mautrix.util import markdown, formatter
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo)
@@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
html = message
else:
return message, escape(message)
- return parse_html(html), html
+ return formatter.parse_html(html), html
class MaubotMessageEvent(MessageEvent):
@@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
content.set_edit(edits)
return self.send_message(room_id, content, **kwargs)
- async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None:
+ def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self)
elif source != SyncStream.INTERNAL:
event.client = self
- return await super().dispatch_event(event, source)
+ return super().dispatch_event(event, source)
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
event = await super().get_event(room_id, event_id)
diff --git a/maubot/standalone/__main__.py b/maubot/standalone/__main__.py
index b593610..bcc2c72 100644
--- a/maubot/standalone/__main__.py
+++ b/maubot/standalone/__main__.py
@@ -36,7 +36,7 @@ from .config import Config
from ..plugin_base import Plugin
from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient
-from ..lib.store_proxy import ClientStoreProxy
+from ..lib.store_proxy import SyncStoreProxy
from ..__meta__ import __version__
parser = argparse.ArgumentParser(
@@ -143,7 +143,7 @@ async def main():
global client, bot
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
- client_session=http_client, loop=loop, store=ClientStoreProxy(nb),
+ client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
log=logging.getLogger("maubot.client").getChild(user_id))
while True:
diff --git a/optional-requirements.txt b/optional-requirements.txt
new file mode 100644
index 0000000..ffbf672
--- /dev/null
+++ b/optional-requirements.txt
@@ -0,0 +1,11 @@
+# Format: #/name defines a new extras_require group called name
+# Uncommented lines after the group definition insert things into that group.
+
+#/postgres
+psycopg2-binary>=2,<3
+
+#/e2be
+asyncpg>=0.20,<0.21
+python-olm>=3,<4
+pycryptodome>=3,<4
+unpaddedbase64>=1,<2
diff --git a/requirements.txt b/requirements.txt
index e37ad38..cd71bc2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-mautrix==0.6.0.beta7
+mautrix==0.6.0rc1
aiohttp>=3,<4
SQLAlchemy>=1,<2
alembic>=1,<2
diff --git a/setup.py b/setup.py
index a1f0ba4..c2cc680 100644
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,19 @@ import os
with open("requirements.txt") as reqs:
install_requires = reqs.read().splitlines()
+with open("optional-requirements.txt") as reqs:
+ extras_require = {}
+ current = []
+ for line in reqs.read().splitlines():
+ if line.startswith("#/"):
+ extras_require[line[2:]] = current = []
+ elif not line or line.startswith("#"):
+ continue
+ else:
+ current.append(line)
+
+extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps})
+
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
__version__ = "UNKNOWN"
with open(path) as f:
@@ -25,6 +38,7 @@ setuptools.setup(
packages=setuptools.find_packages(),
install_requires=install_requires,
+ extras_require=extras_require,
classifiers=[
"Development Status :: 3 - Alpha",