Fix CommandHandler descriptor

This commit is contained in:
Tulir Asokan 2019-01-18 22:58:43 +02:00
parent 4ea980cb93
commit 8b5c637f76
2 changed files with 22 additions and 22 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set) Dict, Tuple, Set, Type)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import functools import functools
@ -55,26 +55,27 @@ class CommandHandler:
self.__mb_event_handler__: bool = True self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,) self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
self.__instance_vars: Dict[str, CommandHandler] = {} self.__bound_copies__: Dict[Any, CommandHandler] = {}
self.__class_instance: Any = None self.__bound_instance__: Any = None
def __copy__(self) -> 'CommandHandler':
new_ch = type(self)(self.__mb_func__)
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
"require_subcommand", "arg_fallthrough", "event_handler", "event_type", "msgtypes"]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
return new_ch
def __get__(self, instance, instancetype): def __get__(self, instance, instancetype):
if not instance or self.__bound_instance__:
return self
try: try:
return self.__instance_vars[instance] return self.__bound_copies__[instance]
except KeyError: except KeyError:
copy = self.__copy__() new_ch = type(self)(self.__mb_func__)
copy.__class_instance = instance keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
self.__instance_vars[instance] = copy "require_subcommand", "arg_fallthrough", "event_handler", "event_type",
return copy "msgtypes"]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
for subcmd in self.__mb_subcommands__]
self.__bound_copies__[instance] = new_ch
return new_ch
@staticmethod @staticmethod
def __command_match_unset(self, val: str) -> str: def __command_match_unset(self, val: str) -> str:
@ -108,15 +109,15 @@ class CommandHandler:
await evt.reply(self.__mb_full_help__) await evt.reply(self.__mb_full_help__)
return return
if self.__class_instance: if self.__bound_instance__:
return await self.__mb_func__(self.__class_instance, evt, **call_args) return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
return await self.__mb_func__(evt, **call_args) return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]: remaining_val: str) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ") command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__: for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command): if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand(evt, _existing_args=call_args, return True, await subcommand(evt, _existing_args=call_args,
remaining_val=remaining_val) remaining_val=remaining_val)
return False, None return False, None
@ -156,7 +157,7 @@ class CommandHandler:
@property @property
def __mb_name__(self) -> str: def __mb_name__(self) -> str:
return self.__mb_get_name__(self.__class_instance) return self.__mb_get_name__(self.__bound_instance__)
@property @property
def __mb_prefix__(self) -> str: def __mb_prefix__(self) -> str:

View File

@ -23,7 +23,6 @@ from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession from aiohttp import ClientSession
if TYPE_CHECKING: if TYPE_CHECKING:
from mautrix.types import Event
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient from .client import MaubotMatrixClient