diff --git a/maubot/cli/cliq/cliq.py b/maubot/cli/cliq/cliq.py index 5806cb6..a65e77a 100644 --- a/maubot/cli/cliq/cliq.py +++ b/maubot/cli/cliq/cliq.py @@ -73,6 +73,11 @@ def command(help: str) -> Callable[[Callable], Callable]: required_unless = questions[key].pop("required_unless") if isinstance(required_unless, str) and kwargs[required_unless]: questions.pop(key) + elif isinstance(required_unless, list): + for v in required_unless: + if kwargs[v]: + questions.pop(key) + break elif isinstance(required_unless, dict): for k, v in required_unless.items(): if kwargs.get(v, object()) == v: @@ -118,7 +123,7 @@ def option(short: str, long: str, message: str = None, help: str = None, click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None, validator: Type[Validator] = None, required: bool = False, default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True, - required_unless: str = None) -> Callable[[Callable], Callable]: + required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]: if not message: message = long[2].upper() + long[3:] diff --git a/maubot/cli/commands/auth.py b/maubot/cli/commands/auth.py index 28afafb..444e3b3 100644 --- a/maubot/cli/commands/auth.py +++ b/maubot/cli/commands/auth.py @@ -13,6 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import webbrowser import json from colorama import Fore @@ -28,7 +29,9 @@ friendly_errors = { "server_not_found": "Registration target server not found.\n\n" "To log in or register through maubot, you must add the server to the\n" "homeservers section in the config. If you only want to log in,\n" - "leave the `secret` field empty." + "leave the `secret` field empty.", + "registration_no_sso": "The register operation is only for registering with a password.\n\n" + "To register with SSO, simply leave out the --register flag.", } @@ -43,9 +46,10 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None: @cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") -@cliq.option("-u", "--username", help="The username to log in with", required_unless="list") +@cliq.option("-u", "--username", help="The username to log in with", + required_unless=["list", "sso"]) @cliq.option("-p", "--password", help="The password to log in with", inq_type="password", - required_unless="list") + required_unless=["list", "sso"]) @cliq.option("-s", "--server", help="The maubot instance to log in through", default="", required=False, prompt=False) @click.option("-r", "--register", help="Register instead of logging in", is_flag=True, @@ -54,39 +58,69 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None: "create or update a client in maubot using it", is_flag=True, default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) +@click.option("-o", "--sso", help="Use single sign-on instead of password login", + is_flag=True, default=False) @click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)", default="Maubot", required=False) @cliq.with_authenticated_http async def auth(homeserver: str, username: str, password: str, server: str, register: bool, - list: bool, update_client: bool, device_name: str, sess: aiohttp.ClientSession - ) -> None: + list: bool, update_client: bool, device_name: str, sso: bool, + sess: aiohttp.ClientSession) -> None: if list: await list_servers(server, sess) return endpoint = "register" if register else "login" url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint if update_client: - url = url.with_query({"update_client": "true"}) - req_data = {"username": username, "password": password, "device_name": device_name} + url = url.update_query({"update_client": "true"}) + if sso: + url = url.update_query({"sso": "true"}) + req_data = {"device_name": device_name} + else: + req_data = {"username": username, "password": password, "device_name": device_name} + action = "registered" if register else "logged in as" async with sess.post(url, json=req_data) as resp: - if resp.status == 200: - data = await resp.json() - action = "registered" if register else "logged in as" - print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") - print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}") - print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}") - elif resp.status in (201, 202): - data = await resp.json() - action = "created" if resp.status == 201 else "updated" - print(f"{Fore.GREEN}Successfully {action} client for " - f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " - f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}") + if not 200 <= resp.status < 300: + await print_error(resp, action) + elif sso: + await wait_sso(resp, sess, server, homeserver) else: - try: - err_data = await resp.json() - error = friendly_errors.get(err_data["errcode"], err_data["error"]) - except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): - error = await resp.text() - action = "register" if register else "log in" - print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") + await print_response(resp, action) + + +async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, + server: str, homeserver: str) -> None: + data = await resp.json() + sso_url, reg_id = data["sso_url"], data["id"] + print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}") + webbrowser.open(sso_url, autoraise=True) + print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}") + wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait" + async with sess.post(wait_url, json={}) as resp: + await print_response(resp, "logged in as") + + +async def print_response(resp: aiohttp.ClientResponse, action: str) -> None: + if resp.status == 200: + data = await resp.json() + print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") + print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}") + print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}") + elif resp.status in (201, 202): + data = await resp.json() + action = "created" if resp.status == 201 else "updated" + print(f"{Fore.GREEN}Successfully {action} client for " + f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " + f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}") + else: + await print_error(resp, action) + + +async def print_error(resp: aiohttp.ClientResponse, action: str) -> None: + try: + err_data = await resp.json() + error = friendly_errors.get(err_data["errcode"], err_data["error"]) + except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): + error = await resp.text() + print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") diff --git a/maubot/client.py b/maubot/client.py index b02ec24..5495c5b 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -241,7 +241,8 @@ class Client: "homeserver": self.homeserver, "access_token": self.access_token, "device_id": self.device_id, - "fingerprint": self.crypto.account.fingerprint if self.crypto else None, + "fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account + else None), "enabled": self.enabled, "started": self.started, "sync": self.sync, diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index 27f2150..3d88506 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -68,8 +68,8 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: enabled=data.get("enabled", True), next_batch=SyncToken(""), filter_id=FilterID(""), sync=data.get("sync", True), autojoin=data.get("autojoin", True), online=data.get("online", True), - displayname=data.get("displayname", ""), - avatar_url=data.get("avatar_url", ""), + displayname=data.get("displayname", "disable"), + avatar_url=data.get("avatar_url", "disable"), device_id=device_id) client = Client(db_instance) client.db_instance.insert() diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py index 863767d..5957f43 100644 --- a/maubot/management/api/client_auth.py +++ b/maubot/management/api/client_auth.py @@ -17,12 +17,15 @@ from typing import Dict, Tuple, NamedTuple, Optional from json import JSONDecodeError from http import HTTPStatus import hashlib +import asyncio import random import string import hmac from aiohttp import web -from mautrix.api import SynapseAdminPath, Method +from yarl import URL + +from mautrix.api import SynapseAdminPath, Method, Path from mautrix.errors import MatrixRequestError from mautrix.client import ClientAPI from mautrix.types import LoginType, LoginResponse @@ -42,6 +45,7 @@ async def get_known_servers(_: web.Request) -> web.Response: class AuthRequestInfo(NamedTuple): + server_name: str client: ClientAPI secret: str username: str @@ -49,6 +53,10 @@ class AuthRequestInfo(NamedTuple): user_type: str device_name: str update_client: bool + sso: bool + + +truthy_strings = ("1", "true", "yes") async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], @@ -61,23 +69,28 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR body = await request.json() except JSONDecodeError: return None, resp.body_not_json + sso = request.query.get("sso", "").lower() in truthy_strings try: username = body["username"] password = body["password"] except KeyError: - return None, resp.username_or_password_missing + if not sso: + return None, resp.username_or_password_missing + username = password = None try: base_url = server["url"] except KeyError: return None, resp.invalid_server return AuthRequestInfo( + server_name=server_name, client=ClientAPI(base_url=base_url, loop=get_loop()), secret=server.get("secret"), username=username, password=password, user_type=body.get("user_type", "bot"), device_name=body.get("device_name", "Maubot"), - update_client=request.query.get("update_client", "").lower() in ("1", "true", "yes"), + update_client=request.query.get("update_client", "").lower() in truthy_strings, + sso=sso, ), None @@ -102,7 +115,9 @@ async def register(request: web.Request) -> web.Response: req, err = await read_client_auth_request(request) if err is not None: return err - if not req.secret: + if req.sso: + return resp.registration_no_sso + elif not req.secret: return resp.registration_secret_not_found path = SynapseAdminPath.v1.register res = await req.client.api.request(Method.GET, path) @@ -137,12 +152,40 @@ async def login(request: web.Request) -> web.Response: req, err = await read_client_auth_request(request) if err is not None: return err - device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + if req.sso: + return await _do_sso(req) + else: + return await _do_login(req) + + +async def _do_sso(req: AuthRequestInfo) -> web.Response: + flows = await req.client.get_login_flows() + if not flows.supports_type(LoginType.SSO): + return resp.sso_not_supported + waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16)) + cfg = get_config() + public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/") + / "client/auth_external_sso/complete" / waiter_id) + sso_url = (req.client.api.base_url + .with_path(str(Path.login.sso.redirect)) + .with_query({"redirectUrl": str(public_url)})) + sso_waiters[waiter_id] = req, get_loop().create_future() + return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) + + +async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response: + device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + device_id = f"maubot_{device_id}" try: - res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD, - password=req.password, device_id=f"maubot_{device_id}", - initial_device_display_name=req.device_name, - store_access_token=False) + if req.sso: + res = await req.client.login(token=login_token, login_type=LoginType.TOKEN, + device_id=device_id, store_access_token=False, + initial_device_display_name=req.device_name) + else: + res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD, + password=req.password, device_id=device_id, + initial_device_display_name=req.device_name, + store_access_token=False) except MatrixRequestError as e: return web.json_response({ "errcode": e.errcode, @@ -155,3 +198,38 @@ async def login(request: web.Request) -> web.Response: "device_id": res.device_id, }, is_login=True) return web.json_response(res.serialize()) + + +sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {} + + +@routes.post("/client/auth/{server}/sso/{id}/wait") +async def wait_sso(request: web.Request) -> web.Response: + waiter_id = request.match_info["id"] + req, fut = sso_waiters[waiter_id] + try: + login_token = await fut + finally: + sso_waiters.pop(waiter_id, None) + return await _do_login(req, login_token) + + +@routes.get("/client/auth_external_sso/complete/{id}") +async def complete_sso(request: web.Request) -> web.Response: + try: + _, fut = sso_waiters[request.match_info["id"]] + except KeyError: + return web.Response(status=404, text="Invalid session ID\n") + if fut.cancelled(): + return web.Response(status=200, text="The login was cancelled from the Maubot client\n") + elif fut.done(): + return web.Response(status=200, text="The login token was already received\n") + try: + fut.set_result(request.query["loginToken"]) + except KeyError: + return web.Response(status=400, text="Missing loginToken query parameter\n") + except asyncio.InvalidStateError: + return web.Response(status=500, text="Invalid state\n") + return web.Response(status=200, + text="Login token received, please return to your Maubot client. " + "This tab can be closed.\n") diff --git a/maubot/management/api/middleware.py b/maubot/management/api/middleware.py index ff6b4c1..ce9b253 100644 --- a/maubot/management/api/middleware.py +++ b/maubot/management/api/middleware.py @@ -29,7 +29,12 @@ log = logging.getLogger("maubot.server") @web.middleware async def auth(request: web.Request, handler: Handler) -> web.Response: subpath = request.path[len(get_config()["server.base_path"]):] - if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs": + if ( + subpath.startswith("/auth/") + or subpath.startswith("/client/auth_external_sso/complete/") + or subpath == "/features" + or subpath == "/logs" + ): return await handler(request) err = check_token(request) if err is not None: diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index d6dee04..5fbedb8 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -194,6 +194,20 @@ class _Response: "errcode": "registration_secret_not_found", }, status=HTTPStatus.NOT_FOUND) + @property + def registration_no_sso(self) -> web.Response: + return web.json_response({ + "error": "The register operation is only for registering with a password", + "errcode": "registration_no_sso", + }, status=HTTPStatus.BAD_REQUEST) + + @property + def sso_not_supported(self) -> web.Response: + return web.json_response({ + "error": "That server does not seem to support single sign-on", + "errcode": "sso_not_supported", + }, status=HTTPStatus.FORBIDDEN) + @property def plugin_has_no_database(self) -> web.Response: return web.json_response({