From d2fbf3e937e17405b8adb4d1671c3b05e8c772ec Mon Sep 17 00:00:00 2001 From: Soheab <33902984+Soheab@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:54:30 +0100 Subject: [PATCH 1/4] Pushing to testing --- discord/commands/_options.py | 556 ++++++++++++++++ discord/commands/core.py | 245 ++----- discord/commands/options.py | 955 +++++++++++++++------------- intro_typing.py | 526 +++++++++++++++ test_exts/options_showcase.py | 194 ++++++ tests/test_slash_command_options.py | 195 ++++++ 6 files changed, 2028 insertions(+), 643 deletions(-) create mode 100644 discord/commands/_options.py create mode 100644 intro_typing.py create mode 100644 test_exts/options_showcase.py create mode 100644 tests/test_slash_command_options.py diff --git a/discord/commands/_options.py b/discord/commands/_options.py new file mode 100644 index 0000000000..1613630bc0 --- /dev/null +++ b/discord/commands/_options.py @@ -0,0 +1,556 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import inspect +import logging +import sys +import types +from collections.abc import Awaitable, Callable, Iterable +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, + Type, + TypeVar, + Union, + get_args, +) + +if sys.version_info >= (3, 12): + from typing import TypeAliasType +else: + from typing_extensions import TypeAliasType + +from ..abc import GuildChannel, Mentionable +from ..channel import ( + CategoryChannel, + DMChannel, + ForumChannel, + MediaChannel, + StageChannel, + TextChannel, + Thread, + VoiceChannel, +) +from ..commands import ApplicationContext, AutocompleteContext +from ..enums import ChannelType +from ..enums import Enum as DiscordEnum +from ..enums import SlashCommandOptionType +from ..utils import MISSING, basic_autocomplete + +if TYPE_CHECKING: + from ..cog import Cog + from ..ext.commands import Converter + from ..member import Member + from ..message import Attachment + from ..role import Role + from ..user import User + + InputType = Union[ + Type[str], + Type[bool], + Type[int], + Type[float], + Type[GuildChannel], + Type[Thread], + Type[Member], + Type[User], + Type[Attachment], + Type[Role], + Type[Mentionable], + SlashCommandOptionType, + Converter, + Type[Converter], + Type[Enum], + Type[DiscordEnum], + ] + + AutocompleteReturnType = Union[ + Iterable["OptionChoice"], Iterable[str], Iterable[int], Iterable[float] + ] + T = TypeVar("T", bound=AutocompleteReturnType) + MaybeAwaitable = Union[T, Awaitable[T]] + AutocompleteFunction = Union[ + Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], + Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], + Callable[ + [AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + MaybeAwaitable[AutocompleteReturnType], + ], + Callable[ + [Cog, AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + MaybeAwaitable[AutocompleteReturnType], + ], + ] + + +__all__ = ( + "ThreadOption", + "Option", + "OptionChoice", + "option", +) + +CHANNEL_TYPE_MAP = { + TextChannel: ChannelType.text, + VoiceChannel: ChannelType.voice, + StageChannel: ChannelType.stage_voice, + CategoryChannel: ChannelType.category, + Thread: ChannelType.public_thread, + ForumChannel: ChannelType.forum, + MediaChannel: ChannelType.media, + DMChannel: ChannelType.private, +} + +_log = logging.getLogger(__name__) + + +class ThreadOption: + """Represents a class that can be passed as the ``input_type`` for an :class:`Option` class. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_type: Literal["public", "private", "news"] + The thread type to expect for this options input. + """ + + def __init__(self, thread_type: Literal["public", "private", "news"]): + type_map = { + "public": ChannelType.public_thread, + "private": ChannelType.private_thread, + "news": ChannelType.news_thread, + } + self._type = type_map[thread_type] + + +class Option: + """Represents a selectable option for a slash command. + + Attributes + ---------- + input_type: Union[Type[:class:`str`], Type[:class:`bool`], Type[:class:`int`], Type[:class:`float`], Type[:class:`.abc.GuildChannel`], Type[:class:`Thread`], Type[:class:`Member`], Type[:class:`User`], Type[:class:`Attachment`], Type[:class:`Role`], Type[:class:`.abc.Mentionable`], :class:`SlashCommandOptionType`, Type[:class:`.ext.commands.Converter`], Type[:class:`enums.Enum`], Type[:class:`Enum`]] + The type of input that is expected for this option. This can be a :class:`SlashCommandOptionType`, + an associated class, a channel type, a :class:`Converter`, a converter class or an :class:`enum.Enum`. + If a :class:`enum.Enum` is used and it has up to 25 values, :attr:`choices` will be automatically filled. If the :class:`enum.Enum` has more than 25 values, :attr:`autocomplete` will be implemented with :func:`discord.utils.basic_autocomplete` instead. + name: :class:`str` + The name of this option visible in the UI. + Inherits from the variable name if not provided as a parameter. + description: Optional[:class:`str`] + The description of this option. + Must be 100 characters or fewer. If :attr:`input_type` is a :class:`enum.Enum` and :attr:`description` is not specified, :attr:`input_type`'s docstring will be used. + choices: Optional[List[Union[:class:`Any`, :class:`OptionChoice`]]] + The list of available choices for this option. + Can be a list of values or :class:`OptionChoice` objects (which represent a name:value pair). + If provided, the input from the user must match one of the choices in the list. + required: Optional[:class:`bool`] + Whether this option is required. + default: Optional[:class:`Any`] + The default value for this option. If provided, ``required`` will be considered ``False``. + min_value: Optional[:class:`int`] + The minimum value that can be entered. + Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. + max_value: Optional[:class:`int`] + The maximum value that can be entered. + Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. + min_length: Optional[:class:`int`] + The minimum length of the string that can be entered. Must be between 0 and 6000 (inclusive). + Only applies to Options with an :attr:`input_type` of :class:`str`. + max_length: Optional[:class:`int`] + The maximum length of the string that can be entered. Must be between 1 and 6000 (inclusive). + Only applies to Options with an :attr:`input_type` of :class:`str`. + channel_types: list[:class:`discord.ChannelType`] | None + A list of channel types that can be selected in this option. + Only applies to Options with an :attr:`input_type` of :class:`discord.SlashCommandOptionType.channel`. + If this argument is used, :attr:`input_type` will be ignored. + name_localizations: Dict[:class:`str`, :class:`str`] + The name localizations for this option. The values of this should be ``"locale": "name"``. + See `here `_ for a list of valid locales. + description_localizations: Dict[:class:`str`, :class:`str`] + The description localizations for this option. The values of this should be ``"locale": "description"``. + See `here `_ for a list of valid locales. + + Examples + -------- + Basic usage: :: + + @bot.slash_command(guild_ids=[...]) + async def hello( + ctx: discord.ApplicationContext, + name: Option(str, "Enter your name"), + age: Option(int, "Enter your age", min_value=1, max_value=99, default=18) + # passing the default value makes an argument optional + # you also can create optional argument using: + # age: Option(int, "Enter your age") = 18 + ): + await ctx.respond(f"Hello! Your name is {name} and you are {age} years old.") + + .. versionadded:: 2.0 + """ + + input_type: SlashCommandOptionType + converter: Converter | type[Converter] | None = None + + def __init__( + self, input_type: InputType = str, /, description: str | None = None, **kwargs + ) -> None: + self.name: str | None = kwargs.pop("name", None) + if self.name is not None: + self.name = str(self.name) + self._parameter_name = self.name # default + input_type = self._parse_type_alias(input_type) + input_type = self._strip_none_type(input_type) + self._raw_type: InputType | tuple = input_type + + enum_choices = [] + input_type_is_class = isinstance(input_type, type) + if input_type_is_class and issubclass(input_type, (Enum, DiscordEnum)): + if description is None and input_type.__doc__ is not None: + description = inspect.cleandoc(input_type.__doc__) + if description and len(description) > 100: + description = description[:97] + "..." + _log.warning( + "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", + self.name, + input_type, + ) + enum_choices = [OptionChoice(e.name, e.value) for e in input_type] + value_class = enum_choices[0].value.__class__ + if value_class in SlashCommandOptionType.__members__ and all( + isinstance(elem.value, value_class) for elem in enum_choices + ): + input_type = SlashCommandOptionType.from_datatype( + enum_choices[0].value.__class__ + ) + else: + enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] + input_type = SlashCommandOptionType.string + + self.description = description or "No description provided" + self.channel_types: list[ChannelType] = kwargs.pop("channel_types", []) + + if self.channel_types: + self.input_type = SlashCommandOptionType.channel + elif isinstance(input_type, SlashCommandOptionType): + self.input_type = input_type + else: + from ..ext.commands import Converter + + if isinstance(input_type, tuple) and any( + issubclass(op, ApplicationContext) for op in input_type + ): + input_type = next( + op for op in input_type if issubclass(op, ApplicationContext) + ) + + if ( + isinstance(input_type, Converter) + or input_type_is_class + and issubclass(input_type, Converter) + ): + self.converter = input_type + self._raw_type = str + self.input_type = SlashCommandOptionType.string + else: + try: + self.input_type = SlashCommandOptionType.from_datatype(input_type) + except TypeError as exc: + from ..ext.commands.converter import CONVERTER_MAPPING + + if input_type not in CONVERTER_MAPPING: + raise exc + self.converter = CONVERTER_MAPPING[input_type] + self._raw_type = str + self.input_type = SlashCommandOptionType.string + else: + if self.input_type == SlashCommandOptionType.channel: + if not isinstance(self._raw_type, tuple): + if hasattr(input_type, "__args__"): + self._raw_type = input_type.__args__ # type: ignore # Union.__args__ + else: + self._raw_type = (input_type,) + if not self.channel_types: + self.channel_types = [ + CHANNEL_TYPE_MAP[t] + for t in self._raw_type + if t is not GuildChannel + ] + self.required: bool = ( + kwargs.pop("required", True) if "default" not in kwargs else False + ) + self.default = kwargs.pop("default", None) + + self._autocomplete: AutocompleteFunction | None = None + self.autocomplete = kwargs.pop("autocomplete", None) + if len(enum_choices) > 25: + self.choices: list[OptionChoice] = [] + for e in enum_choices: + e.value = str(e.value) + self.autocomplete = basic_autocomplete(enum_choices) + self.input_type = SlashCommandOptionType.string + else: + self.choices: list[OptionChoice] = enum_choices or [ + o if isinstance(o, OptionChoice) else OptionChoice(o) + for o in kwargs.pop("choices", []) + ] + + if self.input_type == SlashCommandOptionType.integer: + minmax_types = (int, type(None)) + minmax_typehint = Optional[int] + elif self.input_type == SlashCommandOptionType.number: + minmax_types = (int, float, type(None)) + minmax_typehint = Optional[Union[int, float]] + else: + minmax_types = (type(None),) + minmax_typehint = type(None) + + if self.input_type == SlashCommandOptionType.string: + minmax_length_types = (int, type(None)) + minmax_length_typehint = Optional[int] + else: + minmax_length_types = (type(None),) + minmax_length_typehint = type(None) + + self.min_value: int | float | None = kwargs.pop("min_value", None) + self.max_value: int | float | None = kwargs.pop("max_value", None) + self.min_length: int | None = kwargs.pop("min_length", None) + self.max_length: int | None = kwargs.pop("max_length", None) + + if ( + self.input_type != SlashCommandOptionType.integer + and self.input_type != SlashCommandOptionType.number + and (self.min_value or self.max_value) + ): + raise AttributeError( + "Option does not take min_value or max_value if not of type " + "SlashCommandOptionType.integer or SlashCommandOptionType.number" + ) + if self.input_type != SlashCommandOptionType.string and ( + self.min_length or self.max_length + ): + raise AttributeError( + "Option does not take min_length or max_length if not of type str" + ) + + if self.min_value is not None and not isinstance(self.min_value, minmax_types): + raise TypeError( + f"Expected {minmax_typehint} for min_value, got" + f' "{type(self.min_value).__name__}"' + ) + if self.max_value is not None and not isinstance(self.max_value, minmax_types): + raise TypeError( + f"Expected {minmax_typehint} for max_value, got" + f' "{type(self.max_value).__name__}"' + ) + + if self.min_length is not None: + if not isinstance(self.min_length, minmax_length_types): + raise TypeError( + f"Expected {minmax_length_typehint} for min_length," + f' got "{type(self.min_length).__name__}"' + ) + if self.min_length < 0 or self.min_length > 6000: + raise AttributeError( + "min_length must be between 0 and 6000 (inclusive)" + ) + if self.max_length is not None: + if not isinstance(self.max_length, minmax_length_types): + raise TypeError( + f"Expected {minmax_length_typehint} for max_length," + f' got "{type(self.max_length).__name__}"' + ) + if self.max_length < 1 or self.max_length > 6000: + raise AttributeError("max_length must between 1 and 6000 (inclusive)") + + self.name_localizations = kwargs.pop("name_localizations", MISSING) + self.description_localizations = kwargs.pop( + "description_localizations", MISSING + ) + + if input_type is None: + raise TypeError("input_type cannot be NoneType.") + + @staticmethod + def _parse_type_alias(input_type: InputType) -> InputType: + if isinstance(input_type, TypeAliasType): + return input_type.__value__ + return input_type + + @staticmethod + def _strip_none_type(input_type): + if isinstance(input_type, SlashCommandOptionType): + return input_type + + if input_type is type(None): + raise TypeError("Option type cannot be only NoneType") + + args = () + if isinstance(input_type, types.UnionType): + args = get_args(input_type) + elif getattr(input_type, "__origin__", None) is Union: + args = get_args(input_type) + elif isinstance(input_type, tuple): + args = input_type + + if args: + filtered = tuple(t for t in args if t is not type(None)) + if not filtered: + raise TypeError("Option type cannot be only NoneType") + if len(filtered) == 1: + return filtered[0] + + return filtered + + return input_type + + def to_dict(self) -> dict: + as_dict = { + "name": self.name, + "description": self.description, + "type": self.input_type.value, + "required": self.required, + "choices": [c.to_dict() for c in self.choices], + "autocomplete": bool(self.autocomplete), + } + if self.name_localizations is not MISSING: + as_dict["name_localizations"] = self.name_localizations + if self.description_localizations is not MISSING: + as_dict["description_localizations"] = self.description_localizations + if self.channel_types: + as_dict["channel_types"] = [t.value for t in self.channel_types] + if self.min_value is not None: + as_dict["min_value"] = self.min_value + if self.max_value is not None: + as_dict["max_value"] = self.max_value + if self.min_length is not None: + as_dict["min_length"] = self.min_length + if self.max_length is not None: + as_dict["max_length"] = self.max_length + + return as_dict + + def __repr__(self): + return f"" + + @property + def autocomplete(self) -> AutocompleteFunction | None: + """ + The autocomplete handler for the option. Accepts a callable (sync or async) + that takes a single required argument of :class:`AutocompleteContext` or two arguments + of :class:`discord.Cog` (being the command's cog) and :class:`AutocompleteContext`. + The callable must return an iterable of :class:`str` or :class:`OptionChoice`. + Alternatively, :func:`discord.utils.basic_autocomplete` may be used in place of the callable. + + Returns + ------- + Optional[AutocompleteFunction] + + .. versionchanged:: 2.7 + + .. note:: + Does not validate the input value against the autocomplete results. + """ + return self._autocomplete + + @autocomplete.setter + def autocomplete(self, value: AutocompleteFunction | None) -> None: + self._autocomplete = value + # this is done here so it does not have to be computed every time the autocomplete is invoked + if self._autocomplete is not None: + self._autocomplete._is_instance_method = ( # pyright: ignore [reportFunctionMemberAccess] + sum( + 1 + for param in inspect.signature( + self._autocomplete + ).parameters.values() + if param.default == param.empty # pyright: ignore[reportAny] + and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + ) + == 2 + ) + + +class OptionChoice: + """ + Represents a name:value pairing for a selected :class:`.Option`. + + .. versionadded:: 2.0 + + Attributes + ---------- + name: :class:`str` + The name of the choice. Shown in the UI when selecting an option. + value: Optional[Union[:class:`str`, :class:`int`, :class:`float`]] + The value of the choice. If not provided, will use the value of ``name``. + name_localizations: Dict[:class:`str`, :class:`str`] + The name localizations for this choice. The values of this should be ``"locale": "name"``. + See `here `_ for a list of valid locales. + """ + + def __init__( + self, + name: str, + value: str | int | float | None = None, + name_localizations: dict[str, str] = MISSING, + ): + self.name = str(name) + self.value = value if value is not None else name + self.name_localizations = name_localizations + + def to_dict(self) -> dict[str, str | int | float]: + as_dict = {"name": self.name, "value": self.value} + if self.name_localizations is not MISSING: + as_dict["name_localizations"] = self.name_localizations + + return as_dict + + +def option(name, input_type=None, **kwargs): + """A decorator that can be used instead of typehinting :class:`.Option`. + + .. versionadded:: 2.0 + + Attributes + ---------- + parameter_name: :class:`str` + The name of the target function parameter this option is mapped to. + This allows you to have a separate UI ``name`` and parameter name. + """ + + def decorator(func): + resolved_name = kwargs.pop("parameter_name", None) or name + itype = ( + kwargs.pop("type", None) + or input_type + or func.__annotations__.get(resolved_name, str) + ) + func.__annotations__[resolved_name] = Option(itype, name=name, **kwargs) + return func + + return decorator diff --git a/discord/commands/core.py b/discord/commands/core.py index 499e326a61..48e0b757d3 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -50,9 +50,7 @@ from ..enums import ( IntegrationType, InteractionContextType, - MessageType, SlashCommandOptionType, - try_enum, ) from ..errors import ( ApplicationCommandError, @@ -62,7 +60,6 @@ InvalidArgument, ValidationError, ) -from ..member import Member from ..message import Attachment, Message from ..object import Object from ..role import Role @@ -70,12 +67,8 @@ from ..user import User from ..utils import MISSING, async_all, find, maybe_coroutine, utcnow, warn_deprecated from .context import ApplicationContext, AutocompleteContext -from .options import Option, OptionChoice +from .options import Option, OptionChoice, _get_options -if sys.version_info >= (3, 11): - from typing import Annotated, Literal, get_args, get_origin -else: - from typing_extensions import Annotated, Literal, get_args, get_origin __all__ = ( "_BaseCommand", @@ -759,9 +752,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.attached_to_group: bool = False - self._options_kwargs = kwargs.get("options", []) - self.options: list[Option] = [] - self._validate_parameters() + self.options: list[Option] = kwargs.get("options", []) or self._parse_options() try: checks = func.__commands_checks__ @@ -774,13 +765,6 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self._before_invoke = None self._after_invoke = None - def _validate_parameters(self): - params = self._get_signature_parameters() - if kwop := self._options_kwargs: - self.options = self._match_option_param_names(params, kwop) - else: - self.options = self._parse_options(params) - def _check_required_params(self, params): params = iter(params.items()) required_params = ( @@ -796,146 +780,41 @@ def _check_required_params(self, params): return params - def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: - if check_params: - params = self._check_required_params(params) - else: - params = iter(params.items()) - - final_options = [] - for p_name, p_obj in params: - option = p_obj.annotation - if option == inspect.Parameter.empty: - option = str - - option = Option._strip_none_type(option) - if self._is_typing_literal(option): - literal_values = get_args(option) - if not all(isinstance(v, (str, int, float)) for v in literal_values): - raise TypeError( - "Literal values for choices must be str, int, or float." - ) - - value_type = type(literal_values[0]) - if not all(isinstance(v, value_type) for v in literal_values): - raise TypeError( - "All Literal values for choices must be of the same type." - ) - - option = Option( - value_type, - choices=[ - OptionChoice(name=str(v), value=v) for v in literal_values - ], - ) - - if self._is_typing_annotated(option): - type_hint = get_args(option)[0] - metadata = option.__metadata__ - # If multiple Options in metadata, the first will be used. - option_gen = (elem for elem in metadata if isinstance(elem, Option)) - option = next(option_gen, Option()) - # Handle Optional - if self._is_typing_optional(type_hint): - option.input_type = SlashCommandOptionType.from_datatype( - get_args(type_hint)[0] - ) - option.default = None - else: - option.input_type = SlashCommandOptionType.from_datatype(type_hint) - - if self._is_typing_union(option): - if self._is_typing_optional(option): - option = Option(option.__args__[0], default=None) - else: - option = Option(option.__args__) - - if not isinstance(option, Option): - if isinstance(p_obj.default, Option): - if p_obj.default.input_type is None: - p_obj.default.input_type = SlashCommandOptionType.from_datatype( - option - ) - option = p_obj.default - else: - option = Option(option) - - if option.default is None and not p_obj.default == inspect.Parameter.empty: - if isinstance(p_obj.default, Option): - pass - elif isinstance(p_obj.default, type) and issubclass( - p_obj.default, (DiscordEnum, Enum) - ): - option = Option(p_obj.default) - else: - option.default = p_obj.default - option.required = False - if option.name is None: - option.name = p_name - if option.name != p_name or option._parameter_name is None: - option._parameter_name = p_name - - _validate_names(option) + def _parse_options(self) -> list[Option]: + final_options = _get_options( + self.callback, cog=type(self.cog) if self.cog else None + ) + for option in final_options.values(): + _validate_names(option._param_name) _validate_descriptions(option) - final_options.append(option) - - return final_options - - def _match_option_param_names(self, params, options): - options = list(options) - params = self._check_required_params(params) - - check_annotations: list[Callable[[Option, type], bool]] = [ - lambda o, a: o.input_type == SlashCommandOptionType.string - and o.converter is not None, # pass on converters - lambda o, a: isinstance( - o.input_type, SlashCommandOptionType - ), # pass on slash cmd option type enums - lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types - lambda o, a: self._is_typing_optional(a) - and not o.required - and o._raw_type in a.__args__, # optional - lambda o, a: isinstance(a, type) - and issubclass(a, o._raw_type), # 'normal' types - ] - for o in options: - _validate_names(o) - _validate_descriptions(o) - try: - p_name, p_obj = next(params) - except StopIteration: # not enough params for all the options - raise ClientException("Too many arguments passed to the options kwarg.") - p_obj = p_obj.annotation - - if not any(check(o, p_obj) for check in check_annotations): - raise TypeError( - f"Parameter {p_name} does not match input type of {o.name}." - ) - o._parameter_name = p_name - - left_out_params = OrderedDict() - for k, v in params: - left_out_params[k] = v - options.extend(self._parse_options(left_out_params, check_params=False)) - - return options - - def _is_typing_union(self, annotation): - return getattr(annotation, "__origin__", None) is Union or type( - annotation - ) is getattr( - types, "UnionType", Union - ) # type: ignore - - def _is_typing_optional(self, annotation): - return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore - - def _is_typing_literal(self, annotation): - return get_origin(annotation) is Literal - - def _is_typing_annotated(self, annotation): - return get_origin(annotation) is Annotated + return list(final_options.values()) + + # def _match_option_param_names(self, params, options): + # options = list(options) + # params = self._check_required_params(params) + # + # for o in options: + # _validate_names(o) + # _validate_descriptions(o) + # try: + # p_name, p_obj = next(params) + # except StopIteration: # not enough params for all the options + # raise ClientException("Too many arguments passed to the options kwarg.") + # p_obj = p_obj.annotation + # + # if not any(check(o, p_obj) for check in check_annotations): + # raise TypeError( + # f"Parameter {p_name} does not match input type of {o.name}." + # ) + # o._parameter_name = p_name + # + # left_out_params = OrderedDict() + # for k, v in params: + # left_out_params[k] = v + # options.extend(self._parse_options(left_out_params, check_params=False)) + # + # return options @property def cog(self): @@ -952,7 +831,7 @@ def cog(self, value): or value is None and old_cog is not None ): - self._validate_parameters() + self._parse_options() @property def is_subcommand(self) -> bool: @@ -999,7 +878,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: arg = arg["value"] # Checks if input_type is user, role or channel - if op.input_type in ( + if op._api_type in ( SlashCommandOptionType.user, SlashCommandOptionType.role, SlashCommandOptionType.channel, @@ -1008,7 +887,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: ): resolved = ctx.interaction.data.get("resolved", {}) if ( - op.input_type + op._api_type in (SlashCommandOptionType.user, SlashCommandOptionType.mentionable) and (_data := resolved.get("members", {}).get(arg)) is not None ): @@ -1018,7 +897,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: _data["user"] = _user_data cache_flag = ctx.interaction._state.member_cache_flags.interaction arg = ctx.guild._get_and_update_member(_data, int(arg), cache_flag) - elif op.input_type is SlashCommandOptionType.mentionable: + elif op._api_type is SlashCommandOptionType.mentionable: if (_data := resolved.get("users", {}).get(arg)) is not None: arg = User(state=ctx.interaction._state, data=_data) elif (_data := resolved.get("roles", {}).get(arg)) is not None: @@ -1028,9 +907,9 @@ async def _invoke(self, ctx: ApplicationContext) -> None: else: arg = Object(id=int(arg)) elif ( - _data := resolved.get(f"{op.input_type.name}s", {}).get(arg) + _data := resolved.get(f"{op._api_type.name}s", {}).get(arg) ) is not None: - if op.input_type is SlashCommandOptionType.channel and ( + if op._api_type is SlashCommandOptionType.channel and ( int(arg) in ctx.guild._channels or int(arg) in ctx.guild._threads ): @@ -1044,12 +923,12 @@ async def _invoke(self, ctx: ApplicationContext) -> None: else: obj_type = None kw = {} - if op.input_type is SlashCommandOptionType.user: + if op._api_type is SlashCommandOptionType.user: obj_type = User - elif op.input_type is SlashCommandOptionType.role: + elif op._api_type is SlashCommandOptionType.role: obj_type = Role kw["guild"] = ctx.guild - elif op.input_type is SlashCommandOptionType.channel: + elif op._api_type is SlashCommandOptionType.channel: # NOTE: # This is a fallback in case the channel/thread is not found in the # guild's channels/threads. For channels, if this fallback occurs, at the very minimum, @@ -1058,7 +937,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: # flags, and more will be missing due to a lack of data sent by Discord. obj_type = _threaded_guild_channel_factory(_data["type"])[0] kw["guild"] = ctx.guild - elif op.input_type is SlashCommandOptionType.attachment: + elif op._api_type is SlashCommandOptionType.attachment: obj_type = Attachment arg = obj_type(state=ctx.interaction._state, data=_data, **kw) else: @@ -1066,18 +945,12 @@ async def _invoke(self, ctx: ApplicationContext) -> None: arg = Object(id=int(arg)) elif ( - op.input_type == SlashCommandOptionType.string + op._api_type == SlashCommandOptionType.string and (converter := op.converter) is not None ): - from discord.ext.commands import Converter + arg = await converter.convert(ctx, arg) - if isinstance(converter, Converter): - if isinstance(converter, type): - arg = await converter().convert(ctx, arg) - else: - arg = await converter.convert(ctx, arg) - - elif op._raw_type in ( + elif op._api_type in ( SlashCommandOptionType.integer, SlashCommandOptionType.number, SlashCommandOptionType.string, @@ -1085,20 +958,22 @@ async def _invoke(self, ctx: ApplicationContext) -> None: ): pass - elif issubclass(op._raw_type, Enum): + elif inspect.isclass(op._param_type) and issubclass( + op._param_type, (Enum, DiscordEnum) + ): if isinstance(arg, str) and arg.isdigit(): try: - arg = op._raw_type(int(arg)) + arg = op._param_type(int(arg)) except ValueError: - arg = op._raw_type(arg) + arg = op._param_type(arg) elif choice := find(lambda c: c.value == arg, op.choices): - arg = getattr(op._raw_type, choice.name) + arg = getattr(op._param_type, choice.name) - kwargs[op._parameter_name] = arg + kwargs[op._param_name] = arg for o in self.options: - if o._parameter_name not in kwargs: - kwargs[o._parameter_name] = o.default + if o._param_name not in kwargs: + kwargs[o._param_name] = o.default if self.cog is not None: await self.callback(self.cog, ctx, **kwargs) @@ -1113,9 +988,9 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): for op in ctx.interaction.data.get("options", []): if op.get("focused", False): option = find(lambda o: o.name == op["name"], self.options) - values.update( - {i["name"]: i["value"] for i in ctx.interaction.data["options"]} - ) + values.update({ + i["name"]: i["value"] for i in ctx.interaction.data["options"] + }) ctx.command = self ctx.focused = option ctx.value = op.get("value") diff --git a/discord/commands/options.py b/discord/commands/options.py index e62d81d80a..1e0d30dc00 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -1,463 +1,216 @@ -""" -The MIT License (MIT) - -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - from __future__ import annotations - -import inspect +from collections import OrderedDict +from collections.abc import Awaitable, Iterable, Callable +from enum import Enum, IntEnum import logging import sys import types -from collections.abc import Awaitable, Callable, Iterable -from enum import Enum +import inspect from typing import ( TYPE_CHECKING, + Annotated, Any, Literal, - Optional, - Type, TypeVar, Union, get_args, + get_origin, ) +from ..enums import SlashCommandOptionType, Enum as DiscordEnum, ChannelType +from .context import AutocompleteContext + -if sys.version_info >= (3, 12): - from typing import TypeAliasType -else: - from typing_extensions import TypeAliasType +from ..utils import ( + resolve_annotation, + normalise_optional_params, + MISSING, + basic_autocomplete, +) -from ..abc import GuildChannel, Mentionable +from ..abc import GuildChannel +from ..message import Attachment +from ..role import Role from ..channel import ( + TextChannel, + VoiceChannel, CategoryChannel, - DMChannel, ForumChannel, - MediaChannel, StageChannel, - TextChannel, Thread, - VoiceChannel, ) -from ..commands import ApplicationContext, AutocompleteContext -from ..enums import ChannelType -from ..enums import Enum as DiscordEnum -from ..enums import SlashCommandOptionType -from ..utils import MISSING, basic_autocomplete +from ..member import Member +from ..user import User if TYPE_CHECKING: from ..cog import Cog - from ..ext.commands import Converter - from ..member import Member - from ..message import Attachment - from ..role import Role - from ..user import User - - InputType = Union[ - Type[str], - Type[bool], - Type[int], - Type[float], - Type[GuildChannel], - Type[Thread], - Type[Member], - Type[User], - Type[Attachment], - Type[Role], - Type[Mentionable], - SlashCommandOptionType, - Converter, - Type[Converter], - Type[Enum], - Type[DiscordEnum], - ] - - AutocompleteReturnType = Union[ - Iterable["OptionChoice"], Iterable[str], Iterable[int], Iterable[float] - ] - T = TypeVar("T", bound=AutocompleteReturnType) - MaybeAwaitable = Union[T, Awaitable[T]] - AutocompleteFunction = Union[ - Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], - Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], - Callable[ - [AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] - MaybeAwaitable[AutocompleteReturnType], - ], - Callable[ - [Cog, AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] - MaybeAwaitable[AutocompleteReturnType], - ], - ] -__all__ = ( - "ThreadOption", - "Option", - "OptionChoice", - "option", -) +PY_310 = sys.version_info >= (3, 10) # for UnionType +PY_311 = sys.version_info >= (3, 11) # for StrEnum + +PY_314 = sys.version_info >= (3, 14) +StrEnum = None +if PY_311: + from enum import StrEnum # type: ignore -CHANNEL_TYPE_MAP = { + StrEnum = StrEnum + +if TYPE_CHECKING: + from discord.ext.commands import Converter + + AutocompleteReturnType = ( + Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] + ) + + T = TypeVar("T", bound=AutocompleteReturnType) + MaybeAwaitable = T | Awaitable[T] + AutocompleteFunction = ( + Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] + | Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] + | Callable[ + [AutocompleteContext, Any], + MaybeAwaitable[AutocompleteReturnType], + ] + | Callable[ + [Cog, AutocompleteContext, Any], + MaybeAwaitable[AutocompleteReturnType], + ] + ) + + ValidChannelType = ( + TextChannel | VoiceChannel | CategoryChannel | ForumChannel | StageChannel + ) + + ValidOptionType = ( + type[str] + | type[bool] + | type[int] + | type[float] + | type[GuildChannel] + | type[ValidChannelType] + | type[Thread] + | type[Member] + | type[User] + | type[Attachment] + | type[Role] + | SlashCommandOptionType + | type[Literal] + | Converter[Any] + | type[Converter[Any]] + | type[Enum] + | type[DiscordEnum] + ) + + ValidChoicesType = ( + Iterable["OptionChoice"] + | Iterable[str] + | Iterable[int] + | Iterable[float] + | type[Enum] + | type[DiscordEnum] + ) + +CLS_TO_CHANNEL_TYPE: dict[type[GuildChannel | Thread], ChannelType] = { TextChannel: ChannelType.text, VoiceChannel: ChannelType.voice, - StageChannel: ChannelType.stage_voice, CategoryChannel: ChannelType.category, - Thread: ChannelType.public_thread, ForumChannel: ChannelType.forum, - MediaChannel: ChannelType.media, - DMChannel: ChannelType.private, + StageChannel: ChannelType.stage_voice, + Thread: ChannelType.public_thread, +} +CHANNEL_TYPE_TO_CLS: dict[ChannelType, type[GuildChannel | Thread]] = { + v: k for k, v in CLS_TO_CHANNEL_TYPE.items() +} +OPTION_TYPE_TO_SLASH_OPTION_TYPE: dict[ValidOptionType, SlashCommandOptionType] = { + str: SlashCommandOptionType.string, + bool: SlashCommandOptionType.boolean, + int: SlashCommandOptionType.integer, + float: SlashCommandOptionType.number, + GuildChannel: SlashCommandOptionType.channel, + Thread: SlashCommandOptionType.channel, + Member: SlashCommandOptionType.user, + User: SlashCommandOptionType.user, + Attachment: SlashCommandOptionType.attachment, + Role: SlashCommandOptionType.role, } -_log = logging.getLogger(__name__) - - -class ThreadOption: - """Represents a class that can be passed as the ``input_type`` for an :class:`Option` class. - - .. versionadded:: 2.0 - - Parameters - ---------- - thread_type: Literal["public", "private", "news"] - The thread type to expect for this options input. - """ - - def __init__(self, thread_type: Literal["public", "private", "news"]): - type_map = { - "public": ChannelType.public_thread, - "private": ChannelType.private_thread, - "news": ChannelType.news_thread, - } - self._type = type_map[thread_type] - - -class Option: - """Represents a selectable option for a slash command. - - Attributes - ---------- - input_type: Union[Type[:class:`str`], Type[:class:`bool`], Type[:class:`int`], Type[:class:`float`], Type[:class:`.abc.GuildChannel`], Type[:class:`Thread`], Type[:class:`Member`], Type[:class:`User`], Type[:class:`Attachment`], Type[:class:`Role`], Type[:class:`.abc.Mentionable`], :class:`SlashCommandOptionType`, Type[:class:`.ext.commands.Converter`], Type[:class:`enums.Enum`], Type[:class:`Enum`]] - The type of input that is expected for this option. This can be a :class:`SlashCommandOptionType`, - an associated class, a channel type, a :class:`Converter`, a converter class or an :class:`enum.Enum`. - If a :class:`enum.Enum` is used and it has up to 25 values, :attr:`choices` will be automatically filled. If the :class:`enum.Enum` has more than 25 values, :attr:`autocomplete` will be implemented with :func:`discord.utils.basic_autocomplete` instead. - name: :class:`str` - The name of this option visible in the UI. - Inherits from the variable name if not provided as a parameter. - description: Optional[:class:`str`] - The description of this option. - Must be 100 characters or fewer. If :attr:`input_type` is a :class:`enum.Enum` and :attr:`description` is not specified, :attr:`input_type`'s docstring will be used. - choices: Optional[List[Union[:class:`Any`, :class:`OptionChoice`]]] - The list of available choices for this option. - Can be a list of values or :class:`OptionChoice` objects (which represent a name:value pair). - If provided, the input from the user must match one of the choices in the list. - required: Optional[:class:`bool`] - Whether this option is required. - default: Optional[:class:`Any`] - The default value for this option. If provided, ``required`` will be considered ``False``. - min_value: Optional[:class:`int`] - The minimum value that can be entered. - Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. - max_value: Optional[:class:`int`] - The maximum value that can be entered. - Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. - min_length: Optional[:class:`int`] - The minimum length of the string that can be entered. Must be between 0 and 6000 (inclusive). - Only applies to Options with an :attr:`input_type` of :class:`str`. - max_length: Optional[:class:`int`] - The maximum length of the string that can be entered. Must be between 1 and 6000 (inclusive). - Only applies to Options with an :attr:`input_type` of :class:`str`. - channel_types: list[:class:`discord.ChannelType`] | None - A list of channel types that can be selected in this option. - Only applies to Options with an :attr:`input_type` of :class:`discord.SlashCommandOptionType.channel`. - If this argument is used, :attr:`input_type` will be ignored. - name_localizations: Dict[:class:`str`, :class:`str`] - The name localizations for this option. The values of this should be ``"locale": "name"``. - See `here `_ for a list of valid locales. - description_localizations: Dict[:class:`str`, :class:`str`] - The description localizations for this option. The values of this should be ``"locale": "description"``. - See `here `_ for a list of valid locales. - - Examples - -------- - Basic usage: :: - - @bot.slash_command(guild_ids=[...]) - async def hello( - ctx: discord.ApplicationContext, - name: Option(str, "Enter your name"), - age: Option(int, "Enter your age", min_value=1, max_value=99, default=18) - # passing the default value makes an argument optional - # you also can create optional argument using: - # age: Option(int, "Enter your age") = 18 - ): - await ctx.respond(f"Hello! Your name is {name} and you are {age} years old.") - .. versionadded:: 2.0 - """ +_log = logging.getLogger(__name__) - input_type: SlashCommandOptionType - converter: Converter | type[Converter] | None = None +class OptionChoice: def __init__( - self, input_type: InputType = str, /, description: str | None = None, **kwargs + self, + name: str, + value: str | int | float | None = None, + name_localizations: dict[str, str] = MISSING, ) -> None: - self.name: str | None = kwargs.pop("name", None) - if self.name is not None: - self.name = str(self.name) - self._parameter_name = self.name # default - input_type = self._parse_type_alias(input_type) - input_type = self._strip_none_type(input_type) - self._raw_type: InputType | tuple = input_type - - enum_choices = [] - input_type_is_class = isinstance(input_type, type) - if input_type_is_class and issubclass(input_type, (Enum, DiscordEnum)): - if description is None and input_type.__doc__ is not None: - description = inspect.cleandoc(input_type.__doc__) - if description and len(description) > 100: - description = description[:97] + "..." - _log.warning( - "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", - self.name, - input_type, - ) - enum_choices = [OptionChoice(e.name, e.value) for e in input_type] - value_class = enum_choices[0].value.__class__ - if value_class in SlashCommandOptionType.__members__ and all( - isinstance(elem.value, value_class) for elem in enum_choices - ): - input_type = SlashCommandOptionType.from_datatype( - enum_choices[0].value.__class__ - ) - else: - enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] - input_type = SlashCommandOptionType.string + self.name: str = name + self.value: str | int | float = value if value is not None else name + self.name_localizations: dict[str, str] = name_localizations - self.description = description or "No description provided" - self.channel_types: list[ChannelType] = kwargs.pop("channel_types", []) + if not isinstance(self.value, (str, int, float)): + raise TypeError( + f"Option choice value must be of type str, int, or float, not {type(self.value)}." + ) - if self.channel_types: - self.input_type = SlashCommandOptionType.channel - elif isinstance(input_type, SlashCommandOptionType): - self.input_type = input_type - else: - from ..ext.commands import Converter + self._api_type: SlashCommandOptionType = SlashCommandOptionType.from_datatype( + type(self.value) + ) # type: ignore - if isinstance(input_type, tuple) and any( - issubclass(op, ApplicationContext) for op in input_type - ): - input_type = next( - op for op in input_type if issubclass(op, ApplicationContext) - ) + def to_dict(self) -> dict[str, str | int | float]: + base = {"name": self.name, "value": self.value} + if self.name_localizations: + base["name_localizations"] = self.name_localizations - if ( - isinstance(input_type, Converter) - or input_type_is_class - and issubclass(input_type, Converter) - ): - self.converter = input_type - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - try: - self.input_type = SlashCommandOptionType.from_datatype(input_type) - except TypeError as exc: - from ..ext.commands.converter import CONVERTER_MAPPING - - if input_type not in CONVERTER_MAPPING: - raise exc - self.converter = CONVERTER_MAPPING[input_type] - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - if self.input_type == SlashCommandOptionType.channel: - if not isinstance(self._raw_type, tuple): - if hasattr(input_type, "__args__"): - self._raw_type = input_type.__args__ # type: ignore # Union.__args__ - else: - self._raw_type = (input_type,) - if not self.channel_types: - self.channel_types = [ - CHANNEL_TYPE_MAP[t] - for t in self._raw_type - if t is not GuildChannel - ] - self.required: bool = ( - kwargs.pop("required", True) if "default" not in kwargs else False - ) - self.default = kwargs.pop("default", None) + return base - self._autocomplete: AutocompleteFunction | None = None - self._autocomplete_is_instance_method: bool = False - self.autocomplete = kwargs.pop("autocomplete", None) - if len(enum_choices) > 25: - self.choices: list[OptionChoice] = [] - for e in enum_choices: - e.value = str(e.value) - self.autocomplete = basic_autocomplete(enum_choices) - self.input_type = SlashCommandOptionType.string - else: - self.choices: list[OptionChoice] = enum_choices or [ - o if isinstance(o, OptionChoice) else OptionChoice(o) - for o in kwargs.pop("choices", []) - ] - - if self.input_type == SlashCommandOptionType.integer: - minmax_types = (int, type(None)) - minmax_typehint = Optional[int] - elif self.input_type == SlashCommandOptionType.number: - minmax_types = (int, float, type(None)) - minmax_typehint = Optional[Union[int, float]] - else: - minmax_types = (type(None),) - minmax_typehint = type(None) - if self.input_type == SlashCommandOptionType.string: - minmax_length_types = (int, type(None)) - minmax_length_typehint = Optional[int] - else: - minmax_length_types = (type(None),) - minmax_length_typehint = type(None) - - self.min_value: int | float | None = kwargs.pop("min_value", None) - self.max_value: int | float | None = kwargs.pop("max_value", None) - self.min_length: int | None = kwargs.pop("min_length", None) - self.max_length: int | None = kwargs.pop("max_length", None) - - if ( - self.input_type != SlashCommandOptionType.integer - and self.input_type != SlashCommandOptionType.number - and (self.min_value or self.max_value) - ): - raise AttributeError( - "Option does not take min_value or max_value if not of type " - "SlashCommandOptionType.integer or SlashCommandOptionType.number" - ) - if self.input_type != SlashCommandOptionType.string and ( - self.min_length or self.max_length - ): - raise AttributeError( - "Option does not take min_length or max_length if not of type str" - ) +class Option: + def __init__( + self, + input_type: ValidOptionType = str, + /, + *, + name: str | None = None, + parameter_name: str | None = None, + name_localizations: dict[str, str] | None = None, + description: str | None = None, + description_localizations: dict[str, str] | None = None, + required: bool = True, + default: int | str | float | None = None, + choices: ValidChoicesType | None = None, + channel_types: list[ChannelType] | None = None, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: int | None = None, + max_length: int | None = None, + autocomplete: AutocompleteFunction | None = None, + ) -> None: - if self.min_value is not None and not isinstance(self.min_value, minmax_types): - raise TypeError( - f"Expected {minmax_typehint} for min_value, got" - f' "{type(self.min_value).__name__}"' - ) - if self.max_value is not None and not isinstance(self.max_value, minmax_types): - raise TypeError( - f"Expected {minmax_typehint} for max_value, got" - f' "{type(self.max_value).__name__}"' - ) + self.name: str | None = name + self._param_name: str | None = parameter_name or name - if self.min_length is not None: - if not isinstance(self.min_length, minmax_length_types): - raise TypeError( - f"Expected {minmax_length_typehint} for min_length," - f' got "{type(self.min_length).__name__}"' - ) - if self.min_length < 0 or self.min_length > 6000: - raise AttributeError( - "min_length must be between 0 and 6000 (inclusive)" - ) - if self.max_length is not None: - if not isinstance(self.max_length, minmax_length_types): - raise TypeError( - f"Expected {minmax_length_typehint} for max_length," - f' got "{type(self.max_length).__name__}"' - ) - if self.max_length < 1 or self.max_length > 6000: - raise AttributeError("max_length must between 1 and 6000 (inclusive)") + self.description: str | None = description - self.name_localizations = kwargs.pop("name_localizations", MISSING) - self.description_localizations = kwargs.pop( - "description_localizations", MISSING - ) + self._param_type: ValidOptionType = input_type + self._api_type: SlashCommandOptionType | None = None + self.converter: Converter[Any] | None = None - if input_type is None: - raise TypeError("input_type cannot be NoneType.") - - @staticmethod - def _parse_type_alias(input_type: InputType) -> InputType: - if isinstance(input_type, TypeAliasType): - return input_type.__value__ - return input_type - - @staticmethod - def _strip_none_type(input_type): - if isinstance(input_type, SlashCommandOptionType): - return input_type - - if input_type is type(None): - raise TypeError("Option type cannot be only NoneType") - - args = () - if isinstance(input_type, types.UnionType): - args = get_args(input_type) - elif getattr(input_type, "__origin__", None) is Union: - args = get_args(input_type) - elif isinstance(input_type, tuple): - args = input_type - - if args: - filtered = tuple(t for t in args if t is not type(None)) - if not filtered: - raise TypeError("Option type cannot be only NoneType") - if len(filtered) == 1: - return filtered[0] - - return filtered - - return input_type - - def to_dict(self) -> dict: - as_dict = { - "name": self.name, - "description": self.description, - "type": self.input_type.value, - "required": self.required, - "choices": [c.to_dict() for c in self.choices], - "autocomplete": bool(self.autocomplete), - } - if self.name_localizations is not MISSING: - as_dict["name_localizations"] = self.name_localizations - if self.description_localizations is not MISSING: - as_dict["description_localizations"] = self.description_localizations - if self.channel_types: - as_dict["channel_types"] = [t.value for t in self.channel_types] - if self.min_value is not None: - as_dict["min_value"] = self.min_value - if self.max_value is not None: - as_dict["max_value"] = self.max_value - if self.min_length is not None: - as_dict["min_length"] = self.min_length - if self.max_length is not None: - as_dict["max_length"] = self.max_length + self.required: bool = required if default is None else False + self.default: int | str | float | None = default - return as_dict + self.choices: list[OptionChoice] = self._handle_choices(choices) + self.name_localizations: dict[str, str] = name_localizations or {} + self.description_localizations: dict[str, str] = description_localizations or {} + self.channel_types: list[ChannelType] = channel_types or [] + self.min_value: int | float | None = min_value + self.max_value: int | float | None = max_value + self.min_length: int | None = min_length + self.max_length: int | None = max_length - def __repr__(self): - return f"" + self._autocomplete: AutocompleteFunction | None = None + self.autocomplete = autocomplete @property def autocomplete(self) -> AutocompleteFunction | None: @@ -496,62 +249,348 @@ def autocomplete(self, value: AutocompleteFunction | None) -> None: == 2 ) + def _copy_from(self, other: Option) -> None: + self.name = other.name + self._param_name = other._param_name + self.description = other.description + self._param_type = other._param_type + self._api_type = other._api_type + self.converter = other.converter + self.required = other.required + self.default = other.default + self.choices = other.choices.copy() + self.name_localizations = other.name_localizations.copy() + self.description_localizations = other.description_localizations.copy() + self.channel_types = other.channel_types.copy() + self.min_value = other.min_value + self.max_value = other.max_value + self.min_length = other.min_length + self.max_length = other.max_length + self.autocomplete = other.autocomplete + + def _validate_minmax_value(self) -> None: + if self._api_type not in ( + SlashCommandOptionType.integer, + SlashCommandOptionType.number, + ): + raise ValueError( + f"max_value is only applicable for int and float parameter types, not {self._param_type}." + ) -class OptionChoice: - """ - Represents a name:value pairing for a selected :class:`.Option`. - - .. versionadded:: 2.0 - - Attributes - ---------- - name: :class:`str` - The name of the choice. Shown in the UI when selecting an option. - value: Optional[Union[:class:`str`, :class:`int`, :class:`float`]] - The value of the choice. If not provided, will use the value of ``name``. - name_localizations: Dict[:class:`str`, :class:`str`] - The name localizations for this choice. The values of this should be ``"locale": "name"``. - See `here `_ for a list of valid locales. - """ + def _validate_minmax_length(self, max_length: int | None = None) -> None: + if self._api_type is not SlashCommandOptionType.string: + raise ValueError( + f"min_length and max_length are only applicable for string option types, not {self._api_type}." + ) - def __init__( - self, - name: str, - value: str | int | float | None = None, - name_localizations: dict[str, str] = MISSING, - ): - self.name = str(name) - self.value = value if value is not None else name - self.name_localizations = name_localizations + if max_length is not None and (max_length < 1 or max_length > 6000): + raise ValueError("max_length must be between 1 and 6000.") + + def _handle_type(self, param_type: ValidOptionType | None = None) -> None: + param_type = param_type or self._param_type + if isinstance(param_type, SlashCommandOptionType): + self._api_type = param_type + return + + from discord.ext.commands.converter import CONVERTER_MAPPING, Converter + + try: + self._api_type = OPTION_TYPE_TO_SLASH_OPTION_TYPE[param_type] # type: ignore + except KeyError: + try: + converter_cls = CONVERTER_MAPPING[param_type] # type: ignore + self.converter = converter_cls() # type: ignore + self._api_type = OPTION_TYPE_TO_SLASH_OPTION_TYPE[converter_cls] # type: ignore + except KeyError: + pass + + origin = get_origin(param_type) + args = get_args(param_type) + + if inspect.isclass(param_type): + print( + "Handling class type:", + param_type, + type(param_type), + get_origin(param_type), + get_args(param_type), + param_type.__name__, + ) + if issubclass(param_type, (Enum, DiscordEnum)): # type: ignore + self._parse_choices_from_enum(param_type) + elif issubclass(param_type, Converter): # type: ignore + self.converter = param_type() # type: ignore + elif isinstance(param_type, Converter): + self.converter = param_type + elif origin is Annotated: + self._handle_type(args[0]) + return + elif get_origin(param_type) in (Union, types.UnionType): + union_args = get_args(param_type) + non_none_args = normalise_optional_params(union_args)[:-1] + if len(non_none_args) == 1: + self._handle_type(non_none_args[0]) + return + if any( + c in CLS_TO_CHANNEL_TYPE for c in non_none_args if isinstance(c, type) + ): + self._api_type = SlashCommandOptionType.channel + self.channel_types = [CLS_TO_CHANNEL_TYPE[c] for c in non_none_args] + return + if any( + isinstance(c, type) and issubclass(c, (Member, User)) + for c in non_none_args + ): + self._api_type = SlashCommandOptionType.user + return + elif get_origin(param_type) is Literal: + literal_args = get_args(param_type) + if all(isinstance(arg, str) for arg in literal_args): + self._api_type = SlashCommandOptionType.string + elif all(isinstance(arg, int) for arg in literal_args): + self._api_type = SlashCommandOptionType.integer + elif all(isinstance(arg, float) for arg in literal_args): + self._api_type = SlashCommandOptionType.number + else: + raise TypeError( + f"Unsupported literal choice types in annotation: {literal_args}. " + f"All literal choices must be of the same type and must be str, int, or float." + ) - def to_dict(self) -> dict[str, str | int | float]: - as_dict = {"name": self.name, "value": self.value} - if self.name_localizations is not MISSING: - as_dict["name_localizations"] = self.name_localizations + self._handle_choices(literal_args) + return + + if self.min_length is not None or self.max_length is not None: + self._validate_minmax_length(self.max_length) + + if self.min_value is not None or self.max_value is not None: + self._validate_minmax_value() + + def _handle_choices(self, choices: ValidChoicesType | None) -> list[OptionChoice]: + if not choices: + return [] + + final_choices: list[OptionChoice] = [] + + if isinstance(choices, type) and (issubclass(choices, (Enum, DiscordEnum))): + return self._parse_choices_from_enum(choices) + + if isinstance(choices, Iterable): + for choice in choices: + if isinstance(choice, OptionChoice): + final_choices.append(choice) + elif isinstance(choice, (str, int, float)): + final_choices.append(OptionChoice(name=str(choice), value=choice)) + else: + raise TypeError( + f"Invalid choice type: {type(choice)}. Choices must be OptionChoice instances or str/int/float." + ) + else: + raise TypeError( + f"Invalid choices type: {type(choices)}. Choices must be an iterable of OptionChoice or str/int/float, or an Enum class." + ) + + print( + "Final parsed choices:", + final_choices, + len(final_choices), + self.autocomplete, + ) + if len(final_choices) > 25 and self.autocomplete is None: + self.choices = [] + self.autocomplete = basic_autocomplete(final_choices) + _log.info( + "Option '%s' has more than 25 choices, so choices were cleared and basic autocomplete was set up automatically.", + self.name, + ) + + return final_choices + + def _parse_choices_from_enum(self, enum_cls: type[Enum]) -> list[OptionChoice]: + print("Parsing choices from Enum:", enum_cls, type(enum_cls)) + if self.description is None and enum_cls.__doc__ is not None: + description = inspect.cleandoc(enum_cls.__doc__) + if len(description) > 100: + description = description[:97] + "..." + _log.warning( + "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", + self.name, + self._api_type, + ) - return as_dict + self.description = description + if issubclass(enum_cls, IntEnum): + self._api_type = SlashCommandOptionType.integer + elif StrEnum and issubclass(enum_cls, StrEnum): + self._api_type = SlashCommandOptionType.string + else: + first_member_type: type = next(iter(enum_cls)).value + print("First member type of Enum:", first_member_type) + if not isinstance(first_member_type, (str, int, float)): + raise TypeError( + f"For parameter {self._param_name}: Enum choices must have values of type str, int, or float. Found {type(first_member_type)} in {enum_cls}." + ) -def option(name, input_type=None, **kwargs): - """A decorator that can be used instead of typehinting :class:`.Option`. + self._api_type = SlashCommandOptionType.from_datatype( + type(first_member_type) + ) - .. versionadded:: 2.0 + return self._handle_choices([ + OptionChoice(name=member.name, value=member.value) for member in enum_cls + ]) - Attributes - ---------- - parameter_name: :class:`str` - The name of the target function parameter this option is mapped to. - This allows you to have a separate UI ``name`` and parameter name. - """ + def to_dict(self) -> dict[str, Any]: + if not self._api_type: + raise ValueError("Option type has not been set.") - def decorator(func): - resolved_name = kwargs.pop("parameter_name", None) or name - itype = ( - kwargs.pop("type", None) - or input_type - or func.__annotations__.get(resolved_name, str) + base = { + "type": self._api_type.value, + "name": self._param_name, + "description": self.description, + "required": self.required, + } + if self.choices: + base["choices"] = [choice.to_dict() for choice in self.choices] + if self.name_localizations: + base["name_localizations"] = self.name_localizations + if self.description_localizations: + base["description_localizations"] = self.description_localizations + if self.channel_types: + base["channel_types"] = [ct.value for ct in self.channel_types] + if self.min_value is not None: + base["min_value"] = self.min_value + if self.max_value is not None: + base["max_value"] = self.max_value + if self.min_length is not None: + base["min_length"] = self.min_length + if self.max_length is not None: + base["max_length"] = self.max_length + + return base + + +def _get_options( + func: Callable[..., Any], *, cog: type[Cog] | None = None +) -> dict[str, Option]: + signature = inspect.signature( + func, globals=func.__globals__, locals=func.__globals__ + ) + + res: dict[str, Option] = {} + parameters = OrderedDict(signature.parameters) + + existing_options: dict[str, Option] = getattr(func, "__options__", {}).copy() + print(f"Existing options for function '{func.__name__}': {existing_options}") + + # skip 'self' and 'context' parameters if they exist + param_items = list(parameters.items()) + if cog is not None: + if param_items and param_items[0][0] != "self": + raise ValueError( + f"First parameter of method '{func.__name__}' must be 'self' when it's in a cog, but got '{param_items[0][0]}'." + ) + skip_count = 2 + else: + skip_count = 2 if param_items and param_items[0][0] == "self" else 1 + + skip_count = min(skip_count, len(param_items)) + + for param_name, param in param_items[skip_count:]: + existing = existing_options.pop(param_name, None) + option = Option( + name=param_name, + parameter_name=param_name, + description=inspect.cleandoc(param.__doc__) if param.__doc__ else None, ) - func.__annotations__[resolved_name] = Option(itype, name=name, **kwargs) + + annotation_is_option = False + if param.annotation is not param.empty: + annotation = resolve_annotation( + param.annotation, func.__globals__, func.__globals__, {} + ) + if isinstance(annotation, Option): + annotation_is_option = True + option._copy_from(annotation) + if option.name is None: + option.name = param_name + if option._param_name is None: + option._param_name = param_name + else: + option._param_type = annotation + + if existing: + option._copy_from(existing) + + if param.default is not param.empty: + option.default = param.default + option.required = False + + if param.annotation is param.empty and not annotation_is_option: + continue + + try: + option._handle_type() + except Exception as e: + raise TypeError( + f"Error processing parameter '{param_name}' of function '{func.__name__}': {e}" + ) from e + + if existing_options: + for param_name in existing_options: + raise ValueError( + f"Option '{existing_options[param_name].name}' specified for parameter " + f"'{param_name}' in function '{func.__name__}' was not found in the function. " + ) + + return res + + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def option( + name: str, + input_type: ValidOptionType = str, + /, + *, + parameter_name: str | None = None, + name_localizations: dict[str, str] | None = None, + description: str | None = MISSING, + description_localizations: dict[str, str] | None = None, + required: bool = True, + default: int | str | float | None = None, + choices: ValidChoicesType | None = None, + channel_types: list[ChannelType] | None = None, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: int | None = None, + max_length: int | None = None, + autocomplete: AutocompleteFunction | None = None, +) -> Callable[[CallableT], CallableT]: + def inner(func: CallableT) -> CallableT: + opt = Option( + input_type, + name=name, + parameter_name=parameter_name, + name_localizations=name_localizations, + description=description, + description_localizations=description_localizations, + required=required, + default=default, + choices=choices, + channel_types=channel_types, + min_value=min_value, + max_value=max_value, + min_length=min_length, + max_length=max_length, + autocomplete=autocomplete, + ) + try: + func.__options__[name] = opt # type: ignore + except AttributeError: + func.__options__ = {name: opt} # type: ignore + return func - return decorator + return inner diff --git a/intro_typing.py b/intro_typing.py new file mode 100644 index 0000000000..6776431f97 --- /dev/null +++ b/intro_typing.py @@ -0,0 +1,526 @@ +from __future__ import annotations +from collections.abc import Awaitable, Iterable, Callable +from enum import Enum, IntEnum +import logging +import sys +import types +from typing import Annotated, Any, Literal, TypeVar, Union, get_args, get_origin +import discord +from discord.enums import SlashCommandOptionType, Enum as DiscordEnum +from discord.commands import AutocompleteContext +from discord.ext.commands import Converter +from discord.cog import Cog +import inspect + + +PY_310 = sys.version_info >= (3, 10) # for UnionType +PY_311 = sys.version_info >= (3, 11) # for StrEnum + +PY_314 = sys.version_info >= (3, 14) +StrEnum = None +if PY_311: + from enum import StrEnum # type: ignore + + StrEnum = StrEnum + +AutocompleteReturnType = ( + Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] +) + +T = TypeVar("T", bound=AutocompleteReturnType) +MaybeAwaitable = T | Awaitable[T] +AutocompleteFunction = ( + Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] + | Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] + | Callable[ + [AutocompleteContext, Any], + MaybeAwaitable[AutocompleteReturnType], + ] + | Callable[ + [Cog, AutocompleteContext, Any], + MaybeAwaitable[AutocompleteReturnType], + ] +) + +ValidChannelType = ( + discord.TextChannel + | discord.VoiceChannel + | discord.CategoryChannel + | discord.ForumChannel + | discord.StageChannel +) + +ValidOptionType = ( + type[str] + | type[bool] + | type[int] + | type[float] + | type[discord.abc.GuildChannel] + | type[ValidChannelType] + | type[discord.Thread] + | type[discord.Member] + | type[discord.User] + | type[discord.Attachment] + | type[discord.Role] + | type[discord.Role] + | SlashCommandOptionType + | type[Literal] + | Converter[Any] + | type[Converter[Any]] + | type[Enum] + | type[DiscordEnum] +) + +ValidChoicesType = ( + Iterable["OptionChoice"] + | Iterable[str] + | Iterable[int] + | Iterable[float] + | type[Enum] + | type[DiscordEnum] +) + +CLS_TO_CHANNEL_TYPE: dict[ + type[discord.abc.GuildChannel | discord.Thread], discord.ChannelType +] = { + discord.TextChannel: discord.ChannelType.text, + discord.VoiceChannel: discord.ChannelType.voice, + discord.CategoryChannel: discord.ChannelType.category, + discord.ForumChannel: discord.ChannelType.forum, + discord.StageChannel: discord.ChannelType.stage_voice, + discord.Thread: discord.ChannelType.public_thread, +} +CHANNEL_TYPE_TO_CLS: dict[ + discord.ChannelType, type[discord.abc.GuildChannel | discord.Thread] +] = {v: k for k, v in CLS_TO_CHANNEL_TYPE.items()} +OPTION_TYPE_TO_SLASH_OPTION_TYPE: dict[ValidOptionType, SlashCommandOptionType] = { + str: SlashCommandOptionType.string, + bool: SlashCommandOptionType.boolean, + int: SlashCommandOptionType.integer, + float: SlashCommandOptionType.number, + discord.abc.GuildChannel: SlashCommandOptionType.channel, + discord.Thread: SlashCommandOptionType.channel, + discord.Member: SlashCommandOptionType.user, + discord.User: SlashCommandOptionType.user, + discord.Attachment: SlashCommandOptionType.attachment, + discord.Role: SlashCommandOptionType.role, +} +OPTION_TYPE_TO_DEFAULT_TYPES: dict[SlashCommandOptionType, tuple[type[Any], ...]] = { + SlashCommandOptionType.string: (str,), + SlashCommandOptionType.boolean: (bool,), + SlashCommandOptionType.integer: (int,), + SlashCommandOptionType.number: (float,), + SlashCommandOptionType.channel: (type(None),), # Channel IDs are passed as strings + SlashCommandOptionType.user: (type(None),), # User IDs are passed as strings + SlashCommandOptionType.attachment: ( + type(None), + ), # Attachment IDs are passed as strings + SlashCommandOptionType.role: (type(None),), # Role IDs are passed as strings +} + +_log = logging.getLogger(__name__) + + +class OptionChoice: + def __init__( + self, + name: str, + value: str | int | float | None = None, + name_localizations: dict[str, str] = discord.MISSING, + ) -> None: + self.name: str = name + self.value: str | int | float = value if value is not None else name + self.name_localizations: dict[str, str] = name_localizations + + def to_dict(self) -> dict[str, str | int | float]: + base = {"name": self.name, "value": self.value} + if self.name_localizations: + base["name_localizations"] = self.name_localizations + + return base + + +class Option: + def __init__( + self, + input_type: ValidOptionType = str, + /, + *, + name: str | None = None, + description: str | None = None, + required: bool = True, + default: int | str | float | None = None, + choices: ValidChoicesType | None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + channel_types: list[discord.ChannelType] | None = None, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: int | None = None, + max_length: int | None = None, + autocomplete: AutocompleteFunction | None = None, + ) -> None: + self.name: str | None = name + self._param_name: str | None = None + + self.description: str | None = description + + self._param_type: ValidOptionType = input_type + self._api_type: SlashCommandOptionType | None = None + self.converter: Converter[Any] | None = None + self._handle_type(input_type) + + self.required: bool = required if default is None else False + self.default: int | str | float | None = default + + self.choices: list[OptionChoice] = self._handle_choices(choices) + self.name_localizations: dict[str, str] = name_localizations or {} + self.description_localizations: dict[str, str] = description_localizations or {} + self.channel_types: list[discord.ChannelType] = channel_types or [] + self.min_value: int | float | None = min_value + self.max_value: int | float | None = max_value + self.min_length: int | None = min_length + self.max_length: int | None = max_length + self.autocomplete: AutocompleteFunction | None = autocomplete + + def _validate_max_value(self) -> None: + if self._param_type not in (int, float, None): + raise ValueError( + f"max_value is only applicable for int and float parameter types, not {self._param_type}." + ) + + def _handle_type(self, param_type: ValidOptionType) -> None: + if isinstance(param_type, SlashCommandOptionType): + self._api_type = param_type + return + + if ( + isinstance(param_type, type) + and ( + issubclass(param_type, (Enum, DiscordEnum)) # type: ignore + ) + and not self.choices + ): + self._parse_choices_from_enum(param_type) + return + + api_type = OPTION_TYPE_TO_SLASH_OPTION_TYPE.get(param_type) + if not api_type: + if isinstance(param_type, type) and issubclass(param_type, Converter): # type: ignore + self.converter = param_type() + api_type = SlashCommandOptionType.string + elif isinstance(param_type, Converter): + self.converter = param_type + api_type = SlashCommandOptionType.string + else: + raise TypeError(f"Unsupported option type: {param_type}") + + self._api_type = api_type + self._param_type = param_type + + def _handle_choices(self, choices: ValidChoicesType | None) -> list[OptionChoice]: + if not choices: + return [] + + final_choices: list[OptionChoice] = [] + + if isinstance(choices, type) and (issubclass(choices, (Enum, DiscordEnum))): + return self._parse_choices_from_enum(choices) + + if isinstance(choices, Iterable): + for choice in choices: + if isinstance(choice, OptionChoice): + final_choices.append(choice) + elif isinstance(choice, (str, int, float)): + final_choices.append(OptionChoice(name=str(choice), value=choice)) + else: + raise TypeError( + f"Invalid choice type: {type(choice)}. Choices must be OptionChoice instances or str/int/float." + ) + else: + raise TypeError( + f"Invalid choices type: {type(choices)}. Choices must be an iterable of OptionChoice or str/int/float, or an Enum class." + ) + + return final_choices + + def _parse_choices_from_enum(self, enum_cls: type[Enum]) -> list[OptionChoice]: + if self.description is None and enum_cls.__doc__ is not None: + description = inspect.cleandoc(enum_cls.__doc__) + if len(description) > 100: + description = description[:97] + "..." + _log.warning( + "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", + self.name, + self._api_type, + ) + + self.description = description + + if issubclass(enum_cls, IntEnum): + self._api_type = SlashCommandOptionType.integer + self._param_type = int + elif StrEnum and issubclass(enum_cls, StrEnum): + self._api_type = SlashCommandOptionType.string + self._param_type = str + else: + first_member_type: type = type(next(iter(enum_cls)).value) + if not isinstance(first_member_type, (str, int, float)): + raise TypeError( + f"Enum choices must have values of type str, int, or float. Found {type(first_member_type)} in {enum_cls}." + ) + + self._api_type = SlashCommandOptionType.from_datatype(first_member_type) + self._param_type = first_member_type # type: ignore + + return self._handle_choices(enum_cls) + + def to_dict(self) -> dict[str, Any]: + if not self._api_type: + raise ValueError("Option type has not been set.") + + base = { + "type": self._api_type.value, + "name": self.name, + "description": self.description, + "required": self.required, + } + if self.choices: + base["choices"] = [choice.to_dict() for choice in self.choices] + if self.name_localizations: + base["name_localizations"] = self.name_localizations + if self.description_localizations: + base["description_localizations"] = self.description_localizations + if self.channel_types: + base["channel_types"] = [ct.value for ct in self.channel_types] + if self.min_value is not None: + base["min_value"] = self.min_value + if self.max_value is not None: + base["max_value"] = self.max_value + if self.min_length is not None: + base["min_length"] = self.min_length + if self.max_length is not None: + base["max_length"] = self.max_length + + return base + + +class InspectedAnnotation: + def __init__( + self, + name: str, + annotation: ValidOptionType = str, + is_optional: bool = False, + is_union: bool = False, + is_literal: bool = False, + args: list[type] | None = None, + default: type | None = None, + ) -> None: + self.name = name + self.annotation = annotation + self.is_optional = is_optional + self.is_union = is_union + self.is_literal = is_literal + self.args = get_args(annotation) if args is None else args + self.default = default + + self.channel_types: list[discord.ChannelType] = [] + + self.inner_type: ValidOptionType = annotation + self.api_type = SlashCommandOptionType.string + + def check(self) -> None: + if isinstance(self.inner_type, SlashCommandOptionType): + return + + if ( + isinstance(self.inner_type, type) + and issubclass(self.inner_type, Converter) # type: ignore + or isinstance(self.inner_type, Converter) + ): + return + + api_type = OPTION_TYPE_TO_SLASH_OPTION_TYPE.get(self.inner_type) # type: ignore + if api_type: + self.api_type = api_type + + if self.is_literal: + if all(isinstance(arg, str) for arg in self.args): + self.api_type = SlashCommandOptionType.string + elif all(isinstance(arg, int) for arg in self.args): + self.api_type = SlashCommandOptionType.integer + elif all(isinstance(arg, float) for arg in self.args): + self.api_type = SlashCommandOptionType.number + else: + raise TypeError( + f"Unsupported literal choice types in annotation for parameter {self.name}: {self.args}. " + f"All literal choices must be of the same type and must be str, int, or float." + ) + elif self.is_union and self.args: + if any(c in CLS_TO_CHANNEL_TYPE for c in self.args if isinstance(c, type)): + self.api_type = SlashCommandOptionType.channel + self.channel_types = [CLS_TO_CHANNEL_TYPE[c] for c in self.args] + elif any(issubclass(c, (discord.Member, discord.User)) for c in self.args): + self.api_type = SlashCommandOptionType.user + elif self.is_union and not self.is_optional: + raise TypeError( + f"Unsupported Union annotation for parameter {self.name}: {self.annotation}. " + f"Union types are not supported unless they are Optional or a Union of channel types or " + f"a Union of Member/User types." + ) + else: + raise TypeError( + f"Unsupported annotation type for parameter {self.name}: {self.annotation}. " + f"Type must be a valid option type, a Converter, or a SlashCommandOptionType." + ) + + valid_default_types = OPTION_TYPE_TO_DEFAULT_TYPES.get( + self.api_type, (str, int, float, type(None)) + ) + + if self.default is not None and not isinstance( + self.default, valid_default_types + ): + raise ValueError( + f"Invalid default value for parameter {self.name}: {self.default}. " + f"Default value must be of type {valid_default_types}." + ) + + +def inspect_annotations(func: Callable[..., Any]) -> dict[str, InspectedAnnotation]: + signature = inspect.signature( + func, + globals=globals(), + locals=locals(), + ) + + res: dict[str, InspectedAnnotation] = {} + parameters = signature.parameters + for param_name, param in parameters.items(): + obj = InspectedAnnotation(name=param_name, default=param.default) + annotation = param.annotation + if annotation is param.empty: + continue + + origin = get_origin(annotation) + obj.annotation = annotation + + if origin is Annotated: + obj.inner_type = obj.args[0] + elif origin is Union: + obj.is_optional = type(None) in obj.args + obj.args = [arg for arg in obj.args if arg is not type(None)] + if len(obj.args) == 1 and obj.is_optional: + obj.inner_type = obj.args[0] + elif origin is Literal: + obj.is_literal = True + + obj.check() + res[param_name] = obj + + return res + + +# --- TEST CASES FOR inspect_annotations --- +from typing import Optional + + +class MyEnum(Enum): + A = 1 + B = 2 + + +def test_optional(a: Optional[int] = None): + pass + + +def test_union(b: Union[int, str]): + pass + + +def test_literal(c: Literal["foo", "bar"]): + pass + + +def test_enum(d: MyEnum): + pass + + +def test_converter(e: Converter): + pass + + +def test_channel(f: discord.TextChannel): + pass + + +def test_thread(g: discord.Thread): + pass + + +def test_member_user(h: Union[discord.Member, discord.User]): + pass + + +test_functions = [ + test_optional, + test_union, + test_literal, + test_enum, + test_converter, + test_channel, + test_thread, + test_member_user, +] + +for func in test_functions: + print(f"\nTesting: {func.__name__}") + inspected = inspect_annotations(func) + for name, annotation in inspected.items(): + print(f"Parameter: {name}") + print(f" Annotation: {annotation.annotation}") + print(f" Is Optional: {annotation.is_optional}") + print(f" Is Union: {annotation.is_union}") + print(f" Is Literal: {annotation.is_literal}") + print(f" Args: {annotation.args}") + print(f" Default: {annotation.default}") + print(f" API Type: {annotation.api_type}") + print(f" Channel Types: {annotation.channel_types}") + + +# --- MULTI-PARAMETER TEST CASES --- +def test_multi_1(a: int, b: Optional[str] = None, c: Literal[1, 2, 3] = 1): + pass + + +def test_multi_2( + x: Union[discord.TextChannel, discord.VoiceChannel], y: MyEnum, z: Converter +): + pass + + +def test_multi_3( + p: discord.Member, q: Union[discord.Member, discord.User], r: float = 3.14 +): + pass + + +multi_param_functions = [ + test_multi_1, + test_multi_2, + test_multi_3, +] + +for func in multi_param_functions: + print(f"\nTesting: {func.__name__}") + inspected = inspect_annotations(func) + for name, annotation in inspected.items(): + print(f"Parameter: {name}") + print(f" Annotation: {annotation.annotation}") + print(f" Is Optional: {annotation.is_optional}") + print(f" Is Union: {annotation.is_union}") + print(f" Is Literal: {annotation.is_literal}") + print(f" Args: {annotation.args}") + print(f" Default: {annotation.default}") + print(f" API Type: {annotation.api_type}") + print(f" Channel Types: {annotation.channel_types}") diff --git a/test_exts/options_showcase.py b/test_exts/options_showcase.py new file mode 100644 index 0000000000..7ee3ef286f --- /dev/null +++ b/test_exts/options_showcase.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from enum import Enum, IntEnum +from typing import Annotated, Literal, Optional, Union + +import discord +from discord import OptionChoice, option +from discord.enums import SlashCommandOptionType +from discord.ext import commands + + +class Color(Enum): + red = "red" + green = "green" + blue = "blue" + + +class Priority(IntEnum): + low = 1 + medium = 2 + high = 3 + + +async def color_autocomplete( + ctx: discord.AutocompleteContext, +) -> list[OptionChoice]: + choices = ["red", "green", "blue", "orange", "yellow"] + return [OptionChoice(name=c) for c in choices if ctx.value.lower() in c] + + +class OptionShowcase(commands.Cog): + def __init__(self, bot: discord.Bot): + self.bot = bot + + @commands.slash_command(name="opt_primitives") + async def opt_primitives( + self, + ctx: discord.ApplicationContext, + text: str, + count: int, + ratio: float, + flag: bool, + ) -> None: + await ctx.respond( + f"text={text} count={count} ratio={ratio} flag={flag}" + ) + + @commands.slash_command(name="opt_choices") + @option("color", description="Pick a color", choices=["red", "green", "blue"]) + @option("priority", description="Pick a priority", choices=Priority) + async def opt_choices( + self, + ctx: discord.ApplicationContext, + color: str, + priority: Priority, + ) -> None: + await ctx.respond(f"color={color} priority={priority}") + + @commands.slash_command(name="opt_optionchoice") + @option( + "size", + description="Pick a size", + choices=[ + OptionChoice(name="Small", value="S"), + OptionChoice(name="Large", value="L"), + ], + ) + async def opt_optionchoice( + self, + ctx: discord.ApplicationContext, + size: str, + ) -> None: + await ctx.respond(f"size={size}") + + @commands.slash_command(name="opt_literal") + async def opt_literal( + self, + ctx: discord.ApplicationContext, + mode: Literal["fast", "safe"], + level: Literal[1, 2, 3], + ) -> None: + await ctx.respond(f"mode={mode} level={level}") + + @commands.slash_command(name="opt_optional") + async def opt_optional( + self, + ctx: discord.ApplicationContext, + note: Optional[str] = None, + amount: Optional[int] = None, + ) -> None: + await ctx.respond(f"note={note} amount={amount}") + + @commands.slash_command(name="opt_channels") + async def opt_channels( + self, + ctx: discord.ApplicationContext, + channel: Union[discord.TextChannel, discord.VoiceChannel], + thread: discord.Thread, + ) -> None: + await ctx.respond( + f"channel={channel.mention} thread={thread.mention}" + ) + + @commands.slash_command(name="opt_channeltypes") + @option( + "channel", + discord.abc.GuildChannel, + description="Pick a text or voice channel", + channel_types=[ + discord.ChannelType.text, + discord.ChannelType.voice, + ], + ) + async def opt_channeltypes( + self, + ctx: discord.ApplicationContext, + channel: discord.abc.GuildChannel, + ) -> None: + await ctx.respond(f"channel={channel.mention}") + + @commands.slash_command(name="opt_users") + async def opt_users( + self, + ctx: discord.ApplicationContext, + member: discord.Member, + user: discord.User, + role: discord.Role, + ) -> None: + await ctx.respond( + f"member={member.mention} user={user.mention} role={role.name}" + ) + + @commands.slash_command(name="opt_union_user") + async def opt_union_user( + self, + ctx: discord.ApplicationContext, + target: Union[discord.Member, discord.User], + ) -> None: + await ctx.respond(f"target={target.mention}") + + @commands.slash_command(name="opt_attachment") + async def opt_attachment( + self, + ctx: discord.ApplicationContext, + file: discord.Attachment, + ) -> None: + if file: + await ctx.respond(f"file={file.filename}") + else: + await ctx.respond("No file provided") + + @commands.slash_command(name="opt_decorator") + @option( + "value", + SlashCommandOptionType.integer, + description="Value 1-10", + min_value=1, + max_value=10, + ) + @option( + "text", + SlashCommandOptionType.string, + description="Text 1-10 chars", + min_length=1, + max_length=10, + ) + async def opt_decorator( + self, + ctx: discord.ApplicationContext, + value: int, + text: str, + ) -> None: + await ctx.respond(f"value={value} text={text}") + + @commands.slash_command(name="opt_annotated") + async def opt_annotated( + self, + ctx: discord.ApplicationContext, + tag: Annotated[str, "ignored"], + ) -> None: + await ctx.respond(f"tag={tag}") + + @commands.slash_command(name="opt_autocomplete") + @option("color", description="Pick a color", autocomplete=color_autocomplete) + async def opt_autocomplete( + self, + ctx: discord.ApplicationContext, + color: str, + ) -> None: + await ctx.respond(f"color={color}") + + +def setup(bot: discord.Bot) -> None: + bot.add_cog(OptionShowcase(bot)) diff --git a/tests/test_slash_command_options.py b/tests/test_slash_command_options.py new file mode 100644 index 0000000000..95c47e68ba --- /dev/null +++ b/tests/test_slash_command_options.py @@ -0,0 +1,195 @@ +from enum import Enum +from typing import Literal, Optional, Union + +import pytest +from typing_extensions import Annotated + +import discord +from discord import ChannelType, SlashCommandOptionType +from discord.commands.core import SlashCommand +from discord.commands.options import Option, OptionChoice + + +def _build_registered_slash_command(func): + cmd = SlashCommand(func) + bot = discord.Bot() + bot.add_application_command(cmd) + return cmd + + +def test_slash_command_parses_basic_option(): + async def greet(ctx, name: str): + await ctx.respond(name) + + cmd = _build_registered_slash_command(greet) + options = cmd.to_dict().get("options") + + assert len(options) == 1 + assert options[0]["name"] == "name" + assert options[0]["type"] == SlashCommandOptionType.string.value + + +def test_slash_command_sets_default_and_required_flag(): + async def greet(ctx, name: str = "pycord"): + await ctx.respond(name) + + cmd = _build_registered_slash_command(greet) + + assert len(cmd.options) == 1 + assert cmd.options[0].default == "pycord" + assert cmd.options[0].required is False + + +def test_slash_command_uses_option_decorator_parameter_name_mapping(): + @discord.option( + "display_name", + str, + parameter_name="display-name", + description="Displayed name", + ) + async def greet(ctx, display_name: str): + await ctx.respond(display_name) + + cmd = _build_registered_slash_command(greet) + option = cmd.to_dict().get("options")[0] + + assert option["name"] == "display-name" + assert cmd.options[0]._param_name == "display_name" + assert option["description"] == "Displayed name" + + +def test_option_choice_rejects_invalid_value_type(): + with pytest.raises(TypeError): + OptionChoice(name="broken", value=object()) + + +def test_option_to_dict_requires_parsed_type(): + option = Option(name="value", parameter_name="value", description="Some value") + + with pytest.raises(ValueError, match="Option type has not been set"): + option.to_dict() + + +def test_option_rejects_min_length_for_non_string(): + option = Option( + int, + name="amount", + parameter_name="amount", + description="Amount", + min_length=2, + ) + + with pytest.raises(ValueError, match="min_length and max_length"): + option._handle_type() + + +def test_option_rejects_min_value_for_non_numeric(): + option = Option( + str, + name="label", + parameter_name="label", + description="Label", + min_value=1, + ) + + with pytest.raises(ValueError, match="max_value is only applicable"): + option._handle_type() + + +def test_option_parses_literal_annotation_choices(): + option = Option( + Literal["red", "green"], + name="color", + parameter_name="color", + description="Color", + ) + option._handle_type() + + assert option._api_type is SlashCommandOptionType.string + assert [choice.value for choice in option.choices] == ["red", "green"] + + +def test_option_parses_union_channel_types(): + option = Option( + Union[discord.TextChannel, discord.VoiceChannel], + name="where", + parameter_name="where", + description="Where", + ) + option._handle_type() + + assert option._api_type is SlashCommandOptionType.channel + assert sorted(option.channel_types, key=lambda t: t.value) == [ + ChannelType.text, + ChannelType.voice, + ] + + +def test_option_switches_to_autocomplete_above_25_choices(): + option = Option( + str, + name="pick", + parameter_name="pick", + description="Pick", + choices=[f"choice-{i}" for i in range(26)], + ) + + assert option.autocomplete is not None + assert option.choices == [] + + +def test_annotation_metadata_option_is_used(): + async def echo(ctx, text: Annotated[str, discord.Option(description="Some text")]): + await ctx.respond(text) + + cmd = _build_registered_slash_command(echo) + option = cmd.to_dict().get("options")[0] + + assert option["type"] == SlashCommandOptionType.string.value + assert option["description"] == "Some text" + + +def test_annotation_optional_type_parses_to_string(): + async def echo(ctx, text: Optional[str]): + await ctx.respond(text) + + cmd = _build_registered_slash_command(echo) + option = cmd.to_dict().get("options")[0] + + assert option["type"] == SlashCommandOptionType.string.value + + +def test_annotation_literal_exposes_choices(): + async def pick(ctx, value: Literal["a", "b"]): + await ctx.respond(value) + + cmd = _build_registered_slash_command(pick) + option = cmd.to_dict().get("options")[0] + + assert option["type"] == SlashCommandOptionType.string.value + assert [choice["value"] for choice in option["choices"]] == ["a", "b"] + + +def test_annotation_literal_with_mixed_types_raises(): + async def pick(ctx, value: Literal["a", 1]): + await ctx.respond(value) + + with pytest.raises(TypeError, match="Error processing parameter 'value'"): + SlashCommand(pick) + + +def test_option_parses_choices_from_enum(): + class Flavor(Enum): + VANILLA = "vanilla" + CHOCOLATE = "chocolate" + + option = Option( + Flavor, + name="flavor", + parameter_name="flavor", + description="Flavor", + ) + option._handle_type() + + assert option._api_type is SlashCommandOptionType.string + assert [choice.value for choice in option.choices] == ["vanilla", "chocolate"] From d6fa10e5b6b78ae3b7e70af4b04376609384615c Mon Sep 17 00:00:00 2001 From: Soheab <33902984+Soheab@users.noreply.github.com> Date: Thu, 12 Feb 2026 20:44:06 +0100 Subject: [PATCH 2/4] ai: 1, forwardref: 0 --- discord/commands/options.py | 292 ++++++++++++++++++++++++++++++---- test_exts/options_showcase.py | 44 ++++- 2 files changed, 299 insertions(+), 37 deletions(-) diff --git a/discord/commands/options.py b/discord/commands/options.py index 1e0d30dc00..2e408b3ed9 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -1,7 +1,9 @@ from __future__ import annotations +import ast from collections import OrderedDict from collections.abc import Awaitable, Iterable, Callable from enum import Enum, IntEnum +import importlib import logging import sys import types @@ -16,6 +18,8 @@ get_args, get_origin, ) + + from ..enums import SlashCommandOptionType, Enum as DiscordEnum, ChannelType from .context import AutocompleteContext @@ -25,6 +29,7 @@ normalise_optional_params, MISSING, basic_autocomplete, + deprecated, ) from ..abc import GuildChannel @@ -37,6 +42,8 @@ ForumChannel, StageChannel, Thread, + DMChannel, + MediaChannel, ) from ..member import Member from ..user import User @@ -94,7 +101,7 @@ | type[Attachment] | type[Role] | SlashCommandOptionType - | type[Literal] + | type[Literal] # pyright: ignore[reportMissingTypeArgument] | Converter[Any] | type[Converter[Any]] | type[Enum] @@ -110,15 +117,17 @@ | type[DiscordEnum] ) -CLS_TO_CHANNEL_TYPE: dict[type[GuildChannel | Thread], ChannelType] = { +CLS_TO_CHANNEL_TYPE: dict[type[GuildChannel | DMChannel | Thread], ChannelType] = { TextChannel: ChannelType.text, VoiceChannel: ChannelType.voice, - CategoryChannel: ChannelType.category, - ForumChannel: ChannelType.forum, StageChannel: ChannelType.stage_voice, + CategoryChannel: ChannelType.category, Thread: ChannelType.public_thread, + ForumChannel: ChannelType.forum, + MediaChannel: ChannelType.media, + DMChannel: ChannelType.private, } -CHANNEL_TYPE_TO_CLS: dict[ChannelType, type[GuildChannel | Thread]] = { +CHANNEL_TYPE_TO_CLS: dict[ChannelType, type[GuildChannel | DMChannel | Thread]] = { v: k for k, v in CLS_TO_CHANNEL_TYPE.items() } OPTION_TYPE_TO_SLASH_OPTION_TYPE: dict[ValidOptionType, SlashCommandOptionType] = { @@ -138,6 +147,61 @@ _log = logging.getLogger(__name__) +def _is_type_checking_statement(node: ast.AST) -> bool: + if isinstance(node, ast.Name): + return node.id == "TYPE_CHECKING" + if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + return node.value.id == "typing" and node.attr == "TYPE_CHECKING" + return False + + +def _get_type_checking_locals(func: Callable[..., Any]) -> dict[str, Any]: + module = inspect.getmodule(func) + module_file = getattr(module, "__file__", None) + module_name = getattr(module, "__name__", None) + if module is None or module_file is None or module_name is None: + return {} + + try: + with open(module_file, encoding="utf-8") as f: + tree = ast.parse(f.read(), filename=module_file) + except Exception: + return {} + + resolved: dict[str, Any] = {} + for node in tree.body: + if not isinstance(node, ast.If) or not _is_type_checking_statement(node.test): + continue + + for stmt in node.body: + if isinstance(stmt, ast.Import): + for alias in stmt.names: + try: + imported = importlib.import_module(alias.name) + except Exception: + continue + resolved[alias.asname or alias.name.split(".")[-1]] = imported + elif isinstance(stmt, ast.ImportFrom): + if stmt.module is None: + continue + try: + imported_module = importlib.import_module( + stmt.module, package=module_name + ) + except Exception: + continue + for alias in stmt.names: + if alias.name == "*": + continue + if not hasattr(imported_module, alias.name): + continue + resolved[alias.asname or alias.name] = getattr( + imported_module, alias.name + ) + + return resolved + + class OptionChoice: def __init__( self, @@ -149,17 +213,20 @@ def __init__( self.value: str | int | float = value if value is not None else name self.name_localizations: dict[str, str] = name_localizations - if not isinstance(self.value, (str, int, float)): + if not isinstance(self.value, (str, int, float)): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( f"Option choice value must be of type str, int, or float, not {type(self.value)}." ) - self._api_type: SlashCommandOptionType = SlashCommandOptionType.from_datatype( + self._api_type: SlashCommandOptionType = SlashCommandOptionType.from_datatype( # type: ignore type(self.value) ) # type: ignore - def to_dict(self) -> dict[str, str | int | float]: - base = {"name": self.name, "value": self.value} + def to_dict(self) -> dict[str, str | int | float | dict[str, str]]: + base: dict[str, str | int | float | dict[str, str]] = { + "name": self.name, + "value": self.value, + } if self.name_localizations: base["name_localizations"] = self.name_localizations @@ -200,7 +267,6 @@ def __init__( self.required: bool = required if default is None else False self.default: int | str | float | None = default - self.choices: list[OptionChoice] = self._handle_choices(choices) self.name_localizations: dict[str, str] = name_localizations or {} self.description_localizations: dict[str, str] = description_localizations or {} self.channel_types: list[ChannelType] = channel_types or [] @@ -212,6 +278,8 @@ def __init__( self._autocomplete: AutocompleteFunction | None = None self.autocomplete = autocomplete + self.choices: list[OptionChoice] = self._handle_choices(choices) + @property def autocomplete(self) -> AutocompleteFunction | None: """ @@ -320,12 +388,12 @@ def _handle_type(self, param_type: ValidOptionType | None = None) -> None: self._parse_choices_from_enum(param_type) elif issubclass(param_type, Converter): # type: ignore self.converter = param_type() # type: ignore - elif isinstance(param_type, Converter): + elif isinstance(param_type, Converter): # pyright: ignore[reportUnnecessaryIsInstance] self.converter = param_type elif origin is Annotated: self._handle_type(args[0]) return - elif get_origin(param_type) in (Union, types.UnionType): + elif get_origin(param_type) is Union: # pyright: ignore[reportUnnecessaryComparison] union_args = get_args(param_type) non_none_args = normalise_optional_params(union_args)[:-1] if len(non_none_args) == 1: @@ -343,7 +411,7 @@ def _handle_type(self, param_type: ValidOptionType | None = None) -> None: ): self._api_type = SlashCommandOptionType.user return - elif get_origin(param_type) is Literal: + elif get_origin(param_type) is Literal: # pyright: ignore[reportUnnecessaryComparison] literal_args = get_args(param_type) if all(isinstance(arg, str) for arg in literal_args): self._api_type = SlashCommandOptionType.string @@ -372,14 +440,14 @@ def _handle_choices(self, choices: ValidChoicesType | None) -> list[OptionChoice final_choices: list[OptionChoice] = [] - if isinstance(choices, type) and (issubclass(choices, (Enum, DiscordEnum))): + if isinstance(choices, type) and (issubclass(choices, (Enum, DiscordEnum))): # pyright: ignore[reportUnnecessaryIsInstance] return self._parse_choices_from_enum(choices) - if isinstance(choices, Iterable): + if isinstance(choices, Iterable): # pyright: ignore[reportUnnecessaryIsInstance] for choice in choices: if isinstance(choice, OptionChoice): final_choices.append(choice) - elif isinstance(choice, (str, int, float)): + elif isinstance(choice, (str, int, float)): # pyright: ignore[reportUnnecessaryIsInstance] final_choices.append(OptionChoice(name=str(choice), value=choice)) else: raise TypeError( @@ -432,7 +500,7 @@ def _parse_choices_from_enum(self, enum_cls: type[Enum]) -> list[OptionChoice]: f"For parameter {self._param_name}: Enum choices must have values of type str, int, or float. Found {type(first_member_type)} in {enum_cls}." ) - self._api_type = SlashCommandOptionType.from_datatype( + self._api_type = SlashCommandOptionType.from_datatype( # type: ignore type(first_member_type) ) @@ -444,7 +512,7 @@ def to_dict(self) -> dict[str, Any]: if not self._api_type: raise ValueError("Option type has not been set.") - base = { + base: dict[str, Any] = { "type": self._api_type.value, "name": self._param_name, "description": self.description, @@ -470,6 +538,43 @@ def to_dict(self) -> dict[str, Any]: return base +ValidThreadType = Literal[ + "public", + "private", + "news", + ChannelType.news_thread, + ChannelType.private_thread, + ChannelType.public_thread, +] + + +@deprecated( + "ThreadOption is deprecated and will be removed in a future version. Please use Option with the appropriate channel_types instead.", + since="2.9", +) +class ThreadOption(Option): + """Represents a class that can be passed as the ``input_type`` for an :class:`Option` class. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_type: Literal["public", "private", "news", :attr:`ChannelType.news_thread`, :attr:`ChannelType.private_thread`, :attr:`ChannelType.public_thread`] + The thread type to expect for this options input. + """ + + def __init__(self, thread_type: ValidThreadType) -> types.NoneType: + type_map = { + "public": ChannelType.public_thread, + "private": ChannelType.private_thread, + "news": ChannelType.news_thread, + } + return super().__init__( + Thread, + channel_types=[type_map.get(thread_type, thread_type)], # type: ignore + ) + + def _get_options( func: Callable[..., Any], *, cog: type[Cog] | None = None ) -> dict[str, Option]: @@ -487,8 +592,8 @@ def _get_options( param_items = list(parameters.items()) if cog is not None: if param_items and param_items[0][0] != "self": - raise ValueError( - f"First parameter of method '{func.__name__}' must be 'self' when it's in a cog, but got '{param_items[0][0]}'." + _log.warning( + f"First parameter of method '{func.__name__}' should be 'self' when it's in a cog, but got '{param_items[0][0]!r}'." ) skip_count = 2 else: @@ -506,21 +611,32 @@ def _get_options( annotation_is_option = False if param.annotation is not param.empty: - annotation = resolve_annotation( - param.annotation, func.__globals__, func.__globals__, {} - ) + try: + annotation = resolve_annotation( + param.annotation, func.__globals__, func.__globals__, {} + ) + except NameError: + # Only attempt TYPE_CHECKING import recovery for forward refs. + if not isinstance(param.annotation, str): + raise + eval_locals = func.__globals__.copy() + + eval_locals.update(_get_type_checking_locals(func)) + annotation = resolve_annotation( + param.annotation, func.__globals__, eval_locals, {} + ) if isinstance(annotation, Option): annotation_is_option = True - option._copy_from(annotation) + option._copy_from(annotation) # pyright: ignore[reportPrivateUsage] if option.name is None: option.name = param_name - if option._param_name is None: - option._param_name = param_name + if option._param_name is None: # pyright: ignore[reportPrivateUsage] + option._param_name = param_name # pyright: ignore[reportPrivateUsage] else: - option._param_type = annotation + option._param_type = annotation # pyright: ignore[reportPrivateUsage] if existing: - option._copy_from(existing) + option._copy_from(existing) # pyright: ignore[reportPrivateUsage] if param.default is not param.empty: option.default = param.default @@ -530,7 +646,7 @@ def _get_options( continue try: - option._handle_type() + option._handle_type() # pyright: ignore[reportPrivateUsage] except Exception as e: raise TypeError( f"Error processing parameter '{param_name}' of function '{func.__name__}': {e}" @@ -568,6 +684,67 @@ def option( max_length: int | None = None, autocomplete: AutocompleteFunction | None = None, ) -> Callable[[CallableT], CallableT]: + """Decorator to specify metadata for a command option. + + You may use multiple instances of this decorator to + specify metadata for multiple parameters in the same command. + + Parameters + ---------- + name: :class:`str` + The name of the option as it will appear in Discord. This is required. + input_type: ValidOptionType + The type of the option. This can be inferred from type annotations, but may be specified + here for convenience. Defaults to :class:`str` if not specified or inferred. + parameter_name: Optional[:class:`str`] + The name of the parameter this option corresponds to. If not specified, it will be assumed + to be the same as the option name. This is only necessary to specify if the parameter name + is different from the option name. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + A mapping of locale codes to localized option names. + description: Optional[:class:`str`] + The description of the option as it will appear in Discord. If not specified, it will be + inferred from the parameter's docstring if available. + description_localizations: Optional[Dict[:class:`str`, :class:`str`]] + A mapping of locale codes to localized option descriptions. + required: :class:`bool` + Whether the option is required. Defaults to ``True``. If a default value is provided, + this will be set to ``False``. + default: Optional[Union[:class:`int`, :class:`str`, :class:`float`]] + The default value for the option. If provided, the option will not be required. + choices: Optional[ + Union[ + Iterable[:class:`OptionChoice`], + Iterable[:class:`str`], + Iterable[:class:`int`], + Iterable[:class:`float`], + Type[Enum], + Type[DiscordEnum] + ] + ] + A list of choices for the option. Each choice can be an instance of :class:`OptionChoice` + or a raw value (str, int, or float). If an Enum class is provided, choices will be generated + from the Enum members. If more than 25 choices are provided, they will be cleared and basic + autocomplete will be set up automatically. + channel_types: Optional[List[:class:`ChannelType`]] + A list of channel types to limit a channel option to. Only applicable if the option type is a + channel or a union of channel types. + min_value: Optional[Union[:class:`int`, :class:`float`]] + The minimum value for the option. Only applicable for int and float option types. + max_value: Optional[Union[:class:`int`, :class:`float`]] + The maximum value for the option. Only applicable for int and float option types. + min_length: Optional[:class:`int`] + The minimum length for the option. Only applicable for string option types. + max_length: Optional[:class:`int`] + The maximum length for the option. Only applicable for string option types. + autocomplete: Optional[AutocompleteFunction] + An autocomplete handler for the option. Accepts a callable (sync or async) + that takes a single required argument of :class:`AutocompleteContext` or two arguments + of :class:`discord.Cog` (being the command's cog) and :class:`AutocompleteContext`. + The callable must return an iterable of :class:`str` or :class:`OptionChoice`. Alternatively, + :func:`discord.utils.basic_autocomplete` may be used in place of the callable. + """ + def inner(func: CallableT) -> CallableT: opt = Option( input_type, @@ -594,3 +771,60 @@ def inner(func: CallableT) -> CallableT: return func return inner + + +def options(**options: Option) -> Callable[[CallableT], CallableT]: + """Decorator to specify multiple options for a command at once. + + You may not use both this decorator and the :func:`option` decorator + to specify metadata for the same parameter. This will raise a + :exc:`ValueError`. + + Parameters + ---------- + **options + Keyword arguments where the key is the parameter name and + the value is an instance of :class:`Option` containing the + option metadata for that parameter. + + Example + ------- + + .. code-block:: python3 + + @commands.slash_command(name="opt_multiple") + @options( + target=discord.Option(description="The target member or user",), + hidden=discord.Option(description="Whether the option should be hidden"), + ) + async def opt_multiple( + self, + ctx: discord.ApplicationContext, + target: Union[discord.Member, discord.User], + hidden: bool = False, + ) -> None: + await ctx.respond(f"target={target.mention}") + """ + + def inner(func: CallableT) -> CallableT: + if not all(isinstance(opt, Option) for opt in options.values()): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + "All values passed to @options must be instances of Option." + ) + + existing_options: dict[str, Option] = getattr(func, "__options__", {}).copy() + if not existing_options: + existing_options = options + else: + for name, opt in options.items(): + if name in existing_options: + raise ValueError( + f"Duplicate option metadata for parameter '{name}' in function '{func.__name__}'. " + "Please don't use both @option and @options decorators to specify metadata for the same parameter." + ) + existing_options[name] = opt + + func.__options__ = existing_options # type: ignore + return func + + return inner diff --git a/test_exts/options_showcase.py b/test_exts/options_showcase.py index 7ee3ef286f..46a740596d 100644 --- a/test_exts/options_showcase.py +++ b/test_exts/options_showcase.py @@ -1,13 +1,17 @@ from __future__ import annotations from enum import Enum, IntEnum -from typing import Annotated, Literal, Optional, Union +from typing import TYPE_CHECKING, Annotated, Literal, Optional, Union import discord from discord import OptionChoice, option from discord.enums import SlashCommandOptionType from discord.ext import commands +if TYPE_CHECKING: + from discord import Member as TcMember + from discord import User as TcUser + class Color(Enum): red = "red" @@ -32,6 +36,24 @@ class OptionShowcase(commands.Cog): def __init__(self, bot: discord.Bot): self.bot = bot + @staticmethod + def create_command(): + @commands.slash_command(name="created_command") + async def created_command( + _, + ctx: discord.ApplicationContext, + text: str, + ) -> None: + await ctx.respond(f"text={text}") + + return created_command + + @commands.slash_command(name="autobackup") + async def autobackup( + self, ctx: discord.ApplicationContext, mode: bool, interval: int = 24 + ) -> None: + await ctx.respond(f"mode={mode} interval={interval}") + @commands.slash_command(name="opt_primitives") async def opt_primitives( self, @@ -41,9 +63,7 @@ async def opt_primitives( ratio: float, flag: bool, ) -> None: - await ctx.respond( - f"text={text} count={count} ratio={ratio} flag={flag}" - ) + await ctx.respond(f"text={text} count={count} ratio={ratio} flag={flag}") @commands.slash_command(name="opt_choices") @option("color", description="Pick a color", choices=["red", "green", "blue"]) @@ -97,9 +117,7 @@ async def opt_channels( channel: Union[discord.TextChannel, discord.VoiceChannel], thread: discord.Thread, ) -> None: - await ctx.respond( - f"channel={channel.mention} thread={thread.mention}" - ) + await ctx.respond(f"channel={channel.mention} thread={thread.mention}") @commands.slash_command(name="opt_channeltypes") @option( @@ -138,6 +156,14 @@ async def opt_union_user( ) -> None: await ctx.respond(f"target={target.mention}") + @commands.slash_command(name="opt_type_checking") + async def opt_type_checking( + self, + ctx: discord.ApplicationContext, + target: Union[TcMember, TcUser], + ) -> None: + await ctx.respond(f"target={target.mention}") + @commands.slash_command(name="opt_attachment") async def opt_attachment( self, @@ -191,4 +217,6 @@ async def opt_autocomplete( def setup(bot: discord.Bot) -> None: - bot.add_cog(OptionShowcase(bot)) + cog = OptionShowcase(bot) + cog.create_command() + bot.add_cog(cog) From e1a5083f7f3ed3c1c44dfdc1e588dde918ff840f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 16:45:04 +0000 Subject: [PATCH 3/4] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/commands/core.py | 10 ++-- discord/commands/options.py | 87 +++++++++++++++++++++-------------- intro_typing.py | 33 ++++++------- test_exts/options_showcase.py | 10 ++-- 4 files changed, 75 insertions(+), 65 deletions(-) diff --git a/discord/commands/core.py b/discord/commands/core.py index 48e0b757d3..5ac401d64c 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -30,8 +30,6 @@ import functools import inspect import re -import sys -import types from collections import OrderedDict from enum import Enum from typing import ( @@ -42,7 +40,6 @@ Generator, Generic, TypeVar, - Union, ) from ..channel import PartialMessageable, _threaded_guild_channel_factory @@ -69,7 +66,6 @@ from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice, _get_options - __all__ = ( "_BaseCommand", "ApplicationCommand", @@ -988,9 +984,9 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): for op in ctx.interaction.data.get("options", []): if op.get("focused", False): option = find(lambda o: o.name == op["name"], self.options) - values.update({ - i["name"]: i["value"] for i in ctx.interaction.data["options"] - }) + values.update( + {i["name"]: i["value"] for i in ctx.interaction.data["options"]} + ) ctx.command = self ctx.focused = option ctx.value = op.get("value") diff --git a/discord/commands/options.py b/discord/commands/options.py index 2e408b3ed9..47e8a9890a 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -1,13 +1,14 @@ from __future__ import annotations + import ast -from collections import OrderedDict -from collections.abc import Awaitable, Iterable, Callable -from enum import Enum, IntEnum import importlib +import inspect import logging import sys import types -import inspect +from collections import OrderedDict +from collections.abc import Awaitable, Callable, Iterable +from enum import Enum, IntEnum from typing import ( TYPE_CHECKING, Annotated, @@ -19,34 +20,32 @@ get_origin, ) - -from ..enums import SlashCommandOptionType, Enum as DiscordEnum, ChannelType -from .context import AutocompleteContext - - -from ..utils import ( - resolve_annotation, - normalise_optional_params, - MISSING, - basic_autocomplete, - deprecated, -) - from ..abc import GuildChannel -from ..message import Attachment -from ..role import Role from ..channel import ( - TextChannel, - VoiceChannel, CategoryChannel, + DMChannel, ForumChannel, + MediaChannel, StageChannel, + TextChannel, Thread, - DMChannel, - MediaChannel, + VoiceChannel, ) +from ..enums import ChannelType +from ..enums import Enum as DiscordEnum +from ..enums import SlashCommandOptionType from ..member import Member +from ..message import Attachment +from ..role import Role from ..user import User +from ..utils import ( + MISSING, + basic_autocomplete, + deprecated, + normalise_optional_params, + resolve_annotation, +) +from .context import AutocompleteContext if TYPE_CHECKING: from ..cog import Cog @@ -213,7 +212,9 @@ def __init__( self.value: str | int | float = value if value is not None else name self.name_localizations: dict[str, str] = name_localizations - if not isinstance(self.value, (str, int, float)): # pyright: ignore[reportUnnecessaryIsInstance] + if not isinstance( + self.value, (str, int, float) + ): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( f"Option choice value must be of type str, int, or float, not {type(self.value)}." ) @@ -388,12 +389,16 @@ def _handle_type(self, param_type: ValidOptionType | None = None) -> None: self._parse_choices_from_enum(param_type) elif issubclass(param_type, Converter): # type: ignore self.converter = param_type() # type: ignore - elif isinstance(param_type, Converter): # pyright: ignore[reportUnnecessaryIsInstance] + elif isinstance( + param_type, Converter + ): # pyright: ignore[reportUnnecessaryIsInstance] self.converter = param_type elif origin is Annotated: self._handle_type(args[0]) return - elif get_origin(param_type) is Union: # pyright: ignore[reportUnnecessaryComparison] + elif ( + get_origin(param_type) is Union + ): # pyright: ignore[reportUnnecessaryComparison] union_args = get_args(param_type) non_none_args = normalise_optional_params(union_args)[:-1] if len(non_none_args) == 1: @@ -411,7 +416,9 @@ def _handle_type(self, param_type: ValidOptionType | None = None) -> None: ): self._api_type = SlashCommandOptionType.user return - elif get_origin(param_type) is Literal: # pyright: ignore[reportUnnecessaryComparison] + elif ( + get_origin(param_type) is Literal + ): # pyright: ignore[reportUnnecessaryComparison] literal_args = get_args(param_type) if all(isinstance(arg, str) for arg in literal_args): self._api_type = SlashCommandOptionType.string @@ -440,14 +447,20 @@ def _handle_choices(self, choices: ValidChoicesType | None) -> list[OptionChoice final_choices: list[OptionChoice] = [] - if isinstance(choices, type) and (issubclass(choices, (Enum, DiscordEnum))): # pyright: ignore[reportUnnecessaryIsInstance] + if isinstance(choices, type) and ( + issubclass(choices, (Enum, DiscordEnum)) + ): # pyright: ignore[reportUnnecessaryIsInstance] return self._parse_choices_from_enum(choices) - if isinstance(choices, Iterable): # pyright: ignore[reportUnnecessaryIsInstance] + if isinstance( + choices, Iterable + ): # pyright: ignore[reportUnnecessaryIsInstance] for choice in choices: if isinstance(choice, OptionChoice): final_choices.append(choice) - elif isinstance(choice, (str, int, float)): # pyright: ignore[reportUnnecessaryIsInstance] + elif isinstance( + choice, (str, int, float) + ): # pyright: ignore[reportUnnecessaryIsInstance] final_choices.append(OptionChoice(name=str(choice), value=choice)) else: raise TypeError( @@ -504,9 +517,9 @@ def _parse_choices_from_enum(self, enum_cls: type[Enum]) -> list[OptionChoice]: type(first_member_type) ) - return self._handle_choices([ - OptionChoice(name=member.name, value=member.value) for member in enum_cls - ]) + return self._handle_choices( + [OptionChoice(name=member.name, value=member.value) for member in enum_cls] + ) def to_dict(self) -> dict[str, Any]: if not self._api_type: @@ -631,7 +644,9 @@ def _get_options( if option.name is None: option.name = param_name if option._param_name is None: # pyright: ignore[reportPrivateUsage] - option._param_name = param_name # pyright: ignore[reportPrivateUsage] + option._param_name = ( + param_name # pyright: ignore[reportPrivateUsage] + ) else: option._param_type = annotation # pyright: ignore[reportPrivateUsage] @@ -807,7 +822,9 @@ async def opt_multiple( """ def inner(func: CallableT) -> CallableT: - if not all(isinstance(opt, Option) for opt in options.values()): # pyright: ignore[reportUnnecessaryIsInstance] + if not all( + isinstance(opt, Option) for opt in options.values() + ): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( "All values passed to @options must be instances of Option." ) diff --git a/intro_typing.py b/intro_typing.py index 6776431f97..d3da084c8b 100644 --- a/intro_typing.py +++ b/intro_typing.py @@ -1,17 +1,18 @@ from __future__ import annotations -from collections.abc import Awaitable, Iterable, Callable -from enum import Enum, IntEnum + +import inspect import logging import sys -import types +from collections.abc import Awaitable, Callable, Iterable +from enum import Enum, IntEnum from typing import Annotated, Any, Literal, TypeVar, Union, get_args, get_origin + import discord -from discord.enums import SlashCommandOptionType, Enum as DiscordEnum +from discord.cog import Cog from discord.commands import AutocompleteContext +from discord.enums import Enum as DiscordEnum +from discord.enums import SlashCommandOptionType from discord.ext.commands import Converter -from discord.cog import Cog -import inspect - PY_310 = sys.version_info >= (3, 10) # for UnionType PY_311 = sys.version_info >= (3, 11) # for StrEnum @@ -196,9 +197,7 @@ def _handle_type(self, param_type: ValidOptionType) -> None: if ( isinstance(param_type, type) - and ( - issubclass(param_type, (Enum, DiscordEnum)) # type: ignore - ) + and (issubclass(param_type, (Enum, DiscordEnum))) # type: ignore and not self.choices ): self._parse_choices_from_enum(param_type) @@ -430,11 +429,11 @@ class MyEnum(Enum): B = 2 -def test_optional(a: Optional[int] = None): +def test_optional(a: int | None = None): pass -def test_union(b: Union[int, str]): +def test_union(b: int | str): pass @@ -458,7 +457,7 @@ def test_thread(g: discord.Thread): pass -def test_member_user(h: Union[discord.Member, discord.User]): +def test_member_user(h: discord.Member | discord.User): pass @@ -489,19 +488,17 @@ def test_member_user(h: Union[discord.Member, discord.User]): # --- MULTI-PARAMETER TEST CASES --- -def test_multi_1(a: int, b: Optional[str] = None, c: Literal[1, 2, 3] = 1): +def test_multi_1(a: int, b: str | None = None, c: Literal[1, 2, 3] = 1): pass def test_multi_2( - x: Union[discord.TextChannel, discord.VoiceChannel], y: MyEnum, z: Converter + x: discord.TextChannel | discord.VoiceChannel, y: MyEnum, z: Converter ): pass -def test_multi_3( - p: discord.Member, q: Union[discord.Member, discord.User], r: float = 3.14 -): +def test_multi_3(p: discord.Member, q: discord.Member | discord.User, r: float = 3.14): pass diff --git a/test_exts/options_showcase.py b/test_exts/options_showcase.py index 46a740596d..93d21f66ea 100644 --- a/test_exts/options_showcase.py +++ b/test_exts/options_showcase.py @@ -105,8 +105,8 @@ async def opt_literal( async def opt_optional( self, ctx: discord.ApplicationContext, - note: Optional[str] = None, - amount: Optional[int] = None, + note: str | None = None, + amount: int | None = None, ) -> None: await ctx.respond(f"note={note} amount={amount}") @@ -114,7 +114,7 @@ async def opt_optional( async def opt_channels( self, ctx: discord.ApplicationContext, - channel: Union[discord.TextChannel, discord.VoiceChannel], + channel: discord.TextChannel | discord.VoiceChannel, thread: discord.Thread, ) -> None: await ctx.respond(f"channel={channel.mention} thread={thread.mention}") @@ -152,7 +152,7 @@ async def opt_users( async def opt_union_user( self, ctx: discord.ApplicationContext, - target: Union[discord.Member, discord.User], + target: discord.Member | discord.User, ) -> None: await ctx.respond(f"target={target.mention}") @@ -160,7 +160,7 @@ async def opt_union_user( async def opt_type_checking( self, ctx: discord.ApplicationContext, - target: Union[TcMember, TcUser], + target: TcMember | TcUser, ) -> None: await ctx.respond(f"target={target.mention}") From eca7ecc92a9cdb48d8512a49f254fe69d9c00631 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Feb 2026 16:23:39 +0000 Subject: [PATCH 4/4] style(pre-commit): auto fixes from pre-commit.com hooks --- intro_typing.py | 1 - test_exts/options_showcase.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/intro_typing.py b/intro_typing.py index d3da084c8b..9c1070fcbb 100644 --- a/intro_typing.py +++ b/intro_typing.py @@ -421,7 +421,6 @@ def inspect_annotations(func: Callable[..., Any]) -> dict[str, InspectedAnnotati # --- TEST CASES FOR inspect_annotations --- -from typing import Optional class MyEnum(Enum): diff --git a/test_exts/options_showcase.py b/test_exts/options_showcase.py index 93d21f66ea..2739c1f662 100644 --- a/test_exts/options_showcase.py +++ b/test_exts/options_showcase.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Annotated, Literal, Optional, Union +from typing import TYPE_CHECKING, Annotated, Literal import discord from discord import OptionChoice, option