@@ -888,6 +888,31 @@ def _find_impl(cls, registry):
888888 match = t
889889 return registry .get (match )
890890
891+
892+ def _get_dispatch_param_name (func , * , skip_first = False ):
893+ func_code = func .__code__
894+ pos_param_count = func_code .co_argcount
895+ params = func_code .co_varnames
896+ return next (iter (params [skip_first :pos_param_count ]), None )
897+
898+
899+ def _get_dispatch_annotation (func , param ):
900+ import annotationlib
901+ annotations = annotationlib .get_annotations (func , format = annotationlib .Format .FORWARDREF )
902+ try :
903+ return annotations [param ]
904+ except KeyError :
905+ raise TypeError (
906+ f"Invalid first argument to `register()`: { param !r} . "
907+ f"Add missing annotation to parameter { param !r} of { func .__qualname__ !r} or use `@register(some_class)`."
908+ ) from None
909+
910+
911+ def _get_dispatch_param_and_annotation (func , * , skip_first = False ):
912+ param = _get_dispatch_param_name (func , skip_first = skip_first )
913+ return param , _get_dispatch_annotation (func , param )
914+
915+
891916def singledispatch (func ):
892917 """Single-dispatch generic function decorator.
893918
@@ -935,7 +960,7 @@ def _is_valid_dispatch_type(cls):
935960 return (isinstance (cls , UnionType ) and
936961 all (isinstance (arg , type ) for arg in cls .__args__ ))
937962
938- def register (cls , func = None ):
963+ def register (cls , func = None , _func_is_method = False ):
939964 """generic_func.register(cls, func) -> func
940965
941966 Registers a new implementation for the given *cls* on a *generic_func*.
@@ -960,10 +985,11 @@ def register(cls, func=None):
960985 )
961986 func = cls
962987
963- # only import typing if annotation parsing is necessary
964- from typing import get_type_hints
965- from annotationlib import Format , ForwardRef
966- argname , cls = next (iter (get_type_hints (func , format = Format .FORWARDREF ).items ()))
988+ argname , cls = _get_dispatch_param_and_annotation (
989+ func , skip_first = _func_is_method )
990+
991+ from annotationlib import ForwardRef
992+
967993 if not _is_valid_dispatch_type (cls ):
968994 if isinstance (cls , UnionType ):
969995 raise TypeError (
@@ -1027,7 +1053,7 @@ def register(self, cls, method=None):
10271053
10281054 Registers a new implementation for the given *cls* on a *generic_method*.
10291055 """
1030- return self .dispatcher .register (cls , func = method )
1056+ return self .dispatcher .register (cls , func = method , _func_is_method = True )
10311057
10321058 def __get__ (self , obj , cls = None ):
10331059 return _singledispatchmethod_get (self , obj , cls )
0 commit comments