Skip to content

Commit d3d6e64

Browse files
committed
Resolve first positional param, required to be annotated
1 parent b538c28 commit d3d6e64

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

Lib/functools.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,31 @@ def _find_impl(cls, registry):
888888
match = t
889889
return registry.get(match)
890890

891+
892+
def _get_dispatch_param_name(func, *, skip_first=False):
893+
func_code = func.__code__
894+
pos_param_count = func_code.co_argcount
895+
params = func_code.co_varnames
896+
return next(iter(params[skip_first:pos_param_count]), None)
897+
898+
899+
def _get_dispatch_annotation(func, param):
900+
import annotationlib
901+
annotations = annotationlib.get_annotations(func, format=annotationlib.Format.FORWARDREF)
902+
try:
903+
return annotations[param]
904+
except KeyError:
905+
raise TypeError(
906+
f"Invalid first argument to `register()`: {param!r}. "
907+
f"Add missing annotation to parameter {param!r} of {func.__qualname__!r} or use `@register(some_class)`."
908+
) from None
909+
910+
911+
def _get_dispatch_param_and_annotation(func, *, skip_first=False):
912+
param = _get_dispatch_param_name(func, skip_first=skip_first)
913+
return param, _get_dispatch_annotation(func, param)
914+
915+
891916
def singledispatch(func):
892917
"""Single-dispatch generic function decorator.
893918
@@ -935,7 +960,7 @@ def _is_valid_dispatch_type(cls):
935960
return (isinstance(cls, UnionType) and
936961
all(isinstance(arg, type) for arg in cls.__args__))
937962

938-
def register(cls, func=None):
963+
def register(cls, func=None, _func_is_method=False):
939964
"""generic_func.register(cls, func) -> func
940965
941966
Registers a new implementation for the given *cls* on a *generic_func*.
@@ -960,10 +985,11 @@ def register(cls, func=None):
960985
)
961986
func = cls
962987

963-
# only import typing if annotation parsing is necessary
964-
from typing import get_type_hints
965-
from annotationlib import Format, ForwardRef
966-
argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
988+
argname, cls = _get_dispatch_param_and_annotation(
989+
func, skip_first=_func_is_method)
990+
991+
from annotationlib import ForwardRef
992+
967993
if not _is_valid_dispatch_type(cls):
968994
if isinstance(cls, UnionType):
969995
raise TypeError(
@@ -1027,7 +1053,7 @@ def register(self, cls, method=None):
10271053
10281054
Registers a new implementation for the given *cls* on a *generic_method*.
10291055
"""
1030-
return self.dispatcher.register(cls, func=method)
1056+
return self.dispatcher.register(cls, func=method, _func_is_method=True)
10311057

10321058
def __get__(self, obj, cls=None):
10331059
return _singledispatchmethod_get(self, obj, cls)

0 commit comments

Comments
 (0)