Skip to content

Commit 8350e71

Browse files
committed
Crystalize the decision tree
1 parent 0f75d98 commit 8350e71

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

Lib/functools.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -888,24 +888,27 @@ def _find_impl(cls, registry):
888888
match = t
889889
return registry.get(match)
890890

891-
def _get_dispatch_param(func, *, _dispatchmethod=False):
891+
def _get_dispatch_param(func, *, _inside_dispatchmethod=False):
892892
"""Finds the first positional and user-specified parameter in a callable
893893
or descriptor.
894894
895895
Used by singledispatch for registration by type annotation of the parameter.
896896
"""
897897
# Fast path for typical callables and descriptors.
898-
# idx is 0 when singledispatch() and 1 when singledispatchmethod().
898+
899+
# For staticmethods always pick the first parameter.
899900
if isinstance(func, staticmethod):
900901
idx = 0
901902
func = func.__func__
902-
elif isinstance(func, classmethod):
903+
# For classmethods and bound methods always pick the second parameter.
904+
elif isinstance(func, (classmethod, MethodType)):
903905
idx = 1
904906
func = func.__func__
905-
elif _dispatchmethod and not isinstance(func, MethodType):
906-
idx = 1
907+
# For unbound methods and functions, pick:
908+
# - the first parameter if calling from singledispatch()
909+
# - the second parameter if calling from singledispatchmethod()
907910
else:
908-
idx = 0
911+
idx = _inside_dispatchmethod
909912

910913
if isinstance(func, FunctionType) and not hasattr(func, "__wrapped__"):
911914
# Method from inspect._signature_from_function.
@@ -971,7 +974,7 @@ def _is_valid_dispatch_type(cls):
971974
return (isinstance(cls, UnionType) and
972975
all(isinstance(arg, type) for arg in cls.__args__))
973976

974-
def register(cls, func=None, _dispatchmethod=False):
977+
def register(cls, func=None, _inside_dispatchmethod=False):
975978
"""generic_func.register(cls, func) -> func
976979
977980
Registers a new implementation for the given *cls* on a *generic_func*.
@@ -996,7 +999,8 @@ def register(cls, func=None, _dispatchmethod=False):
996999
)
9971000
func = cls
9981001

999-
argname = _get_dispatch_param(func, _dispatchmethod=_dispatchmethod)
1002+
argname = _get_dispatch_param(
1003+
func, _inside_dispatchmethod=_inside_dispatchmethod)
10001004
if argname is None:
10011005
raise TypeError(
10021006
f"Invalid first argument to `register()`: {func!r} "
@@ -1075,12 +1079,12 @@ def __init__(self, func):
10751079
self.dispatcher = singledispatch(func)
10761080
self.func = func
10771081

1078-
def register(self, cls, method=None, _dispatchmethod=True):
1082+
def register(self, cls, method=None):
10791083
"""generic_method.register(cls, func) -> func
10801084
10811085
Registers a new implementation for the given *cls* on a *generic_method*.
10821086
"""
1083-
return self.dispatcher.register(cls, func=method, _dispatchmethod=_dispatchmethod)
1087+
return self.dispatcher.register(cls, func=method, _inside_dispatchmethod=True)
10841088

10851089
def __get__(self, obj, cls=None):
10861090
return _singledispatchmethod_get(self, obj, cls)

0 commit comments

Comments
 (0)