Add support for end-to-end encryption. Fixes #46

This commit is contained in:
Tulir Asokan 2020-07-12 14:55:41 +03:00
parent 4e767a10e4
commit 69d7a4341b
17 changed files with 203 additions and 24 deletions

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include README.md
include LICENSE
include requirements.txt
include optional-requirements.txt

View File

@ -0,0 +1,47 @@
"""Add Matrix state store
Revision ID: 90aa88820eab
Revises: 4b93300852aa
Create Date: 2020-07-12 01:50:06.215623
"""
from alembic import op
import sqlalchemy as sa
from mautrix.client.state_store.sqlalchemy import SerializableType
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
# revision identifiers, used by Alembic.
revision = '90aa88820eab'
down_revision = '4b93300852aa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mx_room_state',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
sa.PrimaryKeyConstraint('room_id')
)
op.create_table('mx_user_profile',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=False),
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
sa.Column('displayname', sa.String(), nullable=True),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint('room_id', 'user_id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mx_user_profile')
op.drop_table('mx_room_state')
# ### end Alembic commands ###

View File

@ -5,6 +5,18 @@
# Postgres: postgres://username:password@hostname/dbname
database: sqlite:///maubot.db
# Database for encryption data.
crypto_database:
# Type of database. Either "default", "pickle" or "postgres".
# When set to default, using SQLite as the main database will use pickle as the crypto database
# and using Postgres as the main database will use the same one as the crypto database.
#
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
# When using non-default postgres, postgres_uri is used to connect to postgres.
type: default
postgres_uri: postgres://username:password@hostname/dbname
pickle_dir: ./crypto
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins

View File

@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
init_zip_loader(config)
db_engine = init_db(config)
clients = init_client_class(loop)
clients = init_client_class(config, loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
try:
log.info("Starting server")
loop.run_until_complete(server.start())
if Client.crypto_db:
log.debug("Starting client crypto database")
loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever")

View File

@ -18,12 +18,13 @@ from io import BytesIO
import zipfile
import os
from mautrix.client.api.types.util import SerializerError
from ruamel.yaml import YAML, YAMLError
from colorama import Fore
from PyInquirer import prompt
import click
from mautrix.types import SerializerError
from ...loader import PluginMeta
from ..cliq.validators import PathValidator
from ..base import app

View File

@ -18,9 +18,10 @@ import asyncio
from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
from mautrix.client.api.types.util import Obj
import click
from mautrix.types import Obj
from ..config import get_token
from ..base import app

View File

@ -13,24 +13,46 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio
import logging
from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from .lib.store_proxy import ClientStoreProxy
from .lib.store_proxy import SyncStoreProxy
from .db import DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
PickleCryptoStore)
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
except ImportError:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
try:
from mautrix.util.async_db import Database as AsyncDatabase
from mautrix.crypto import PgCryptoStore
except ImportError:
AsyncDatabase = None
PgCryptoStore = None
if TYPE_CHECKING:
from .instance import PluginInstance
from .config import Config
log = logging.getLogger("maubot.client")
@ -40,10 +62,15 @@ class Client:
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None
crypto_db: 'AsyncDatabase' = None
references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient
crypto: Optional['OlmMachine']
crypto_store: Optional['CryptoStore']
started: bool
remote_displayname: Optional[str]
@ -61,7 +88,15 @@ class Client:
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, device_id=self.device_id,
store=ClientStoreProxy(self.db_instance))
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
else:
self.crypto_store = None
self.crypto = None
self.client.ignore_initial_sync = True
self.client.ignore_first_sync = True
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
@ -71,6 +106,14 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None:
self.sync_ok = ok
@ -130,6 +173,16 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.crypto:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database. "
"Encryption may not work.")
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
self.start_sync()
await self._update_remote_profile()
self.started = True
@ -154,6 +207,8 @@ class Client:
self.started = False
await self.stop_plugins()
self.stop_sync()
if self.crypto:
await self.crypto_store.close()
def clear_cache(self) -> None:
self.stop_sync()
@ -172,6 +227,7 @@ class Client:
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
"device_id": self.device_id,
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
@ -243,11 +299,12 @@ class Client:
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, log=self.log)
client_session=self.http_client, device_id=self.device_id,
log=self.log, state_store=self.global_state_store)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError(f"MXID mismatch: {mxid}")
new_client.store = self.db_instance
new_client.sync_store = self.db_instance
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
@ -341,7 +398,30 @@ class Client:
# endregion
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
if OlmMachine:
db_type = config["crypto_database.type"]
if db_type == "default":
db_url = config["database"]
parsed_url = URL(db_url)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif db_type == "postgres" and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
return Client.all()

View File

@ -32,6 +32,9 @@ class Config(BaseFileConfig):
base = helper.base
copy = helper.copy
copy("database")
copy("crypto_database.type")
copy("crypto_database.postgres_uri")
copy("crypto_database.pickle_dir")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")

View File

@ -23,6 +23,7 @@ import sqlalchemy as sql
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
from mautrix.util.db import Base
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from .config import Config
@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
db = sql.create_engine(config["database"])
Base.metadata.bind = db
for table in (DBPlugin, DBClient):
for table in (DBPlugin, DBClient, RoomState, UserProfile):
table.bind(db)
if not db.has_table("alembic_version"):

View File

@ -13,11 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client import ClientStore
from mautrix.client import SyncStore
from mautrix.types import SyncToken
class ClientStoreProxy(ClientStore):
class SyncStoreProxy(SyncStore):
def __init__(self, db_instance) -> None:
self.db_instance = db_instance

View File

@ -19,8 +19,8 @@ import asyncio
from attr import dataclass
from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer)
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
from ..__meta__ import __version__
from ..plugin_base import Plugin

View File

@ -22,7 +22,8 @@ import os
from ruamel.yaml import YAML, YAMLError
from packaging.version import Version
from mautrix.client.api.types.util import SerializerError
from mautrix.types import SerializerError
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin

View File

@ -13,13 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Awaitable, Optional, Tuple
from typing import Union, Awaitable, Optional, Tuple, List
from html import escape
import asyncio
import attr
from mautrix.client import Client as MatrixClient, SyncStream
from mautrix.util.formatter import parse_html
from mautrix.util import markdown
from mautrix.util import markdown, formatter
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo)
@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
html = message
else:
return message, escape(message)
return parse_html(html), html
return formatter.parse_html(html), html
class MaubotMessageEvent(MessageEvent):
@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
content.set_edit(edits)
return self.send_message(room_id, content, **kwargs)
async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None:
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self)
elif source != SyncStream.INTERNAL:
event.client = self
return await super().dispatch_event(event, source)
return super().dispatch_event(event, source)
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
event = await super().get_event(room_id, event_id)

View File

@ -36,7 +36,7 @@ from .config import Config
from ..plugin_base import Plugin
from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient
from ..lib.store_proxy import ClientStoreProxy
from ..lib.store_proxy import SyncStoreProxy
from ..__meta__ import __version__
parser = argparse.ArgumentParser(
@ -143,7 +143,7 @@ async def main():
global client, bot
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
client_session=http_client, loop=loop, store=ClientStoreProxy(nb),
client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
log=logging.getLogger("maubot.client").getChild(user_id))
while True:

11
optional-requirements.txt Normal file
View File

@ -0,0 +1,11 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
#/postgres
psycopg2-binary>=2,<3
#/e2be
asyncpg>=0.20,<0.21
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<2

View File

@ -1,4 +1,4 @@
mautrix==0.6.0.beta7
mautrix==0.6.0rc1
aiohttp>=3,<4
SQLAlchemy>=1,<2
alembic>=1,<2

View File

@ -5,6 +5,19 @@ import os
with open("requirements.txt") as reqs:
install_requires = reqs.read().splitlines()
with open("optional-requirements.txt") as reqs:
extras_require = {}
current = []
for line in reqs.read().splitlines():
if line.startswith("#/"):
extras_require[line[2:]] = current = []
elif not line or line.startswith("#"):
continue
else:
current.append(line)
extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps})
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
__version__ = "UNKNOWN"
with open(path) as f:
@ -25,6 +38,7 @@ setuptools.setup(
packages=setuptools.find_packages(),
install_requires=install_requires,
extras_require=extras_require,
classifiers=[
"Development Status :: 3 - Alpha",