@@ -888,25 +888,29 @@ def _find_impl(cls, registry):
888888 match = t
889889 return registry .get (match )
890890
891- def _get_dispatch_param (func , * , pos = 0 ):
892- """Finds the positional user-specified parameter at position *pos*
893- of a callable or descriptor.
891+ def _get_dispatch_param (func , * , _dispatchmethod = False ):
892+ """Finds the first positional and user-specified parameter in a callable
893+ or descriptor.
894894
895895 Used by singledispatch for registration by type annotation of the parameter.
896- *pos* should either be 0 (for functions and staticmethods) or 1 (for methods).
897896 """
898897 # Fast path for typical callables and descriptors.
899- if isinstance (func , (classmethod , staticmethod )):
898+ # 0 from singledispatch(), 1 from singledispatchmethod()
899+ idx = _dispatchmethod
900+ if isinstance (func , staticmethod ):
901+ idx = 0
902+ func = func .__func__
903+ elif isinstance (func , classmethod ):
900904 func = func .__func__
901905 if isinstance (func , FunctionType ) and not hasattr (func , "__wrapped__" ):
902906 func_code = func .__code__
903907 try :
904- return func_code .co_varnames [:func_code .co_argcount ][pos ]
908+ return func_code .co_varnames [:func_code .co_argcount ][idx ]
905909 except IndexError :
906910 pass
907911 # Fallback path for more nuanced inspection of ambiguous callables.
908912 import inspect
909- for param in list (inspect .signature (func ).parameters .values ())[pos :]:
913+ for param in list (inspect .signature (func ).parameters .values ())[idx :]:
910914 if param .kind in (param .KEYWORD_ONLY , param .VAR_KEYWORD ):
911915 break
912916 return param .name
@@ -959,7 +963,7 @@ def _is_valid_dispatch_type(cls):
959963 return (isinstance (cls , UnionType ) and
960964 all (isinstance (arg , type ) for arg in cls .__args__ ))
961965
962- def register (cls , func = None , _func_is_method = False ):
966+ def register (cls , func = None , _dispatchmethod = False ):
963967 """generic_func.register(cls, func) -> func
964968
965969 Registers a new implementation for the given *cls* on a *generic_func*.
@@ -984,10 +988,7 @@ def register(cls, func=None, _func_is_method=False):
984988 )
985989 func = cls
986990
987- # 0 for functions, 1 for methods where first argument should be skipped
988- argpos = _func_is_method and not isinstance (func , staticmethod )
989-
990- argname = _get_dispatch_param (func , pos = argpos )
991+ argname = _get_dispatch_param (func , _dispatchmethod = _dispatchmethod )
991992 if argname is None :
992993 raise TypeError (
993994 f"Invalid first argument to `register()`: { func !r} "
@@ -1071,7 +1072,7 @@ def register(self, cls, method=None):
10711072
10721073 Registers a new implementation for the given *cls* on a *generic_method*.
10731074 """
1074- return self .dispatcher .register (cls , func = method , _func_is_method = True )
1075+ return self .dispatcher .register (cls , func = method , _dispatchmethod = True )
10751076
10761077 def __get__ (self , obj , cls = None ):
10771078 return _singledispatchmethod_get (self , obj , cls )
0 commit comments