Add option to create/update client with mbc auth

This commit is contained in:
Tulir Asokan 2021-11-19 17:10:51 +02:00
parent 8c3e3a3255
commit 85e5ea401c
8 changed files with 234 additions and 110 deletions

View File

@ -15,15 +15,41 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Callable, Union, Optional
import functools
import inspect
import asyncio
import aiohttp
from prompt_toolkit.validation import Validator
from questionary import prompt
import click
from ..base import app
from ..config import get_token
from .validators import Required, ClickValidator
def with_http(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with aiohttp.ClientSession() as sess:
return await func(*args, sess=sess, **kwargs)
return wrapper
def with_authenticated_http(func):
@functools.wraps(func)
async def wrapper(*args, server: str, **kwargs):
server, token = get_token(server)
if not token:
return
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
return await func(*args, sess=sess, server=server, **kwargs)
return wrapper
def command(help: str) -> Callable[[Callable], Callable]:
def decorator(func) -> Callable:
questions = func.__inquirer_questions__.copy()
@ -52,7 +78,10 @@ def command(help: str) -> Callable[[Callable], Callable]:
if not resp and question_list:
return
kwargs = {**kwargs, **resp}
func(*args, **kwargs)
res = func(*args, **kwargs)
if inspect.isawaitable(res):
asyncio.run(res)
return app.command(help=help)(wrapper)

View File

@ -13,13 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.parse import quote
from urllib.request import urlopen, Request
from urllib.error import HTTPError
import functools
import json
from colorama import Fore
from yarl import URL
import aiohttp
import click
from ..config import get_token
@ -27,8 +25,6 @@ from ..cliq import cliq
history_count: int = 10
enc = functools.partial(quote, safe="")
friendly_errors = {
"server_not_found": "Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
@ -37,6 +33,15 @@ friendly_errors = {
}
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
async with sess.get(url) as resp:
data = await resp.json()
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in data.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
@cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
@ -46,42 +51,40 @@ friendly_errors = {
required=False, prompt=False)
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
default=False)
@click.option("-c", "--update-client", help="Instead of returning the access token, "
"create or update a client in maubot using it",
is_flag=True, default=False)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool
) -> None:
server, token = get_token(server)
if not token:
return
headers = {"Authorization": f"Bearer {token}"}
@cliq.with_authenticated_http
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
list: bool, update_client: bool, sess: aiohttp.ClientSession) -> None:
if list:
url = f"{server}/_matrix/maubot/v1/client/auth/servers"
with urlopen(Request(url, headers=headers)) as resp_data:
resp = json.load(resp_data)
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in resp.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
await list_servers(server, sess)
return
endpoint = "register" if register else "login"
headers["Content-Type"] = "application/json"
url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}"
req = Request(url, headers=headers,
data=json.dumps({
"username": username,
"password": password,
}).encode("utf-8"))
try:
with urlopen(req) as resp_data:
resp = json.load(resp_data)
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
if update_client:
url = url.with_query({"update_client": "true"})
req_data = {"username": username, "password": password}
async with sess.post(url, json=req_data) as resp:
if resp.status == 200:
data = await resp.json()
action = "registered" if register else "logged in as"
print(f"{Fore.GREEN}Successfully {action} "
f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}")
except HTTPError as e:
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
else:
try:
err_data = json.load(e)
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (json.JSONDecodeError, KeyError):
error = str(e)
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = await resp.text()
action = "register" if register else "log in"
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -86,10 +86,8 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
if self.enable_crypto:
self._prepare_crypto()
else:
self.crypto_store = None
self.crypto = None
@ -102,10 +100,15 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
raise ValueError("Crypto database not configured")
@property
def enable_crypto(self) -> bool:
return bool(OlmMachine and self.device_id and self.crypto_db)
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
db=self.crypto_db)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None:
@ -121,6 +124,19 @@ class Client:
except Exception:
self.log.exception("Failed to start")
async def _start_crypto(self) -> None:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database, "
"resetting encryption")
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
@ -129,7 +145,7 @@ class Client:
self.log.warning("Ignoring start() call to started client")
return
try:
user_id = await self.client.whoami()
whoami = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.db_instance.enabled = False
@ -143,8 +159,13 @@ class Client:
f"retrying in {(try_n + 1) * 10}s: {e}")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.db_instance.enabled = False
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
f"but got {whoami.device_id}")
self.db_instance.enabled = False
return
if not self.filter_id:
@ -167,15 +188,7 @@ class Client:
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.crypto:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database. "
"Encryption may not work.")
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
await self._start_crypto()
self.start_sync()
await self._update_remote_profile()
self.started = True
@ -285,23 +298,31 @@ class Client:
else:
await self._update_remote_profile()
async def update_access_details(self, access_token: str, homeserver: str) -> None:
async def update_access_details(self, access_token: str, homeserver: str,
device_id: Optional[str] = None) -> None:
if not access_token and not homeserver:
return
elif access_token == self.access_token and homeserver == self.homeserver:
return
device_id = device_id or self.device_id
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, device_id=self.device_id,
device_id=device_id, client_session=self.http_client,
log=self.log, state_store=self.global_state_store)
mxid = await new_client.whoami()
if mxid != self.id:
raise ValueError(f"MXID mismatch: {mxid}")
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = SyncStoreProxy(self.db_instance)
self.stop_sync()
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.db_instance.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
self.start_sync()
async def _update_remote_profile(self) -> None:

View File

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -45,10 +45,11 @@ async def get_client(request: web.Request) -> web.Response:
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
device_id = data.get("device_id", None)
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
loop=Client.loop, client_session=Client.http_client)
try:
mxid = await new_client.whoami()
whoami = await new_client.whoami()
except MatrixInvalidToken:
return resp.bad_client_access_token
except MatrixRequestError:
@ -56,27 +57,31 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
except MatrixConnectionError:
return resp.bad_client_connection_details
if user_id is None:
existing_client = Client.get(mxid, None)
existing_client = Client.get(whoami.user_id, None)
if existing_client is not None:
return resp.user_exists
elif mxid != user_id:
return resp.mxid_mismatch(mxid)
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
elif whoami.user_id != user_id:
return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True), online=data.get("online", True),
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
avatar_url=data.get("avatar_url", ""),
device_id=device_id)
client = Client(db_instance)
client.db_instance.insert()
await client.start()
return resp.created(client.to_dict())
async def _update_client(client: Client, data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try:
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None))
data.get("homeserver", None),
data.get("device_id", None))
except MatrixInvalidToken:
return resp.bad_client_access_token
except MatrixRequestError:
@ -93,7 +98,16 @@ async def _update_client(client: Client, data: dict) -> web.Response:
client.autojoin = data.get("autojoin", client.autojoin)
client.online = data.get("online", client.online)
client.sync = data.get("sync", client.sync)
return resp.updated(client.to_dict())
return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = Client.get(user_id, None)
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data, is_login=is_login)
@routes.post("/client/new")
@ -108,15 +122,11 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try:
data = await request.json()
except JSONDecodeError:
return resp.body_not_json
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data)
return await _create_or_update_client(user_id, data)
@routes.delete("/client/{id}")

View File

@ -25,10 +25,11 @@ from aiohttp import web
from mautrix.api import SynapseAdminPath, Method
from mautrix.errors import MatrixRequestError
from mautrix.client import ClientAPI
from mautrix.types import LoginType
from mautrix.types import LoginType, LoginResponse
from .base import routes, get_config, get_loop
from .responses import resp
from .client import _create_or_update_client, _create_client
def known_homeservers() -> Dict[str, Dict[str, str]]:
@ -46,6 +47,7 @@ class AuthRequestInfo(NamedTuple):
username: str
password: str
user_type: str
update_client: bool
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
@ -70,15 +72,16 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
secret = server.get("secret")
api = ClientAPI(base_url=base_url, loop=get_loop())
user_type = body.get("user_type", "bot")
return AuthRequestInfo(api, secret, username, password, user_type), None
update_client = request.query.get("update_client", "").lower() in ("1", "true", "yes")
return AuthRequestInfo(api, secret, username, password, user_type, update_client), None
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False,
def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False,
user_type: str = None) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(user.encode("utf-8"))
mac.update(username.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
@ -94,28 +97,34 @@ async def register(request: web.Request) -> web.Response:
info, err = await read_client_auth_request(request)
if err is not None:
return err
client: ClientAPI
client, secret, username, password, user_type = info
if not secret:
if not info.secret:
return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register
res = await client.api.request(Method.GET, path)
res = await info.client.api.request(Method.GET, path)
content = {
"nonce": res["nonce"],
"username": username,
"password": password,
"username": info.username,
"password": info.password,
"admin": False,
"mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type),
"user_type": user_type,
"user_type": info.user_type,
}
content["mac"] = generate_mac(**content, secret=info.secret)
try:
return web.json_response(await client.api.request(Method.POST, path, content=content))
raw_res = await info.client.api.request(Method.POST, path, content=content)
except MatrixRequestError as e:
return web.json_response({
"errcode": e.errcode,
"error": e.message,
"http_status": e.http_status,
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
login_res = LoginResponse.deserialize(raw_res)
if info.update_client:
return await _create_client(login_res.user_id, {
"homeserver": str(info.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
})
return web.json_response(login_res.serialize())
@routes.post("/client/auth/{server}/login")
@ -129,9 +138,15 @@ async def login(request: web.Request) -> web.Response:
res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD,
password=info.password, device_id=f"maubot_{device_id}",
initial_device_display_name="Maubot", store_access_token=False)
return web.json_response(res.serialize())
except MatrixRequestError as e:
return web.json_response({
"errcode": e.errcode,
"error": e.message,
}, status=e.http_status)
if info.update_client:
return await _create_or_update_client(res.user_id, {
"homeserver": str(client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
}, is_login=True)
return web.json_response(res.serialize())

View File

@ -69,6 +69,13 @@ class _Response:
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response({
"error": "The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
@property
def pid_mismatch(self) -> web.Response:
return web.json_response({
@ -294,8 +301,9 @@ class _Response:
def found(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK)
def updated(self, data: dict) -> web.Response:
return self.found(data)
@staticmethod
def updated(data: dict, is_login: bool = False) -> web.Response:
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def logged_in(self, token: str) -> web.Response:
return self.found({

View File

@ -366,7 +366,7 @@ paths:
schema:
$ref: '#/components/schemas/MatrixClient'
responses:
200:
202:
description: Client updated
content:
application/json:
@ -454,6 +454,12 @@ paths:
required: true
schema:
type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post:
operationId: client_auth_register
summary: |
@ -475,18 +481,29 @@ paths:
properties:
access_token:
type: string
example: token_here
example: syt_123_456_789
user_id:
type: string
example: '@putkiteippi:maunium.net'
home_server:
type: string
example: maunium.net
device_id:
type: string
example: device_id_here
example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401:
$ref: '#/components/responses/Unauthorized'
409:
description: |
There is already a client with the user ID of that token.
This should usually not happen, because the user ID was just created.
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
500:
$ref: '#/components/responses/MatrixServerError'
'/client/auth/{server}/login':
@ -497,6 +514,12 @@ paths:
required: true
schema:
type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post:
operationId: client_auth_login
summary: Log in to the given Matrix server via the maubot server
@ -519,10 +542,22 @@ paths:
example: '@putkiteippi:maunium.net'
access_token:
type: string
example: token_here
example: syt_123_456_789
device_id:
type: string
example: device_id_here
example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
202:
description: Client updated (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401:
$ref: '#/components/responses/Unauthorized'
500:
@ -641,6 +676,9 @@ components:
access_token:
type: string
description: The Matrix access token for this client.
device_id:
type: string
description: The Matrix device ID corresponding to the access token.
enabled:
type: boolean
example: true

View File

@ -144,13 +144,13 @@ async def main():
while True:
try:
whoami_user_id = await client.whoami()
whoami = await client.whoami()
except Exception:
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
await asyncio.sleep(10)
continue
if whoami_user_id != user_id:
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami_user_id}")
if whoami.user_id != user_id:
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
sys.exit(1)
break