diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index 97cbd89..8bd6534 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, - Dict, Tuple, Set) + Dict, Tuple, Set, Type) from abc import ABC, abstractmethod import asyncio import functools @@ -55,26 +55,27 @@ class CommandHandler: self.__mb_event_handler__: bool = True self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,) - self.__instance_vars: Dict[str, CommandHandler] = {} - self.__class_instance: Any = None - - def __copy__(self) -> 'CommandHandler': - new_ch = type(self)(self.__mb_func__) - keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match", - "require_subcommand", "arg_fallthrough", "event_handler", "event_type", "msgtypes"] - for key in keys: - key = f"__mb_{key}__" - setattr(new_ch, key, getattr(self, key)) - return new_ch + self.__bound_copies__: Dict[Any, CommandHandler] = {} + self.__bound_instance__: Any = None def __get__(self, instance, instancetype): + if not instance or self.__bound_instance__: + return self try: - return self.__instance_vars[instance] + return self.__bound_copies__[instance] except KeyError: - copy = self.__copy__() - copy.__class_instance = instance - self.__instance_vars[instance] = copy - return copy + new_ch = type(self)(self.__mb_func__) + keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match", + "require_subcommand", "arg_fallthrough", "event_handler", "event_type", + "msgtypes"] + for key in keys: + key = f"__mb_{key}__" + setattr(new_ch, key, getattr(self, key)) + new_ch.__bound_instance__ = instance + new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype) + for subcmd in self.__mb_subcommands__] + self.__bound_copies__[instance] = new_ch + return new_ch @staticmethod def __command_match_unset(self, val: str) -> str: @@ -108,15 +109,15 @@ class CommandHandler: await evt.reply(self.__mb_full_help__) return - if self.__class_instance: - return await self.__mb_func__(self.__class_instance, evt, **call_args) + if self.__bound_instance__: + return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(evt, **call_args) async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str) -> Tuple[bool, Any]: command, remaining_val = _split_in_two(remaining_val.strip(), " ") for subcommand in self.__mb_subcommands__: - if subcommand.__mb_is_command_match__(subcommand.__class_instance, command): + if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): return True, await subcommand(evt, _existing_args=call_args, remaining_val=remaining_val) return False, None @@ -156,7 +157,7 @@ class CommandHandler: @property def __mb_name__(self) -> str: - return self.__mb_get_name__(self.__class_instance) + return self.__mb_get_name__(self.__bound_instance__) @property def __mb_prefix__(self) -> str: diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index c720cc5..fbc7891 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -23,7 +23,6 @@ from sqlalchemy.engine.base import Engine from aiohttp import ClientSession if TYPE_CHECKING: - from mautrix.types import Event from mautrix.util.config import BaseProxyConfig from .client import MaubotMatrixClient