More changes

This commit is contained in:
Tulir Asokan 2018-10-16 16:41:02 +03:00
parent 0b246e44a8
commit eef052b1e9
9 changed files with 195 additions and 61 deletions

View File

@ -10,8 +10,9 @@ plugin_directories:
- ./plugins - ./plugins
server: server:
# The IP:port to listen to. # The IP and port to listen to.
listen: 0.0.0.0:29316 hostname: 0.0.0.0
port: 29316
# The base management API path. # The base management API path.
base_path: /_matrix/maubot base_path: /_matrix/maubot
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.

View File

@ -17,9 +17,14 @@ from sqlalchemy import orm
import sqlalchemy as sql import sqlalchemy as sql
import logging.config import logging.config
import argparse import argparse
import asyncio
import copy import copy
import sys
from .config import Config from .config import Config
from .db import Base, init as init_db
from .server import MaubotServer
from .client import Client, init as init_client
from .__meta__ import __version__ from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@ -36,7 +41,23 @@ logging.config.dictConfig(copy.deepcopy(config["logging"]))
log = logging.getLogger("maubot") log = logging.getLogger("maubot")
log.debug(f"Initializing maubot {__version__}") log.debug(f"Initializing maubot {__version__}")
db_engine = sql.create_engine(config["database"]) db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = orm.sessionmaker(bind=db_engine) db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory) db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind=db_engine Base.metadata.bind=db_engine
loop = asyncio.get_event_loop()
init_db(db_session)
init_client(loop)
server = MaubotServer(config, loop)
try:
loop.run_until_complete(server.start())
loop.run_forever()
except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...")
for client in Client.cache.values():
client.stop()
loop.run_until_complete(server.stop())
sys.exit(0)

View File

@ -13,62 +13,21 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Union, Callable from typing import Dict, List, Optional
from aiohttp import ClientSession from aiohttp import ClientSession
import asyncio import asyncio
import logging import logging
from mautrix import Client as MatrixClient from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType, MessageEvent)
from .command_spec import ParsedCommand
from .db import DBClient from .db import DBClient
from .matrix import MaubotMatrixClient
log = logging.getLogger("maubot.client") log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client: class Client:
loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
@ -78,26 +37,33 @@ class Client:
def __init__(self, db_instance: DBClient) -> None: def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance self.db_instance = db_instance
self.cache[self.id] = self self.cache[self.id] = self
self.client = MaubotMatrixClient(maubot_client=self, self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
store=self.db_instance, mxid=self.id, base_url=self.homeserver,
mxid=self.id, token=self.access_token, client_session=self.http_client,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id)) log=log.getChild(self.id))
if self.autojoin: if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self.client.start(), loop=self.loop)
def stop(self) -> None:
self.client.stop()
@classmethod @classmethod
def get(cls, user_id: UserID) -> Optional['Client']: def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try: try:
return cls.cache[user_id] return cls.cache[user_id]
except KeyError: except KeyError:
db_instance = DBClient.query.get(user_id) db_instance = db_instance or DBClient.query.get(user_id)
if not db_instance: if not db_instance:
return None return None
return Client(db_instance) return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
# region Properties # region Properties
@property @property
@ -176,3 +142,9 @@ class Client:
async def _handle_invite(self, evt: StateEvent) -> None: async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id) await self.client.join_room_by_id(evt.room_id)
def init(loop: asyncio.AbstractEventLoop) -> None:
Client.loop = loop
for client in Client.all():
client.start()

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator) from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
from sqlalchemy.orm import Query from sqlalchemy.orm import Query, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import json import json
@ -89,3 +89,9 @@ class DBCommandSpec(Base):
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"), ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) primary_key=True)
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False) spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
def init(session: scoped_session) -> None:
DBPlugin.query = session.query_property()
DBClient.query = session.query_property()
DBCommandSpec.query = session.query_property()

View File

@ -13,14 +13,17 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type from typing import TypeVar, Type, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..plugin_base import Plugin from ..plugin_base import Plugin
PluginClass = TypeVar("PluginClass", bound=Plugin) PluginClass = TypeVar("PluginClass", bound=Plugin)
class PluginLoader(ABC): class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
id: str id: str
version: str version: str

View File

@ -29,7 +29,6 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader): class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {} path_cache: Dict[str, 'ZippedPluginLoader'] = {}
id_cache: Dict[str, 'ZippedPluginLoader'] = {}
path: str path: str
id: str id: str

77
maubot/matrix.py Normal file
View File

@ -0,0 +1,77 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Union, Callable
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import EventType, MessageEvent
from .command_spec import ParsedCommand, CommandSpec
class MaubotMatrixClient(MatrixClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.command_specs: Dict[str, CommandSpec] = {}
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
self.command_specs[plugin_id] = spec
self._reparse_command_specs()
def _reparse_command_specs(self) -> None:
self.commands = [parsed_command
for spec in self.command_specs.values()
for parsed_command in spec.parse()]
def remove_command_spec(self, plugin_id: str) -> None:
try:
del self.command_specs[plugin_id]
self._reparse_command_specs()
except KeyError:
pass
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
class Plugin(ABC): class Plugin(ABC):
def __init__(self, client: 'MaubotMatrixClient') -> None: def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
self.client = client self.client = client
self.id = plugin_instance_id
def set_command_spec(self, spec: 'CommandSpec') -> None: def set_command_spec(self, spec: 'CommandSpec') -> None:
pass self.client.set_command_spec(self.id, spec)
async def start(self) -> None: async def start(self) -> None:
pass pass

54
maubot/server.py Normal file
View File

@ -0,0 +1,54 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
import asyncio
from mautrix.api import PathBuilder
from .config import Config
from .__meta__ import __version__
class MaubotServer:
def __init__(self, config: Config, loop: asyncio.AbstractEventLoop):
self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop)
self.config = config
path = PathBuilder(config["server.base_path"])
self.app.router.add_get(path.version, self.version)
as_path = PathBuilder(config["server.appservice_base_path"])
self.app.router.add_put(as_path.transactions, self.handle_transaction)
self.runner = web.AppRunner(self.app)
async def start(self) -> None:
await self.runner.setup()
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
await site.start()
async def stop(self) -> None:
await self.runner.cleanup()
@staticmethod
async def version(_: web.Request) -> web.Response:
return web.json_response({
"version": __version__
})
async def handle_transaction(self, request: web.Request) -> web.Response:
return web.Response(status=501)