Skip to content

Commit 82616f9

Browse files
committed
Support all callables
1 parent 8d86f9e commit 82616f9

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

Lib/functools.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# import weakref # Deferred to single_dispatch()
2020
from operator import itemgetter
2121
from reprlib import recursive_repr
22-
from types import GenericAlias, MethodType, MappingProxyType, UnionType
22+
from types import FunctionType, GenericAlias, MethodType, MappingProxyType, UnionType
2323
from _thread import RLock
2424

2525
################################################################################
@@ -888,20 +888,24 @@ def _find_impl(cls, registry):
888888
match = t
889889
return registry.get(match)
890890

891-
def _get_dispatch_param_name(func, *, skip_first_param=False):
892-
if not hasattr(func, '__code__'):
893-
skip_first_param = not isinstance(func, staticmethod)
891+
def _get_dispatch_param(func, *, pos=0):
892+
if isinstance(func, (MethodType, classmethod, staticmethod)):
894893
func = func.__func__
895-
func_code = func.__code__
896-
pos_param_count = func_code.co_argcount
897-
params = func_code.co_varnames
898-
try:
899-
return params[skip_first_param:pos_param_count][0]
900-
except IndexError:
901-
raise TypeError(
902-
f"Invalid first argument to `register()`: function {func!r}"
903-
f"does not accept positional arguments."
904-
) from None
894+
if isinstance(func, FunctionType):
895+
func_code = func.__code__
896+
try:
897+
return func_code.co_varnames[:func_code.co_argcount][pos]
898+
except IndexError:
899+
pass
900+
import inspect
901+
for insp_param in list(inspect.signature(func).parameters.values())[pos:]:
902+
if insp_param.KEYWORD_ONLY or insp_param.VAR_KEYWORD:
903+
break
904+
return insp_param.name
905+
raise TypeError(
906+
f"Invalid first argument to `register()`: {func!r}"
907+
f"does not accept positional arguments."
908+
) from None
905909

906910
def _get_dispatch_annotation(func, param):
907911
import annotationlib, typing
@@ -916,8 +920,8 @@ def _get_dispatch_annotation(func, param):
916920
) from None
917921
return fwdref_or_typeform
918922

919-
def _get_dispatch_param_and_annotation(func, *, skip_first_param=False):
920-
param = _get_dispatch_param_name(func, skip_first_param=skip_first_param)
923+
def _get_dispatch_arg_from_annotations(func, *, pos=0):
924+
param = _get_dispatch_param(func, pos=pos)
921925
return param, _get_dispatch_annotation(func, param)
922926

923927
def singledispatch(func):
@@ -992,8 +996,9 @@ def register(cls, func=None, _func_is_method=False):
992996
)
993997
func = cls
994998

995-
argname, cls = _get_dispatch_param_and_annotation(
996-
func, skip_first_param=_func_is_method)
999+
# 0 for functions, 1 for methods
1000+
argpos = _func_is_method and not isinstance(func, staticmethod)
1001+
argname, cls = _get_dispatch_arg_from_annotations(func, pos=argpos)
9971002

9981003
from annotationlib import ForwardRef
9991004

0 commit comments

Comments
 (0)