Add SSO support to mbc auth

This commit is contained in:
Tulir Asokan 2021-11-20 16:23:06 +02:00
parent f2bae18c7a
commit ca7a980081
7 changed files with 177 additions and 40 deletions

View File

@ -73,6 +73,11 @@ def command(help: str) -> Callable[[Callable], Callable]:
required_unless = questions[key].pop("required_unless")
if isinstance(required_unless, str) and kwargs[required_unless]:
questions.pop(key)
elif isinstance(required_unless, list):
for v in required_unless:
if kwargs[v]:
questions.pop(key)
break
elif isinstance(required_unless, dict):
for k, v in required_unless.items():
if kwargs.get(v, object()) == v:
@ -118,7 +123,7 @@ def option(short: str, long: str, message: str = None, help: str = None,
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
validator: Type[Validator] = None, required: bool = False,
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
required_unless: str = None) -> Callable[[Callable], Callable]:
required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]:
if not message:
message = long[2].upper() + long[3:]

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import webbrowser
import json
from colorama import Fore
@ -28,7 +29,9 @@ friendly_errors = {
"server_not_found": "Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty."
"leave the `secret` field empty.",
"registration_no_sso": "The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag.",
}
@ -43,9 +46,10 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
@cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with",
required_unless=["list", "sso"])
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
required_unless="list")
required_unless=["list", "sso"])
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
required=False, prompt=False)
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
@ -54,39 +58,69 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
"create or update a client in maubot using it",
is_flag=True, default=False)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option("-o", "--sso", help="Use single sign-on instead of password login",
is_flag=True, default=False)
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
default="Maubot", required=False)
@cliq.with_authenticated_http
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
list: bool, update_client: bool, device_name: str, sess: aiohttp.ClientSession
) -> None:
list: bool, update_client: bool, device_name: str, sso: bool,
sess: aiohttp.ClientSession) -> None:
if list:
await list_servers(server, sess)
return
endpoint = "register" if register else "login"
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
if update_client:
url = url.with_query({"update_client": "true"})
req_data = {"username": username, "password": password, "device_name": device_name}
url = url.update_query({"update_client": "true"})
if sso:
url = url.update_query({"sso": "true"})
req_data = {"device_name": device_name}
else:
req_data = {"username": username, "password": password, "device_name": device_name}
action = "registered" if register else "logged in as"
async with sess.post(url, json=req_data) as resp:
if resp.status == 200:
data = await resp.json()
action = "registered" if register else "logged in as"
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
if not 200 <= resp.status < 300:
await print_error(resp, action)
elif sso:
await wait_sso(resp, sess, server, homeserver)
else:
try:
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = await resp.text()
action = "register" if register else "log in"
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
await print_response(resp, action)
async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession,
server: str, homeserver: str) -> None:
data = await resp.json()
sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
webbrowser.open(sso_url, autoraise=True)
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
async with sess.post(wait_url, json={}) as resp:
await print_response(resp, "logged in as")
async def print_response(resp: aiohttp.ClientResponse, action: str) -> None:
if resp.status == 200:
data = await resp.json()
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
else:
await print_error(resp, action)
async def print_error(resp: aiohttp.ClientResponse, action: str) -> None:
try:
err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = await resp.text()
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")

View File

@ -241,7 +241,8 @@ class Client:
"homeserver": self.homeserver,
"access_token": self.access_token,
"device_id": self.device_id,
"fingerprint": self.crypto.account.fingerprint if self.crypto else None,
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
else None),
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,

View File

@ -68,8 +68,8 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True), online=data.get("online", True),
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id)
client = Client(db_instance)
client.db_instance.insert()

View File

@ -17,12 +17,15 @@ from typing import Dict, Tuple, NamedTuple, Optional
from json import JSONDecodeError
from http import HTTPStatus
import hashlib
import asyncio
import random
import string
import hmac
from aiohttp import web
from mautrix.api import SynapseAdminPath, Method
from yarl import URL
from mautrix.api import SynapseAdminPath, Method, Path
from mautrix.errors import MatrixRequestError
from mautrix.client import ClientAPI
from mautrix.types import LoginType, LoginResponse
@ -42,6 +45,7 @@ async def get_known_servers(_: web.Request) -> web.Response:
class AuthRequestInfo(NamedTuple):
server_name: str
client: ClientAPI
secret: str
username: str
@ -49,6 +53,10 @@ class AuthRequestInfo(NamedTuple):
user_type: str
device_name: str
update_client: bool
sso: bool
truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
@ -61,23 +69,28 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
body = await request.json()
except JSONDecodeError:
return None, resp.body_not_json
sso = request.query.get("sso", "").lower() in truthy_strings
try:
username = body["username"]
password = body["password"]
except KeyError:
return None, resp.username_or_password_missing
if not sso:
return None, resp.username_or_password_missing
username = password = None
try:
base_url = server["url"]
except KeyError:
return None, resp.invalid_server
return AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in ("1", "true", "yes"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
), None
@ -102,7 +115,9 @@ async def register(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
if err is not None:
return err
if not req.secret:
if req.sso:
return resp.registration_no_sso
elif not req.secret:
return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register
res = await req.client.api.request(Method.GET, path)
@ -137,12 +152,40 @@ async def login(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
if err is not None:
return err
device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
if req.sso:
return await _do_sso(req)
else:
return await _do_login(req)
async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/")
/ "client/auth_external_sso/complete" / waiter_id)
sso_url = (req.client.api.base_url
.with_path(str(Path.login.sso.redirect))
.with_query({"redirectUrl": str(public_url)}))
sso_waiters[waiter_id] = req, get_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
password=req.password, device_id=f"maubot_{device_id}",
initial_device_display_name=req.device_name,
store_access_token=False)
if req.sso:
res = await req.client.login(token=login_token, login_type=LoginType.TOKEN,
device_id=device_id, store_access_token=False,
initial_device_display_name=req.device_name)
else:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
password=req.password, device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False)
except MatrixRequestError as e:
return web.json_response({
"errcode": e.errcode,
@ -155,3 +198,38 @@ async def login(request: web.Request) -> web.Response:
"device_id": res.device_id,
}, is_login=True)
return web.json_response(res.serialize())
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
async def wait_sso(request: web.Request) -> web.Response:
waiter_id = request.match_info["id"]
req, fut = sso_waiters[waiter_id]
try:
login_token = await fut
finally:
sso_waiters.pop(waiter_id, None)
return await _do_login(req, login_token)
@routes.get("/client/auth_external_sso/complete/{id}")
async def complete_sso(request: web.Request) -> web.Response:
try:
_, fut = sso_waiters[request.match_info["id"]]
except KeyError:
return web.Response(status=404, text="Invalid session ID\n")
if fut.cancelled():
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
elif fut.done():
return web.Response(status=200, text="The login token was already received\n")
try:
fut.set_result(request.query["loginToken"])
except KeyError:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n")

View File

@ -29,7 +29,12 @@ log = logging.getLogger("maubot.server")
@web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len(get_config()["server.base_path"]):]
if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs":
if (
subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/")
or subpath == "/features"
or subpath == "/logs"
):
return await handler(request)
err = check_token(request)
if err is not None:

View File

@ -194,6 +194,20 @@ class _Response:
"errcode": "registration_secret_not_found",
}, status=HTTPStatus.NOT_FOUND)
@property
def registration_no_sso(self) -> web.Response:
return web.json_response({
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
}, status=HTTPStatus.BAD_REQUEST)
@property
def sso_not_supported(self) -> web.Response:
return web.json_response({
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
}, status=HTTPStatus.FORBIDDEN)
@property
def plugin_has_no_database(self) -> web.Response:
return web.json_response({