1919# import weakref # Deferred to single_dispatch()
2020from operator import itemgetter
2121from reprlib import recursive_repr
22- from types import GenericAlias , MethodType , MappingProxyType , UnionType
22+ from types import FunctionType , GenericAlias , MethodType , MappingProxyType , UnionType
2323from _thread import RLock
2424
2525################################################################################
@@ -888,20 +888,24 @@ def _find_impl(cls, registry):
888888 match = t
889889 return registry .get (match )
890890
891- def _get_dispatch_param_name (func , * , skip_first_param = False ):
892- if not hasattr (func , '__code__' ):
893- skip_first_param = not isinstance (func , staticmethod )
891+ def _get_dispatch_param (func , * , pos = 0 ):
892+ if isinstance (func , (MethodType , classmethod , staticmethod )):
894893 func = func .__func__
895- func_code = func .__code__
896- pos_param_count = func_code .co_argcount
897- params = func_code .co_varnames
898- try :
899- return params [skip_first_param :pos_param_count ][0 ]
900- except IndexError :
901- raise TypeError (
902- f"Invalid first argument to `register()`: function { func !r} "
903- f"does not accept positional arguments."
904- ) from None
894+ if isinstance (func , FunctionType ):
895+ func_code = func .__code__
896+ try :
897+ return func_code .co_varnames [:func_code .co_argcount ][pos ]
898+ except IndexError :
899+ pass
900+ import inspect
901+ for insp_param in list (inspect .signature (func ).parameters .values ())[pos :]:
902+ if insp_param .KEYWORD_ONLY or insp_param .VAR_KEYWORD :
903+ break
904+ return insp_param .name
905+ raise TypeError (
906+ f"Invalid first argument to `register()`: { func !r} "
907+ f"does not accept positional arguments."
908+ ) from None
905909
906910def _get_dispatch_annotation (func , param ):
907911 import annotationlib , typing
@@ -916,8 +920,8 @@ def _get_dispatch_annotation(func, param):
916920 ) from None
917921 return fwdref_or_typeform
918922
919- def _get_dispatch_param_and_annotation (func , * , skip_first_param = False ):
920- param = _get_dispatch_param_name (func , skip_first_param = skip_first_param )
923+ def _get_dispatch_arg_from_annotations (func , * , pos = 0 ):
924+ param = _get_dispatch_param (func , pos = pos )
921925 return param , _get_dispatch_annotation (func , param )
922926
923927def singledispatch (func ):
@@ -992,8 +996,9 @@ def register(cls, func=None, _func_is_method=False):
992996 )
993997 func = cls
994998
995- argname , cls = _get_dispatch_param_and_annotation (
996- func , skip_first_param = _func_is_method )
999+ # 0 for functions, 1 for methods
1000+ argpos = _func_is_method and not isinstance (func , staticmethod )
1001+ argname , cls = _get_dispatch_arg_from_annotations (func , pos = argpos )
9971002
9981003 from annotationlib import ForwardRef
9991004
0 commit comments