Add SSO support to mbc auth

This commit is contained in:
Tulir Asokan 2021-11-20 16:23:06 +02:00
parent f2bae18c7a
commit ca7a980081
7 changed files with 177 additions and 40 deletions

View File

@ -73,6 +73,11 @@ def command(help: str) -> Callable[[Callable], Callable]:
required_unless = questions[key].pop("required_unless") required_unless = questions[key].pop("required_unless")
if isinstance(required_unless, str) and kwargs[required_unless]: if isinstance(required_unless, str) and kwargs[required_unless]:
questions.pop(key) questions.pop(key)
elif isinstance(required_unless, list):
for v in required_unless:
if kwargs[v]:
questions.pop(key)
break
elif isinstance(required_unless, dict): elif isinstance(required_unless, dict):
for k, v in required_unless.items(): for k, v in required_unless.items():
if kwargs.get(v, object()) == v: if kwargs.get(v, object()) == v:
@ -118,7 +123,7 @@ def option(short: str, long: str, message: str = None, help: str = None,
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None, click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
validator: Type[Validator] = None, required: bool = False, validator: Type[Validator] = None, required: bool = False,
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True, default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
required_unless: str = None) -> Callable[[Callable], Callable]: required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]:
if not message: if not message:
message = long[2].upper() + long[3:] message = long[2].upper() + long[3:]

View File

@ -13,6 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import webbrowser
import json import json
from colorama import Fore from colorama import Fore
@ -28,7 +29,9 @@ friendly_errors = {
"server_not_found": "Registration target server not found.\n\n" "server_not_found": "Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n" "To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n" "homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty." "leave the `secret` field empty.",
"registration_no_sso": "The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag.",
} }
@ -43,9 +46,10 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
@cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list") @cliq.option("-u", "--username", help="The username to log in with",
required_unless=["list", "sso"])
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password", @cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
required_unless="list") required_unless=["list", "sso"])
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="", @cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
required=False, prompt=False) required=False, prompt=False)
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True, @click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
@ -54,39 +58,69 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
"create or update a client in maubot using it", "create or update a client in maubot using it",
is_flag=True, default=False) is_flag=True, default=False)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option("-o", "--sso", help="Use single sign-on instead of password login",
is_flag=True, default=False)
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)", @click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
default="Maubot", required=False) default="Maubot", required=False)
@cliq.with_authenticated_http @cliq.with_authenticated_http
async def auth(homeserver: str, username: str, password: str, server: str, register: bool, async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
list: bool, update_client: bool, device_name: str, sess: aiohttp.ClientSession list: bool, update_client: bool, device_name: str, sso: bool,
) -> None: sess: aiohttp.ClientSession) -> None:
if list: if list:
await list_servers(server, sess) await list_servers(server, sess)
return return
endpoint = "register" if register else "login" endpoint = "register" if register else "login"
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
if update_client: if update_client:
url = url.with_query({"update_client": "true"}) url = url.update_query({"update_client": "true"})
req_data = {"username": username, "password": password, "device_name": device_name} if sso:
url = url.update_query({"sso": "true"})
req_data = {"device_name": device_name}
else:
req_data = {"username": username, "password": password, "device_name": device_name}
action = "registered" if register else "logged in as"
async with sess.post(url, json=req_data) as resp: async with sess.post(url, json=req_data) as resp:
if resp.status == 200: if not 200 <= resp.status < 300:
data = await resp.json() await print_error(resp, action)
action = "registered" if register else "logged in as" elif sso:
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") await wait_sso(resp, sess, server, homeserver)
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
else: else:
try: await print_response(resp, action)
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession,
error = await resp.text() server: str, homeserver: str) -> None:
action = "register" if register else "log in" data = await resp.json()
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
webbrowser.open(sso_url, autoraise=True)
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
async with sess.post(wait_url, json={}) as resp:
await print_response(resp, "logged in as")
async def print_response(resp: aiohttp.ClientResponse, action: str) -> None:
if resp.status == 200:
data = await resp.json()
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
else:
await print_error(resp, action)
async def print_error(resp: aiohttp.ClientResponse, action: str) -> None:
try:
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = await resp.text()
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")

View File

@ -241,7 +241,8 @@ class Client:
"homeserver": self.homeserver, "homeserver": self.homeserver,
"access_token": self.access_token, "access_token": self.access_token,
"device_id": self.device_id, "device_id": self.device_id,
"fingerprint": self.crypto.account.fingerprint if self.crypto else None, "fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
else None),
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"sync": self.sync, "sync": self.sync,

View File

@ -68,8 +68,8 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
enabled=data.get("enabled", True), next_batch=SyncToken(""), enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True), filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True), online=data.get("online", True), autojoin=data.get("autojoin", True), online=data.get("online", True),
displayname=data.get("displayname", ""), displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", ""), avatar_url=data.get("avatar_url", "disable"),
device_id=device_id) device_id=device_id)
client = Client(db_instance) client = Client(db_instance)
client.db_instance.insert() client.db_instance.insert()

View File

@ -17,12 +17,15 @@ from typing import Dict, Tuple, NamedTuple, Optional
from json import JSONDecodeError from json import JSONDecodeError
from http import HTTPStatus from http import HTTPStatus
import hashlib import hashlib
import asyncio
import random import random
import string import string
import hmac import hmac
from aiohttp import web from aiohttp import web
from mautrix.api import SynapseAdminPath, Method from yarl import URL
from mautrix.api import SynapseAdminPath, Method, Path
from mautrix.errors import MatrixRequestError from mautrix.errors import MatrixRequestError
from mautrix.client import ClientAPI from mautrix.client import ClientAPI
from mautrix.types import LoginType, LoginResponse from mautrix.types import LoginType, LoginResponse
@ -42,6 +45,7 @@ async def get_known_servers(_: web.Request) -> web.Response:
class AuthRequestInfo(NamedTuple): class AuthRequestInfo(NamedTuple):
server_name: str
client: ClientAPI client: ClientAPI
secret: str secret: str
username: str username: str
@ -49,6 +53,10 @@ class AuthRequestInfo(NamedTuple):
user_type: str user_type: str
device_name: str device_name: str
update_client: bool update_client: bool
sso: bool
truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
@ -61,23 +69,28 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
body = await request.json() body = await request.json()
except JSONDecodeError: except JSONDecodeError:
return None, resp.body_not_json return None, resp.body_not_json
sso = request.query.get("sso", "").lower() in truthy_strings
try: try:
username = body["username"] username = body["username"]
password = body["password"] password = body["password"]
except KeyError: except KeyError:
return None, resp.username_or_password_missing if not sso:
return None, resp.username_or_password_missing
username = password = None
try: try:
base_url = server["url"] base_url = server["url"]
except KeyError: except KeyError:
return None, resp.invalid_server return None, resp.invalid_server
return AuthRequestInfo( return AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()), client=ClientAPI(base_url=base_url, loop=get_loop()),
secret=server.get("secret"), secret=server.get("secret"),
username=username, username=username,
password=password, password=password,
user_type=body.get("user_type", "bot"), user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"), device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in ("1", "true", "yes"), update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
), None ), None
@ -102,7 +115,9 @@ async def register(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request) req, err = await read_client_auth_request(request)
if err is not None: if err is not None:
return err return err
if not req.secret: if req.sso:
return resp.registration_no_sso
elif not req.secret:
return resp.registration_secret_not_found return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register path = SynapseAdminPath.v1.register
res = await req.client.api.request(Method.GET, path) res = await req.client.api.request(Method.GET, path)
@ -137,12 +152,40 @@ async def login(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request) req, err = await read_client_auth_request(request)
if err is not None: if err is not None:
return err return err
device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) if req.sso:
return await _do_sso(req)
else:
return await _do_login(req)
async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/")
/ "client/auth_external_sso/complete" / waiter_id)
sso_url = (req.client.api.base_url
.with_path(str(Path.login.sso.redirect))
.with_query({"redirectUrl": str(public_url)}))
sso_waiters[waiter_id] = req, get_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try: try:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD, if req.sso:
password=req.password, device_id=f"maubot_{device_id}", res = await req.client.login(token=login_token, login_type=LoginType.TOKEN,
initial_device_display_name=req.device_name, device_id=device_id, store_access_token=False,
store_access_token=False) initial_device_display_name=req.device_name)
else:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
password=req.password, device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False)
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response({ return web.json_response({
"errcode": e.errcode, "errcode": e.errcode,
@ -155,3 +198,38 @@ async def login(request: web.Request) -> web.Response:
"device_id": res.device_id, "device_id": res.device_id,
}, is_login=True) }, is_login=True)
return web.json_response(res.serialize()) return web.json_response(res.serialize())
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
async def wait_sso(request: web.Request) -> web.Response:
waiter_id = request.match_info["id"]
req, fut = sso_waiters[waiter_id]
try:
login_token = await fut
finally:
sso_waiters.pop(waiter_id, None)
return await _do_login(req, login_token)
@routes.get("/client/auth_external_sso/complete/{id}")
async def complete_sso(request: web.Request) -> web.Response:
try:
_, fut = sso_waiters[request.match_info["id"]]
except KeyError:
return web.Response(status=404, text="Invalid session ID\n")
if fut.cancelled():
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
elif fut.done():
return web.Response(status=200, text="The login token was already received\n")
try:
fut.set_result(request.query["loginToken"])
except KeyError:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n")

View File

@ -29,7 +29,12 @@ log = logging.getLogger("maubot.server")
@web.middleware @web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response: async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len(get_config()["server.base_path"]):] subpath = request.path[len(get_config()["server.base_path"]):]
if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs": if (
subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/")
or subpath == "/features"
or subpath == "/logs"
):
return await handler(request) return await handler(request)
err = check_token(request) err = check_token(request)
if err is not None: if err is not None:

View File

@ -194,6 +194,20 @@ class _Response:
"errcode": "registration_secret_not_found", "errcode": "registration_secret_not_found",
}, status=HTTPStatus.NOT_FOUND) }, status=HTTPStatus.NOT_FOUND)
@property
def registration_no_sso(self) -> web.Response:
return web.json_response({
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
}, status=HTTPStatus.BAD_REQUEST)
@property
def sso_not_supported(self) -> web.Response:
return web.json_response({
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
}, status=HTTPStatus.FORBIDDEN)
@property @property
def plugin_has_no_database(self) -> web.Response: def plugin_has_no_database(self) -> web.Response:
return web.json_response({ return web.json_response({