Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 92 additions & 35 deletions discord/ext/bridge/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
SlashCommandOptionType,
)

from ...utils import MISSING, find, get, warn_deprecated
from ...utils import MISSING, find, get
from ..commands import (
BadArgument,
)
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, func, **kwargs):
self.brief = kwargs.pop("brief", None)
super().__init__(func, **kwargs)

async def dispatch_error(
async def dispatch_error( # type: ignore
self, ctx: BridgeApplicationContext, error: Exception
) -> None:
await super().dispatch_error(ctx, error)
Expand All @@ -107,7 +107,7 @@ def __init__(self, func, **kwargs):
f"{option.annotation.__class__.__name__} is not supported in bridge commands. Use BridgeOption instead."
)

async def dispatch_error(self, ctx: BridgeExtContext, error: Exception) -> None:
async def dispatch_error(self, ctx: BridgeExtContext, error: Exception) -> None: # type: ignore
await super().dispatch_error(ctx, error)
ctx.bot.dispatch("bridge_command_error", ctx, error)

Expand All @@ -122,7 +122,7 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
class BridgeSlashGroup(SlashCommandGroup):
"""A subclass of :class:`.SlashCommandGroup` that is used for bridge commands."""

__slots__ = ("module",)
__slots__ = ("module",) # type: ignore

def __init__(self, callback, *args, **kwargs):
if perms := getattr(callback, "__default_member_permissions__", None):
Expand All @@ -132,21 +132,23 @@ def __init__(self, callback, *args, **kwargs):
self.__original_kwargs__["callback"] = callback
self.__command = None

async def _invoke(self, ctx: BridgeApplicationContext) -> None:
if not (options := ctx.interaction.data.get("options")):
async def _invoke(self, ctx: BridgeApplicationContext) -> None: # type: ignore
if not (
options := ctx.interaction.data and ctx.interaction.data.get("options")
):
if not self.__command:
self.__command = BridgeSlashCommand(self.callback)
ctx.command = self.__command
return await ctx.command.invoke(ctx)
option = options[0]
resolved = ctx.interaction.data.get("resolved", None)
resolved = ctx.interaction.data and ctx.interaction.data.get("resolved", None)
command = find(lambda x: x.name == option["name"], self.subcommands)
option["resolved"] = resolved
ctx.interaction.data = option
await command.invoke(ctx)
option["resolved"] = resolved # type: ignore
ctx.interaction.data = option # type: ignore
await command.invoke(ctx) # type: ignore


class BridgeExtGroup(BridgeExtCommand, Group):
class BridgeExtGroup(BridgeExtCommand, Group): # type: ignore
"""A subclass of :class:`.ext.commands.Group` that is used for bridge commands."""


Expand Down Expand Up @@ -175,11 +177,14 @@ class BridgeCommand:

__special_attrs__ = ["slash_variant", "ext_variant", "parent"]

def __init__(self, callback, **kwargs):
def __init__(self, callback, **kwargs: Any) -> None:
self.parent = kwargs.pop("parent", None)
self.slash_variant: BridgeSlashCommand = kwargs.pop(
"slash_variant", None
) or BridgeSlashCommand(callback, **kwargs)
) or BridgeSlashCommand(
callback, **kwargs
) # type: ignore

self.ext_variant: BridgeExtCommand = kwargs.pop(
"ext_variant", None
) or BridgeExtCommand(callback, **kwargs)
Expand Down Expand Up @@ -259,13 +264,17 @@ def add_to(self, bot: ExtBot) -> None:
bot.add_command(self.ext_variant)

async def invoke(
self, ctx: BridgeExtContext | BridgeApplicationContext, /, *args, **kwargs
):
self,
ctx: BridgeExtContext | BridgeApplicationContext,
/,
*args: Any,
**kwargs: Any,
) -> None:
if ctx.is_app:
return await self.slash_variant.invoke(ctx)
return await self.ext_variant.invoke(ctx)
return await self.slash_variant.invoke(ctx) # type: ignore
return await self.ext_variant.invoke(ctx) # type: ignore

def error(self, coro):
def error(self, coro: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that registers a coroutine as a local error handler.

This error handler is limited to the command it is defined to.
Expand All @@ -291,7 +300,7 @@ def error(self, coro):

return coro

def before_invoke(self, coro):
def before_invoke(self, coro: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that registers a coroutine as a pre-invoke hook.

This hook is called directly before the command is called, making
Expand All @@ -315,7 +324,7 @@ def before_invoke(self, coro):

return coro

def after_invoke(self, coro):
def after_invoke(self, coro: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that registers a coroutine as a post-invoke hook.

This hook is called directly after the command is called, making it
Expand Down Expand Up @@ -371,27 +380,31 @@ class BridgeCommandGroup(BridgeCommand):
"mapped",
]

ext_variant: BridgeExtGroup
slash_variant: BridgeSlashGroup
ext_variant: BridgeExtGroup # type: ignore
slash_variant: BridgeSlashGroup # type: ignore

def __init__(self, callback, *args, **kwargs):
ext_var = BridgeExtGroup(callback, *args, **kwargs)
def __init__(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
ext_var = kwargs.pop("ext_variant", BridgeExtGroup(callback, *args, **kwargs))
kwargs.update({"name": ext_var.name})
slash_var = kwargs.pop(
"slash_variant", BridgeSlashGroup(callback, *args, **kwargs)
)

super().__init__(
callback,
ext_variant=ext_var,
slash_variant=BridgeSlashGroup(callback, *args, **kwargs),
slash_variant=slash_var,
parent=kwargs.pop("parent", None),
)

self.subcommands: list[BridgeCommand] = []
self.subcommands: list[BridgeCommand | BridgeCommandGroup] = []

self.mapped: SlashCommand | None = None
if map_to := getattr(callback, "__custom_map_to__", None):
kwargs.update(map_to)
self.mapped = self.slash_variant.command(**kwargs)(callback)

def walk_commands(self) -> Iterator[BridgeCommand]:
def walk_commands(self) -> Iterator[BridgeCommand | BridgeCommandGroup]:
"""An iterator that recursively walks through all the bridge group's subcommands.

Yields
Expand Down Expand Up @@ -421,14 +434,57 @@ def wrap(callback):
**kwargs,
cls=BridgeExtCommand,
)(callback)

command = BridgeCommand(
callback, parent=self, slash_variant=slash, ext_variant=ext
callback,
parent=self,
slash_variant=slash,
ext_variant=ext,
)
self.subcommands.append(command)
return command

return wrap

def subgroup(
self,
*,
name: str = MISSING,
description: str | None = None,
**kwargs: Any,
) -> Callable[..., BridgeCommandGroup]:
"""A decorator to register a function as a subgroup.

Parameters
----------
kwargs: Optional[Dict[:class:`str`, Any]]
Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommandGroup` and :class:`.ext.commands.Group`)
"""

def wrap(callback: Callable[..., Any]) -> BridgeCommandGroup:
slash = self.slash_variant.create_subgroup(
name=name,
description=description,
**kwargs,
)

ext = self.ext_variant.group(
name=name,
description=description or "...",
**kwargs,
cls=BridgeExtGroup,
)(callback)
group = BridgeCommandGroup(
callback,
parent=self,
slash_variant=slash,
ext_variant=ext,
)
self.subcommands.append(group)
return group

return wrap


def bridge_command(**kwargs):
"""A decorator that is used to wrap a function as a bridge command.
Expand All @@ -455,7 +511,8 @@ def bridge_group(**kwargs):
"""

def decorator(callback):
return BridgeCommandGroup(callback, **kwargs)
name = kwargs.pop("name", callback.__name__)
return BridgeCommandGroup(callback, name=name, **kwargs)

return decorator

Expand Down Expand Up @@ -510,7 +567,7 @@ def predicate(func: Callable | ApplicationCommand):
if isinstance(func, ApplicationCommand):
func.guild_only = True
else:
func.__guild_only__ = True
func.__guild_only__ = True # type: ignore

from ..commands import guild_only

Expand All @@ -534,7 +591,7 @@ def predicate(func: Callable | ApplicationCommand):
if isinstance(func, ApplicationCommand):
func.nsfw = True
else:
func.__nsfw__ = True
func.__nsfw__ = True # type: ignore

from ..commands import is_nsfw

Expand Down Expand Up @@ -565,7 +622,7 @@ def predicate(func: Callable | ApplicationCommand):
if isinstance(func, ApplicationCommand):
func.default_member_permissions = _perms
else:
func.__default_member_permissions__ = _perms
func.__default_member_permissions__ = _perms # type: ignore

return func

Expand All @@ -591,7 +648,7 @@ async def convert(self, ctx, argument):


class AttachmentConverter(Converter):
async def convert(self, ctx: Context, arg: str):
async def convert(self, ctx: Context, arg: str): # type: ignore
try:
attach = ctx.message.attachments[0]
except IndexError:
Expand All @@ -601,7 +658,7 @@ async def convert(self, ctx: Context, arg: str):


class BooleanConverter(Converter):
async def convert(self, ctx, arg: bool):
async def convert(self, ctx, arg: bool): # type: ignore
return _convert_to_bool(str(arg))


Expand Down Expand Up @@ -633,7 +690,7 @@ def __init__(self, input_type, *args, **kwargs):
async def convert(self, ctx, argument: str) -> Any:
try:
if self.converter is not None:
converted = await self.converter().convert(ctx, argument)
converted = await self.converter().convert(ctx, argument) # type: ignore
else:
converter = BRIDGE_CONVERTER_MAPPING.get(self.input_type)
if isinstance(converter, type) and issubclass(converter, Converter):
Expand Down
Loading