Add support for end-to-end encryption. Fixes #46
This commit is contained in:
parent
4e767a10e4
commit
69d7a4341b
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
include README.md
|
||||||
|
include LICENSE
|
||||||
|
include requirements.txt
|
||||||
|
include optional-requirements.txt
|
47
alembic/versions/90aa88820eab_add_matrix_state_store.py
Normal file
47
alembic/versions/90aa88820eab_add_matrix_state_store.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Add Matrix state store
|
||||||
|
|
||||||
|
Revision ID: 90aa88820eab
|
||||||
|
Revises: 4b93300852aa
|
||||||
|
Create Date: 2020-07-12 01:50:06.215623
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from mautrix.client.state_store.sqlalchemy import SerializableType
|
||||||
|
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '90aa88820eab'
|
||||||
|
down_revision = '4b93300852aa'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('mx_room_state',
|
||||||
|
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
|
||||||
|
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('room_id')
|
||||||
|
)
|
||||||
|
op.create_table('mx_user_profile',
|
||||||
|
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
|
||||||
|
sa.Column('displayname', sa.String(), nullable=True),
|
||||||
|
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('room_id', 'user_id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('mx_user_profile')
|
||||||
|
op.drop_table('mx_room_state')
|
||||||
|
# ### end Alembic commands ###
|
@ -5,6 +5,18 @@
|
|||||||
# Postgres: postgres://username:password@hostname/dbname
|
# Postgres: postgres://username:password@hostname/dbname
|
||||||
database: sqlite:///maubot.db
|
database: sqlite:///maubot.db
|
||||||
|
|
||||||
|
# Database for encryption data.
|
||||||
|
crypto_database:
|
||||||
|
# Type of database. Either "default", "pickle" or "postgres".
|
||||||
|
# When set to default, using SQLite as the main database will use pickle as the crypto database
|
||||||
|
# and using Postgres as the main database will use the same one as the crypto database.
|
||||||
|
#
|
||||||
|
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
|
||||||
|
# When using non-default postgres, postgres_uri is used to connect to postgres.
|
||||||
|
type: default
|
||||||
|
postgres_uri: postgres://username:password@hostname/dbname
|
||||||
|
pickle_dir: ./crypto
|
||||||
|
|
||||||
plugin_directories:
|
plugin_directories:
|
||||||
# The directory where uploaded new plugins should be stored.
|
# The directory where uploaded new plugins should be stored.
|
||||||
upload: ./plugins
|
upload: ./plugins
|
||||||
|
@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
|
|||||||
|
|
||||||
init_zip_loader(config)
|
init_zip_loader(config)
|
||||||
db_engine = init_db(config)
|
db_engine = init_db(config)
|
||||||
clients = init_client_class(loop)
|
clients = init_client_class(config, loop)
|
||||||
management_api = init_mgmt_api(config, loop)
|
management_api = init_mgmt_api(config, loop)
|
||||||
server = MaubotServer(management_api, config, loop)
|
server = MaubotServer(management_api, config, loop)
|
||||||
plugins = init_plugin_instance_class(config, server, loop)
|
plugins = init_plugin_instance_class(config, server, loop)
|
||||||
@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
|
|||||||
try:
|
try:
|
||||||
log.info("Starting server")
|
log.info("Starting server")
|
||||||
loop.run_until_complete(server.start())
|
loop.run_until_complete(server.start())
|
||||||
|
if Client.crypto_db:
|
||||||
|
log.debug("Starting client crypto database")
|
||||||
|
loop.run_until_complete(Client.crypto_db.start())
|
||||||
log.info("Starting clients and plugins")
|
log.info("Starting clients and plugins")
|
||||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||||
log.info("Startup actions complete, running forever")
|
log.info("Startup actions complete, running forever")
|
||||||
|
@ -18,12 +18,13 @@ from io import BytesIO
|
|||||||
import zipfile
|
import zipfile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mautrix.client.api.types.util import SerializerError
|
|
||||||
from ruamel.yaml import YAML, YAMLError
|
from ruamel.yaml import YAML, YAMLError
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from PyInquirer import prompt
|
from PyInquirer import prompt
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from mautrix.types import SerializerError
|
||||||
|
|
||||||
from ...loader import PluginMeta
|
from ...loader import PluginMeta
|
||||||
from ..cliq.validators import PathValidator
|
from ..cliq.validators import PathValidator
|
||||||
from ..base import app
|
from ..base import app
|
||||||
|
@ -18,9 +18,10 @@ import asyncio
|
|||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from aiohttp import WSMsgType, WSMessage, ClientSession
|
from aiohttp import WSMsgType, WSMessage, ClientSession
|
||||||
from mautrix.client.api.types.util import Obj
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from mautrix.types import Obj
|
||||||
|
|
||||||
from ..config import get_token
|
from ..config import get_token
|
||||||
from ..base import app
|
from ..base import app
|
||||||
|
|
||||||
|
@ -13,24 +13,46 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||||
|
from os import path
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||||
PresenceState, StateFilter)
|
PresenceState, StateFilter)
|
||||||
from mautrix.client import InternalEventType
|
from mautrix.client import InternalEventType
|
||||||
|
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||||
|
|
||||||
from .lib.store_proxy import ClientStoreProxy
|
from .lib.store_proxy import SyncStoreProxy
|
||||||
from .db import DBClient
|
from .db import DBClient
|
||||||
from .matrix import MaubotMatrixClient
|
from .matrix import MaubotMatrixClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
|
||||||
|
PickleCryptoStore)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
||||||
|
SQLStateStore = BaseSQLStateStore
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mautrix.util.async_db import Database as AsyncDatabase
|
||||||
|
from mautrix.crypto import PgCryptoStore
|
||||||
|
except ImportError:
|
||||||
|
AsyncDatabase = None
|
||||||
|
PgCryptoStore = None
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .instance import PluginInstance
|
from .instance import PluginInstance
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
log = logging.getLogger("maubot.client")
|
log = logging.getLogger("maubot.client")
|
||||||
|
|
||||||
@ -40,10 +62,15 @@ class Client:
|
|||||||
loop: asyncio.AbstractEventLoop = None
|
loop: asyncio.AbstractEventLoop = None
|
||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
|
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||||
|
crypto_pickle_dir: str = None
|
||||||
|
crypto_db: 'AsyncDatabase' = None
|
||||||
|
|
||||||
references: Set['PluginInstance']
|
references: Set['PluginInstance']
|
||||||
db_instance: DBClient
|
db_instance: DBClient
|
||||||
client: MaubotMatrixClient
|
client: MaubotMatrixClient
|
||||||
|
crypto: Optional['OlmMachine']
|
||||||
|
crypto_store: Optional['CryptoStore']
|
||||||
started: bool
|
started: bool
|
||||||
|
|
||||||
remote_displayname: Optional[str]
|
remote_displayname: Optional[str]
|
||||||
@ -61,7 +88,15 @@ class Client:
|
|||||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||||
token=self.access_token, client_session=self.http_client,
|
token=self.access_token, client_session=self.http_client,
|
||||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||||
store=ClientStoreProxy(self.db_instance))
|
sync_store=SyncStoreProxy(self.db_instance),
|
||||||
|
state_store=self.global_state_store)
|
||||||
|
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
|
||||||
|
self.crypto_store = self._make_crypto_store()
|
||||||
|
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||||
|
self.client.crypto = self.crypto
|
||||||
|
else:
|
||||||
|
self.crypto_store = None
|
||||||
|
self.crypto = None
|
||||||
self.client.ignore_initial_sync = True
|
self.client.ignore_initial_sync = True
|
||||||
self.client.ignore_first_sync = True
|
self.client.ignore_first_sync = True
|
||||||
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
|
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
|
||||||
@ -71,6 +106,14 @@ class Client:
|
|||||||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||||
|
|
||||||
|
def _make_crypto_store(self) -> 'CryptoStore':
|
||||||
|
if self.crypto_db:
|
||||||
|
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||||
|
elif self.crypto_pickle_dir:
|
||||||
|
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
|
||||||
|
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
|
||||||
|
raise ValueError("Crypto database not configured")
|
||||||
|
|
||||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||||
async def handler(data: Dict[str, Any]) -> None:
|
async def handler(data: Dict[str, Any]) -> None:
|
||||||
self.sync_ok = ok
|
self.sync_ok = ok
|
||||||
@ -130,6 +173,16 @@ class Client:
|
|||||||
await self.client.set_displayname(self.displayname)
|
await self.client.set_displayname(self.displayname)
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
await self.client.set_avatar_url(self.avatar_url)
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
|
if self.crypto:
|
||||||
|
self.log.debug("Enabling end-to-end encryption support")
|
||||||
|
await self.crypto_store.open()
|
||||||
|
crypto_device_id = await self.crypto_store.get_device_id()
|
||||||
|
if crypto_device_id and crypto_device_id != self.device_id:
|
||||||
|
self.log.warning("Mismatching device ID in crypto store and main database. "
|
||||||
|
"Encryption may not work.")
|
||||||
|
await self.crypto.load()
|
||||||
|
if not crypto_device_id:
|
||||||
|
await self.crypto_store.put_device_id(self.device_id)
|
||||||
self.start_sync()
|
self.start_sync()
|
||||||
await self._update_remote_profile()
|
await self._update_remote_profile()
|
||||||
self.started = True
|
self.started = True
|
||||||
@ -154,6 +207,8 @@ class Client:
|
|||||||
self.started = False
|
self.started = False
|
||||||
await self.stop_plugins()
|
await self.stop_plugins()
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
|
if self.crypto:
|
||||||
|
await self.crypto_store.close()
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
@ -172,6 +227,7 @@ class Client:
|
|||||||
"id": self.id,
|
"id": self.id,
|
||||||
"homeserver": self.homeserver,
|
"homeserver": self.homeserver,
|
||||||
"access_token": self.access_token,
|
"access_token": self.access_token,
|
||||||
|
"device_id": self.device_id,
|
||||||
"enabled": self.enabled,
|
"enabled": self.enabled,
|
||||||
"started": self.started,
|
"started": self.started,
|
||||||
"sync": self.sync,
|
"sync": self.sync,
|
||||||
@ -243,11 +299,12 @@ class Client:
|
|||||||
return
|
return
|
||||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||||
token=access_token or self.access_token, loop=self.loop,
|
token=access_token or self.access_token, loop=self.loop,
|
||||||
client_session=self.http_client, log=self.log)
|
client_session=self.http_client, device_id=self.device_id,
|
||||||
|
log=self.log, state_store=self.global_state_store)
|
||||||
mxid = await new_client.whoami()
|
mxid = await new_client.whoami()
|
||||||
if mxid != self.id:
|
if mxid != self.id:
|
||||||
raise ValueError(f"MXID mismatch: {mxid}")
|
raise ValueError(f"MXID mismatch: {mxid}")
|
||||||
new_client.store = self.db_instance
|
new_client.sync_store = self.db_instance
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
self.client = new_client
|
self.client = new_client
|
||||||
self.db_instance.homeserver = homeserver
|
self.db_instance.homeserver = homeserver
|
||||||
@ -341,7 +398,30 @@ class Client:
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||||
Client.http_client = ClientSession(loop=loop)
|
Client.http_client = ClientSession(loop=loop)
|
||||||
Client.loop = loop
|
Client.loop = loop
|
||||||
|
|
||||||
|
if OlmMachine:
|
||||||
|
db_type = config["crypto_database.type"]
|
||||||
|
if db_type == "default":
|
||||||
|
db_url = config["database"]
|
||||||
|
parsed_url = URL(db_url)
|
||||||
|
if parsed_url.scheme == "sqlite":
|
||||||
|
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||||
|
elif parsed_url.scheme == "postgres":
|
||||||
|
if not PgCryptoStore:
|
||||||
|
log.warning("Default database is postgres, but asyncpg is not installed. "
|
||||||
|
"Encryption will not work.")
|
||||||
|
else:
|
||||||
|
Client.crypto_db = AsyncDatabase(url=db_url,
|
||||||
|
upgrade_table=PgCryptoStore.upgrade_table)
|
||||||
|
elif db_type == "pickle":
|
||||||
|
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||||
|
elif db_type == "postgres" and PgCryptoStore:
|
||||||
|
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
|
||||||
|
upgrade_table=PgCryptoStore.upgrade_table)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported crypto database type")
|
||||||
|
|
||||||
return Client.all()
|
return Client.all()
|
||||||
|
@ -32,6 +32,9 @@ class Config(BaseFileConfig):
|
|||||||
base = helper.base
|
base = helper.base
|
||||||
copy = helper.copy
|
copy = helper.copy
|
||||||
copy("database")
|
copy("database")
|
||||||
|
copy("crypto_database.type")
|
||||||
|
copy("crypto_database.postgres_uri")
|
||||||
|
copy("crypto_database.pickle_dir")
|
||||||
copy("plugin_directories.upload")
|
copy("plugin_directories.upload")
|
||||||
copy("plugin_directories.load")
|
copy("plugin_directories.load")
|
||||||
copy("plugin_directories.trash")
|
copy("plugin_directories.trash")
|
||||||
|
@ -23,6 +23,7 @@ import sqlalchemy as sql
|
|||||||
|
|
||||||
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
|
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
|
||||||
from mautrix.util.db import Base
|
from mautrix.util.db import Base
|
||||||
|
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
|
||||||
@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
|
|||||||
db = sql.create_engine(config["database"])
|
db = sql.create_engine(config["database"])
|
||||||
Base.metadata.bind = db
|
Base.metadata.bind = db
|
||||||
|
|
||||||
for table in (DBPlugin, DBClient):
|
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
||||||
table.bind(db)
|
table.bind(db)
|
||||||
|
|
||||||
if not db.has_table("alembic_version"):
|
if not db.has_table("alembic_version"):
|
||||||
|
@ -13,11 +13,11 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from mautrix.client import ClientStore
|
from mautrix.client import SyncStore
|
||||||
from mautrix.types import SyncToken
|
from mautrix.types import SyncToken
|
||||||
|
|
||||||
|
|
||||||
class ClientStoreProxy(ClientStore):
|
class SyncStoreProxy(SyncStore):
|
||||||
def __init__(self, db_instance) -> None:
|
def __init__(self, db_instance) -> None:
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
|
|
||||||
|
@ -19,8 +19,8 @@ import asyncio
|
|||||||
|
|
||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
from packaging.version import Version, InvalidVersion
|
from packaging.version import Version, InvalidVersion
|
||||||
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
|
|
||||||
deserializer)
|
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
|
||||||
|
|
||||||
from ..__meta__ import __version__
|
from ..__meta__ import __version__
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
@ -22,7 +22,8 @@ import os
|
|||||||
|
|
||||||
from ruamel.yaml import YAML, YAMLError
|
from ruamel.yaml import YAML, YAMLError
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from mautrix.client.api.types.util import SerializerError
|
|
||||||
|
from mautrix.types import SerializerError
|
||||||
|
|
||||||
from ..lib.zipimport import zipimporter, ZipImportError
|
from ..lib.zipimport import zipimporter, ZipImportError
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
@ -13,13 +13,14 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Union, Awaitable, Optional, Tuple
|
from typing import Union, Awaitable, Optional, Tuple, List
|
||||||
from html import escape
|
from html import escape
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from mautrix.client import Client as MatrixClient, SyncStream
|
from mautrix.client import Client as MatrixClient, SyncStream
|
||||||
from mautrix.util.formatter import parse_html
|
from mautrix.util import markdown, formatter
|
||||||
from mautrix.util import markdown
|
|
||||||
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
|
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
|
||||||
MessageType, TextMessageEventContent, Format, RelatesTo)
|
MessageType, TextMessageEventContent, Format, RelatesTo)
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
|
|||||||
html = message
|
html = message
|
||||||
else:
|
else:
|
||||||
return message, escape(message)
|
return message, escape(message)
|
||||||
return parse_html(html), html
|
return formatter.parse_html(html), html
|
||||||
|
|
||||||
|
|
||||||
class MaubotMessageEvent(MessageEvent):
|
class MaubotMessageEvent(MessageEvent):
|
||||||
@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
|
|||||||
content.set_edit(edits)
|
content.set_edit(edits)
|
||||||
return self.send_message(room_id, content, **kwargs)
|
return self.send_message(room_id, content, **kwargs)
|
||||||
|
|
||||||
async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None:
|
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
|
||||||
if isinstance(event, MessageEvent):
|
if isinstance(event, MessageEvent):
|
||||||
event = MaubotMessageEvent(event, self)
|
event = MaubotMessageEvent(event, self)
|
||||||
elif source != SyncStream.INTERNAL:
|
elif source != SyncStream.INTERNAL:
|
||||||
event.client = self
|
event.client = self
|
||||||
return await super().dispatch_event(event, source)
|
return super().dispatch_event(event, source)
|
||||||
|
|
||||||
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
|
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
|
||||||
event = await super().get_event(room_id, event_id)
|
event = await super().get_event(room_id, event_id)
|
||||||
|
@ -36,7 +36,7 @@ from .config import Config
|
|||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
from ..loader import PluginMeta
|
from ..loader import PluginMeta
|
||||||
from ..matrix import MaubotMatrixClient
|
from ..matrix import MaubotMatrixClient
|
||||||
from ..lib.store_proxy import ClientStoreProxy
|
from ..lib.store_proxy import SyncStoreProxy
|
||||||
from ..__meta__ import __version__
|
from ..__meta__ import __version__
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -143,7 +143,7 @@ async def main():
|
|||||||
global client, bot
|
global client, bot
|
||||||
|
|
||||||
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
|
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
|
||||||
client_session=http_client, loop=loop, store=ClientStoreProxy(nb),
|
client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
|
||||||
log=logging.getLogger("maubot.client").getChild(user_id))
|
log=logging.getLogger("maubot.client").getChild(user_id))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
11
optional-requirements.txt
Normal file
11
optional-requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Format: #/name defines a new extras_require group called name
|
||||||
|
# Uncommented lines after the group definition insert things into that group.
|
||||||
|
|
||||||
|
#/postgres
|
||||||
|
psycopg2-binary>=2,<3
|
||||||
|
|
||||||
|
#/e2be
|
||||||
|
asyncpg>=0.20,<0.21
|
||||||
|
python-olm>=3,<4
|
||||||
|
pycryptodome>=3,<4
|
||||||
|
unpaddedbase64>=1,<2
|
@ -1,4 +1,4 @@
|
|||||||
mautrix==0.6.0.beta7
|
mautrix==0.6.0rc1
|
||||||
aiohttp>=3,<4
|
aiohttp>=3,<4
|
||||||
SQLAlchemy>=1,<2
|
SQLAlchemy>=1,<2
|
||||||
alembic>=1,<2
|
alembic>=1,<2
|
||||||
|
14
setup.py
14
setup.py
@ -5,6 +5,19 @@ import os
|
|||||||
with open("requirements.txt") as reqs:
|
with open("requirements.txt") as reqs:
|
||||||
install_requires = reqs.read().splitlines()
|
install_requires = reqs.read().splitlines()
|
||||||
|
|
||||||
|
with open("optional-requirements.txt") as reqs:
|
||||||
|
extras_require = {}
|
||||||
|
current = []
|
||||||
|
for line in reqs.read().splitlines():
|
||||||
|
if line.startswith("#/"):
|
||||||
|
extras_require[line[2:]] = current = []
|
||||||
|
elif not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
current.append(line)
|
||||||
|
|
||||||
|
extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps})
|
||||||
|
|
||||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
|
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
|
||||||
__version__ = "UNKNOWN"
|
__version__ = "UNKNOWN"
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
@ -25,6 +38,7 @@ setuptools.setup(
|
|||||||
packages=setuptools.find_packages(),
|
packages=setuptools.find_packages(),
|
||||||
|
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
|
extras_require=extras_require,
|
||||||
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
Loading…
Reference in New Issue
Block a user