From 6969b2e273bfd083a062d737ac6b2c8ca72674a9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Oct 2018 22:51:34 +0300 Subject: [PATCH] Add db and client stuff --- maubot/client.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++ maubot/db.py | 65 +++++++++++++++++++++++++- 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 maubot/client.py diff --git a/maubot/client.py b/maubot/client.py new file mode 100644 index 0000000..a5c4521 --- /dev/null +++ b/maubot/client.py @@ -0,0 +1,116 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2018 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Dict, Optional +from aiohttp import ClientSession +import logging + +from mautrix import ClientAPI +from mautrix.types import UserID, SyncToken, FilterID, ContentURI + +from .db import DBClient + +log = logging.getLogger("maubot.client") + + +class Client: + cache: Dict[UserID, 'Client'] = {} + http_client: ClientSession = None + + def __init__(self, db_instance: DBClient) -> None: + self.db_instance: DBClient = db_instance + self.cache[self.id] = self + self.client: ClientAPI = ClientAPI(mxid=self.id, + base_url=self.homeserver, + token=self.access_token, + client_session=self.http_client, + log=log.getChild(self.id)) + + @classmethod + def get(cls, id: UserID) -> Optional['Client']: + try: + return cls.cache[id] + except KeyError: + db_instance = DBClient.query.get(id) + if not db_instance: + return None + return Client(db_instance) + + # region Properties + @property + def id(self) -> UserID: + return self.db_instance.id + + @property + def homeserver(self) -> str: + return self.db_instance.id + + @property + def access_token(self) -> str: + return self.db_instance.access_token + + @access_token.setter + def access_token(self, value: str) -> None: + self.db_instance.access_token = value + + @property + def next_batch(self) -> SyncToken: + return self.db_instance.next_batch + + @next_batch.setter + def next_batch(self, value: SyncToken) -> None: + self.db_instance.next_batch = value + + @property + def filter_id(self) -> FilterID: + return self.db_instance.filter_id + + @filter_id.setter + def filter_id(self, value: FilterID) -> None: + self.db_instance.filter_id = value + + @property + def sync(self) -> bool: + return self.db_instance.sync + + @sync.setter + def sync(self, value: bool) -> None: + self.db_instance.sync = value + + @property + def autojoin(self) -> bool: + return self.db_instance.autojoin + + @autojoin.setter + def autojoin(self, value: bool) -> None: + self.db_instance.autojoin = value + + @property + def displayname(self) -> str: + return self.db_instance.displayname + + @displayname.setter + def displayname(self, value: str) -> None: + self.db_instance.displayname = value + + @property + def avatar_url(self) -> ContentURI: + return self.db_instance.avatar_url + + @avatar_url.setter + def avatar_url(self, value: ContentURI) -> None: + self.db_instance.avatar_url = value + + # endregion diff --git a/maubot/db.py b/maubot/db.py index 93b7c45..812cc3e 100644 --- a/maubot/db.py +++ b/maubot/db.py @@ -13,8 +13,71 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy import (Column) +from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator) +from sqlalchemy.orm import Query from sqlalchemy.ext.declarative import declarative_base +import json + +from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI Base: declarative_base = declarative_base() + +class JSONEncodedDict(TypeDecorator): + impl = Text + + @property + def python_type(self): + return dict + + def process_literal_param(self, value, _): + return json.dumps(value) if value is not None else None + + def process_bind_param(self, value, _): + return json.dumps(value) if value is not None else None + + def process_result_value(self, value, _): + return json.loads(value) if value is not None else None + + +class DBPlugin(Base): + query: Query + __tablename__ = "plugin" + + id: str = Column(String(255), primary_key=True) + type: str = Column(String(255), nullable=False) + enabled: bool = Column(Boolean, nullable=False, default=False) + primary_user: str = Column(String(255), + ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), + nullable=False) + + +class DBClient(Base): + query: Query + __tablename__ = "client" + + id: UserID = Column(String(255), primary_key=True) + homeserver: str = Column(String(255), nullable=False) + access_token: str = Column(String(255), nullable=False) + + next_batch: SyncToken = Column(String(255), nullable=False, default="") + filter_id: FilterID = Column(String(255), nullable=False, default="") + + sync: bool = Column(Boolean, nullable=False, default=True) + autojoin: bool = Column(Boolean, nullable=False, default=True) + + displayname: str = Column(String(255), nullable=False, default="") + avatar_url: ContentURI = Column(String(255), nullable=False, default="") + + +class DBCommandSpec(Base): + query: Query + __tablename__ = "command_spec" + + owner: str = Column(String(255), + ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"), + primary_key=True) + client: UserID = Column(String(255), + ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"), + primary_key=True) + spec: JSON = Column(JSONEncodedDict, nullable=False)