diff --git a/maubot/client.py b/maubot/client.py index a5c4521..3d73b3f 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -17,8 +17,9 @@ from typing import Dict, Optional from aiohttp import ClientSession import logging -from mautrix import ClientAPI -from mautrix.types import UserID, SyncToken, FilterID, ContentURI +from mautrix import Client as MatrixClient +from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, + EventType) from .db import DBClient @@ -32,11 +33,13 @@ class Client: def __init__(self, db_instance: DBClient) -> None: self.db_instance: DBClient = db_instance self.cache[self.id] = self - self.client: ClientAPI = ClientAPI(mxid=self.id, - base_url=self.homeserver, - token=self.access_token, - client_session=self.http_client, - log=log.getChild(self.id)) + self.client: MatrixClient = MatrixClient(mxid=self.id, + base_url=self.homeserver, + token=self.access_token, + client_session=self.http_client, + log=log.getChild(self.id)) + if self.autojoin: + self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) @classmethod def get(cls, id: UserID) -> Optional['Client']: @@ -49,6 +52,7 @@ class Client: return Client(db_instance) # region Properties + @property def id(self) -> UserID: return self.db_instance.id @@ -63,6 +67,7 @@ class Client: @access_token.setter def access_token(self, value: str) -> None: + self.client.api.token = value self.db_instance.access_token = value @property @@ -95,6 +100,12 @@ class Client: @autojoin.setter def autojoin(self, value: bool) -> None: + if value == self.db_instance.autojoin: + return + if value: + self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) + else: + self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER) self.db_instance.autojoin = value @property @@ -114,3 +125,7 @@ class Client: self.db_instance.avatar_url = value # endregion + + async def handle_invite(self, evt: StateEvent): + if evt.state_key == self.id and evt.content.membership == Membership.INVITE: + await self.client.join_room_by_id(evt.room_id)