maubot/maubot/client.py

171 lines
5.4 KiB
Python
Raw Normal View History

2018-10-10 19:51:34 +00:00
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
2018-10-16 13:41:02 +00:00
from typing import Dict, List, Optional
2018-10-10 19:51:34 +00:00
from aiohttp import ClientSession
2018-10-15 21:25:23 +00:00
import asyncio
2018-10-10 19:51:34 +00:00
import logging
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
2018-10-17 20:43:56 +00:00
EventType, Filter, RoomFilter, RoomEventFilter)
2018-10-10 19:51:34 +00:00
from .db import DBClient
2018-10-16 13:41:02 +00:00
from .matrix import MaubotMatrixClient
2018-10-10 19:51:34 +00:00
log = logging.getLogger("maubot.client")
class Client:
2018-10-16 13:41:02 +00:00
loop: asyncio.AbstractEventLoop
2018-10-10 19:51:34 +00:00
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
2018-10-15 21:25:23 +00:00
db_instance: DBClient
client: MaubotMatrixClient
2018-10-10 19:51:34 +00:00
def __init__(self, db_instance: DBClient) -> None:
2018-10-15 21:25:23 +00:00
self.db_instance = db_instance
2018-10-10 19:51:34 +00:00
self.cache[self.id] = self
2018-10-16 22:30:08 +00:00
self.log = log.getChild(self.id)
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
2018-10-16 13:41:02 +00:00
token=self.access_token, client_session=self.http_client,
2018-10-16 22:30:08 +00:00
log=self.log, loop=self.loop, store=self.db_instance)
2018-10-12 21:30:05 +00:00
if self.autojoin:
2018-10-15 21:25:23 +00:00
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
2018-10-10 19:51:34 +00:00
2018-10-16 13:41:02 +00:00
def start(self) -> None:
2018-10-16 22:30:08 +00:00
asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
try:
2018-10-17 20:43:56 +00:00
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
),
),
))
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
2018-10-17 20:43:56 +00:00
await self.client.start(self.filter_id)
2018-10-16 22:30:08 +00:00
except Exception:
self.log.exception("starting raised exception")
2018-10-16 13:41:02 +00:00
def stop(self) -> None:
self.client.stop()
2018-10-10 19:51:34 +00:00
@classmethod
2018-10-16 13:41:02 +00:00
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
2018-10-10 19:51:34 +00:00
try:
2018-10-14 19:08:11 +00:00
return cls.cache[user_id]
2018-10-10 19:51:34 +00:00
except KeyError:
2018-10-16 13:41:02 +00:00
db_instance = db_instance or DBClient.query.get(user_id)
2018-10-10 19:51:34 +00:00
if not db_instance:
return None
return Client(db_instance)
2018-10-16 13:41:02 +00:00
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
2018-10-16 22:30:08 +00:00
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
2018-10-16 22:30:08 +00:00
2018-10-10 19:51:34 +00:00
# region Properties
2018-10-12 21:30:05 +00:00
2018-10-10 19:51:34 +00:00
@property
def id(self) -> UserID:
return self.db_instance.id
@property
def homeserver(self) -> str:
2018-10-16 22:30:08 +00:00
return self.db_instance.homeserver
2018-10-10 19:51:34 +00:00
@property
def access_token(self) -> str:
return self.db_instance.access_token
@access_token.setter
def access_token(self, value: str) -> None:
2018-10-12 21:30:05 +00:00
self.client.api.token = value
2018-10-10 19:51:34 +00:00
self.db_instance.access_token = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@next_batch.setter
def next_batch(self, value: SyncToken) -> None:
self.db_instance.next_batch = value
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@filter_id.setter
def filter_id(self, value: FilterID) -> None:
self.db_instance.filter_id = value
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
self.db_instance.sync = value
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
2018-10-12 21:30:05 +00:00
if value == self.db_instance.autojoin:
return
if value:
2018-10-15 21:25:23 +00:00
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
2018-10-12 21:30:05 +00:00
else:
2018-10-15 21:25:23 +00:00
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
2018-10-10 19:51:34 +00:00
self.db_instance.autojoin = value
@property
def displayname(self) -> str:
return self.db_instance.displayname
@displayname.setter
def displayname(self, value: str) -> None:
self.db_instance.displayname = value
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
@avatar_url.setter
def avatar_url(self, value: ContentURI) -> None:
self.db_instance.avatar_url = value
# endregion
2018-10-12 21:30:05 +00:00
2018-10-16 13:41:02 +00:00
def init(loop: asyncio.AbstractEventLoop) -> None:
2018-10-16 22:30:08 +00:00
Client.http_client = ClientSession(loop=loop)
2018-10-16 13:41:02 +00:00
Client.loop = loop
for client in Client.all():
client.start()