@@ -2895,6 +2895,28 @@ def add(self, x, y):
28952895 Abstract ()
28962896
28972897 def test_type_ann_register (self ):
2898+ @functools .singledispatch
2899+ def t (arg ):
2900+ return "base"
2901+ @t .register
2902+ def _ (arg : int ):
2903+ return "int"
2904+ @t .register
2905+ def _ (arg : str ):
2906+ return "str"
2907+ def _ (arg : bytes ):
2908+ return "bytes"
2909+ @t .register
2910+ @functools .wraps (_ )
2911+ def wrapper (* args , ** kwargs ):
2912+ return _ (* args , ** kwargs )
2913+ self .assertEqual (t (0 ), "int" )
2914+ self .assertEqual (t ('' ), "str" )
2915+ self .assertEqual (t (0.0 ), "base" )
2916+ self .assertEqual (t (b'' ), "bytes" )
2917+
2918+ def test_method_type_ann_register (self ):
2919+
28982920 class A :
28992921 @functools .singledispatchmethod
29002922 def t (self , arg ):
@@ -2905,13 +2927,28 @@ def _(self, arg: int):
29052927 @t .register
29062928 def _ (self , arg : str ):
29072929 return "str"
2930+ def _ (self , arg : bytes ):
2931+ return "bytes"
2932+ @t .register
2933+ @functools .wraps (_ )
2934+ def wrapper (self , * args , ** kwargs ):
2935+ return self ._ (* args , ** kwargs )
2936+
29082937 a = A ()
29092938
29102939 self .assertEqual (a .t (0 ), "int" )
29112940 self .assertEqual (a .t ('' ), "str" )
29122941 self .assertEqual (a .t (0.0 ), "base" )
2942+ self .assertEqual (a .t (b'' ), "bytes" )
29132943
29142944 def test_staticmethod_type_ann_register (self ):
2945+ def wrapper_decorator (func ):
2946+ wrapped = func .__func__
2947+ @staticmethod
2948+ @functools .wraps (wrapped )
2949+ def wrapper (* args , ** kwargs ):
2950+ return wrapped (* args , ** kwargs )
2951+ return wrapper
29152952 class A :
29162953 @functools .singledispatchmethod
29172954 @staticmethod
@@ -2925,13 +2962,25 @@ def _(arg: int):
29252962 @staticmethod
29262963 def _ (arg : str ):
29272964 return isinstance (arg , str )
2965+ @t .register
2966+ @wrapper_decorator
2967+ @staticmethod
2968+ def _ (arg : bytes ):
2969+ return isinstance (arg , bytes )
29282970 a = A ()
29292971
29302972 self .assertTrue (A .t (0 ))
29312973 self .assertTrue (A .t ('' ))
29322974 self .assertEqual (A .t (0.0 ), 0.0 )
29332975
29342976 def test_classmethod_type_ann_register (self ):
2977+ def wrapper_decorator (func ):
2978+ wrapped = func .__func__
2979+ @classmethod
2980+ @functools .wraps (wrapped )
2981+ def wrapper (* args , ** kwargs ):
2982+ return wrapped (* args , ** kwargs )
2983+ return wrapper
29352984 class A :
29362985 def __init__ (self , arg ):
29372986 self .arg = arg
@@ -2948,10 +2997,16 @@ def _(cls, arg: int):
29482997 @classmethod
29492998 def _ (cls , arg : str ):
29502999 return cls ("str" )
3000+ @t .register
3001+ @wrapper_decorator
3002+ @classmethod
3003+ def _ (cls , arg : bytes ):
3004+ return cls ("bytes" )
29513005
29523006 self .assertEqual (A .t (0 ).arg , "int" )
29533007 self .assertEqual (A .t ('' ).arg , "str" )
29543008 self .assertEqual (A .t (0.0 ).arg , "base" )
3009+ self .assertEqual (A .t (b'' ).arg , "bytes" )
29553010
29563011 def test_method_wrapping_attributes (self ):
29573012 class A :
@@ -3170,12 +3225,27 @@ def test_invalid_registrations(self):
31703225 @functools .singledispatch
31713226 def i (arg ):
31723227 return "base"
3228+ with self .assertRaises (TypeError ) as exc :
3229+ @i .register
3230+ def _ () -> None :
3231+ return "My function doesn't take arguments"
3232+ self .assertStartsWith (str (exc .exception ), msg_prefix )
3233+ self .assertEndsWith (str (exc .exception ), "does not accept positional arguments." )
3234+
3235+ with self .assertRaises (TypeError ) as exc :
3236+ @i .register
3237+ def _ (* , foo : str ) -> None :
3238+ return "My function takes keyword-only arguments"
3239+ self .assertStartsWith (str (exc .exception ), msg_prefix )
3240+ self .assertEndsWith (str (exc .exception ), "does not accept positional arguments." )
3241+
31733242 with self .assertRaises (TypeError ) as exc :
31743243 @i .register (42 )
31753244 def _ (arg ):
31763245 return "I annotated with a non-type"
31773246 self .assertStartsWith (str (exc .exception ), msg_prefix + "42" )
31783247 self .assertEndsWith (str (exc .exception ), msg_suffix )
3248+
31793249 with self .assertRaises (TypeError ) as exc :
31803250 @i .register
31813251 def _ (arg ):
@@ -3185,6 +3255,17 @@ def _(arg):
31853255 )
31863256 self .assertEndsWith (str (exc .exception ), msg_suffix )
31873257
3258+ with self .assertRaises (TypeError ) as exc :
3259+ @i .register
3260+ def _ (arg , extra : int ):
3261+ return "I did not annotate the right param"
3262+ self .assertStartsWith (str (exc .exception ), msg_prefix +
3263+ "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
3264+ )
3265+ self .assertEndsWith (str (exc .exception ),
3266+ "Add missing type annotation to parameter 'arg' "
3267+ "of this function or use `@register(some_class)`." )
3268+
31883269 with self .assertRaises (TypeError ) as exc :
31893270 @i .register
31903271 def _ (arg : typing .Iterable [str ]):
0 commit comments