Skip to content

Commit b55de76

Browse files
committed
Rework registration tests
1 parent ef74667 commit b55de76

File tree

1 file changed

+135
-1
lines changed

1 file changed

+135
-1
lines changed

Lib/test/test_functools.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2955,6 +2955,43 @@ def _(cls, arg: str):
29552955
self.assertEqual(A.t('').arg, "str")
29562956
self.assertEqual(A.t(0.0).arg, "base")
29572957

2958+
def test_method_type_ann_register_correct_param_skipped(self):
2959+
class C:
2960+
@functools.singledispatchmethod
2961+
def t(self, x):
2962+
return "base"
2963+
2964+
@t.register
2965+
def _(self: typing.Self, x: int) -> str:
2966+
return "int"
2967+
2968+
@t.register
2969+
@classmethod
2970+
def _(self: type['C'], x: complex) -> str:
2971+
return "complex"
2972+
2973+
@t.register
2974+
@staticmethod # 'x' cannot be skipped.
2975+
def _(x: float) -> str:
2976+
return "float"
2977+
2978+
def _bytes(self: typing.Self, x: bytes) -> None:
2979+
return "bytes"
2980+
2981+
def _bytearray(self: typing.Self, x: bytearray) -> None:
2982+
return "bytearray"
2983+
2984+
c = C()
2985+
C.t.register(c._bytes)
2986+
c.t.register(C._bytearray)
2987+
2988+
self.assertEqual(c.t(NotImplemented), "base")
2989+
self.assertEqual(c.t(42), "int")
2990+
self.assertEqual(c.t(1991j), "complex")
2991+
self.assertEqual(c.t(.572), "float")
2992+
self.assertEqual(c.t(b'ytes'), "bytes")
2993+
self.assertEqual(c.t(bytearray(3)), "bytearray")
2994+
29582995
def test_method_wrapping_attributes(self):
29592996
class A:
29602997
@functools.singledispatchmethod
@@ -3175,7 +3212,7 @@ def i(arg):
31753212
with self.assertRaises(TypeError) as exc:
31763213
@i.register(42)
31773214
def _(arg):
3178-
return "I annotated with a non-type"
3215+
return "I passed a non-type"
31793216
self.assertStartsWith(str(exc.exception), msg_prefix + "42")
31803217
self.assertEndsWith(str(exc.exception), msg_suffix)
31813218
with self.assertRaises(TypeError) as exc:
@@ -3187,6 +3224,10 @@ def _(arg):
31873224
)
31883225
self.assertEndsWith(str(exc.exception), msg_suffix)
31893226

3227+
def test_type_ann_register_invalid_types(self):
3228+
@functools.singledispatch
3229+
def i(arg):
3230+
return "base"
31903231
with self.assertRaises(TypeError) as exc:
31913232
@i.register
31923233
def _(arg: typing.Iterable[str]):
@@ -3213,6 +3254,99 @@ def _(arg: typing.Union[int, typing.Iterable[str]]):
32133254
'int | typing.Iterable[str] not all arguments are classes.'
32143255
)
32153256

3257+
def test_type_ann_register_missing_annotation(self):
3258+
add_missing_re = (
3259+
r"Invalid first argument to `register\(\)`: <function .+>. "
3260+
r"Use either `@register\(some_class\)` or add a type annotation "
3261+
r"to parameter 'arg' of your callable."
3262+
)
3263+
no_positional_re = (
3264+
r"Invalid first argument to `register\(\)`: <function .+> "
3265+
r"does not accept positional arguments."
3266+
)
3267+
3268+
@functools.singledispatch
3269+
def d(arg):
3270+
pass
3271+
3272+
with self.assertRaisesRegex(TypeError, add_missing_re):
3273+
@d.register
3274+
def _(arg) -> int:
3275+
"""I only annotated the return type."""
3276+
return 42
3277+
3278+
with self.assertRaisesRegex(TypeError, add_missing_re):
3279+
@d.register
3280+
def _(arg, /, arg2) -> int:
3281+
"""I did not annotate the first param."""
3282+
return 42
3283+
3284+
with self.assertRaisesRegex(TypeError, no_positional_re):
3285+
@d.register
3286+
def _(*, arg: int = 13, arg2: int = 37) -> int:
3287+
"""I do not accept positional arguments."""
3288+
return 42
3289+
3290+
with self.assertRaisesRegex(TypeError, add_missing_re):
3291+
@d.register
3292+
def _(arg, **kwargs: int):
3293+
"""I only annotated keyword arguments type."""
3294+
return 42
3295+
3296+
def test_method_type_ann_register_missing_annotation(self):
3297+
add_missing_re = (
3298+
r"Invalid first argument to `register\(\)`: <%s.+>. "
3299+
r"Use either `@register\(some_class\)` or add a type annotation "
3300+
r"to parameter 'arg' of your callable."
3301+
)
3302+
no_positional_re = (
3303+
r"Invalid first argument to `register\(\)`: <%s.+> "
3304+
r"does not accept positional arguments."
3305+
)
3306+
3307+
class C:
3308+
@functools.singledispatchmethod
3309+
def d(self, arg):
3310+
return "base"
3311+
3312+
with self.assertRaisesRegex(TypeError, no_positional_re % "function"):
3313+
@d.register
3314+
def _() -> None:
3315+
"""I am not a incorrect method."""
3316+
return 42
3317+
3318+
with self.assertRaisesRegex(TypeError, no_positional_re % "function"):
3319+
@d.register
3320+
def _(self: typing.Self):
3321+
"""I only take self."""
3322+
return 42
3323+
3324+
with self.assertRaisesRegex(TypeError, no_positional_re % "function"):
3325+
@d.register
3326+
def _(self: typing.Self, *, arg):
3327+
"""I did not annotate the key parameter."""
3328+
return 42
3329+
3330+
with self.assertRaisesRegex(TypeError, add_missing_re % "classmethod"):
3331+
@d.register
3332+
@classmethod
3333+
def _(cls: type[typing.Self], arg) -> int:
3334+
"""I did not annotate the key parameter again."""
3335+
return 42
3336+
3337+
with self.assertRaisesRegex(TypeError, add_missing_re % "staticmethod"):
3338+
@d.register
3339+
@staticmethod
3340+
def _(arg, arg2: int, /, *, arg3: int = 1991):
3341+
"""I missed first arg again."""
3342+
return 42
3343+
3344+
def later(self, arg, **kwargs: int):
3345+
return 42
3346+
3347+
with self.assertRaisesRegex(TypeError, add_missing_re % "bound method"):
3348+
C.d.register(C().later)
3349+
32163350
def test_invalid_positional_argument(self):
32173351
@functools.singledispatch
32183352
def f(*args, **kwargs):

0 commit comments

Comments
 (0)