Skip to content

Commit 114f250

Browse files
add match-case support for unions
1 parent cb2b850 commit 114f250

File tree

2 files changed

+240
-43
lines changed

2 files changed

+240
-43
lines changed

Lib/test/test_patma.py

Lines changed: 219 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import dis
55
import enum
66
import inspect
7-
from re import I
87
import sys
98
import unittest
109
from test import support
@@ -16,6 +15,13 @@ class Point:
1615
y: int
1716

1817

18+
@dataclasses.dataclass
19+
class Point3D:
20+
x: int
21+
y: int
22+
z: int
23+
24+
1925
class TestCompiler(unittest.TestCase):
2026

2127
def test_refleaks(self):
@@ -2891,11 +2897,81 @@ class B(A): ...
28912897

28922898
def test_patma_union_type(self):
28932899
IntOrStr = int | str
2894-
x = 0
2895-
match x:
2900+
w = None
2901+
match 0:
28962902
case IntOrStr():
2897-
x = 1
2898-
self.assertEqual(x, 1)
2903+
w = 0
2904+
self.assertEqual(w, 0)
2905+
2906+
def test_patma_union_no_match(self):
2907+
StrOrBytes = str | bytes
2908+
w = None
2909+
match 0:
2910+
case StrOrBytes():
2911+
w = 0
2912+
self.assertIsNone(w)
2913+
2914+
def test_union_type_positional_subpattern(self):
2915+
IntOrStr = int | str
2916+
w = None
2917+
match 0:
2918+
case IntOrStr(y):
2919+
w = y
2920+
self.assertEqual(w, 0)
2921+
2922+
def test_union_type_keyword_subpattern(self):
2923+
EitherPoint = Point | Point3D
2924+
p = Point(x=1, y=2)
2925+
w = None
2926+
match p:
2927+
case EitherPoint(x=1, y=2):
2928+
w = 0
2929+
self.assertEqual(w, 0)
2930+
2931+
def test_patma_union_arg(self):
2932+
p = Point(x=1, y=2)
2933+
IntOrStr = int | str
2934+
w = None
2935+
match p:
2936+
case Point(IntOrStr(), IntOrStr()):
2937+
w = 0
2938+
self.assertEqual(w, 0)
2939+
2940+
def test_patma_union_kwarg(self):
2941+
p = Point(x=1, y=2)
2942+
IntOrStr = int | str
2943+
w = None
2944+
match p:
2945+
case Point(x=IntOrStr(), y=IntOrStr()):
2946+
w = 0
2947+
self.assertEqual(w, 0)
2948+
2949+
def test_patma_union_arg_no_match(self):
2950+
p = Point(x=1, y=2)
2951+
StrOrBytes = str | bytes
2952+
w = None
2953+
match p:
2954+
case Point(StrOrBytes(), StrOrBytes()):
2955+
w = 0
2956+
self.assertIsNone(w)
2957+
2958+
def test_patma_union_kwarg_no_match(self):
2959+
p = Point(x=1, y=2)
2960+
StrOrBytes = str | bytes
2961+
w = None
2962+
match p:
2963+
case Point(x=StrOrBytes(), y=StrOrBytes()):
2964+
w = 0
2965+
self.assertIsNone(w)
2966+
2967+
def test_union_type_match_second_member(self):
2968+
EitherPoint = Point | Point3D
2969+
p = Point3D(x=1, y=2, z=3)
2970+
w = None
2971+
match p:
2972+
case EitherPoint(x=1, y=2, z=3):
2973+
w = 0
2974+
self.assertEqual(w, 0)
28992975

29002976

29012977
class TestSyntaxErrors(unittest.TestCase):
@@ -3239,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self):
32393315
pass
32403316
""")
32413317

3318+
32423319
class TestTypeErrors(unittest.TestCase):
32433320

3321+
def test_generic_type(self):
3322+
t = list[str]
3323+
w = None
3324+
with self.assertRaises(TypeError):
3325+
match ["s"]:
3326+
case t():
3327+
w = 0
3328+
self.assertIsNone(w)
3329+
3330+
def test_legacy_generic_type(self):
3331+
from typing import List
3332+
t = List[str]
3333+
w = None
3334+
with self.assertRaises(TypeError):
3335+
match ["s"]:
3336+
case t():
3337+
w = 0
3338+
self.assertIsNone(w)
3339+
32443340
def test_accepts_positional_subpatterns_0(self):
32453341
class Class:
32463342
__match_args__ = ()
@@ -3350,6 +3446,124 @@ def test_class_pattern_not_type(self):
33503446
w = 0
33513447
self.assertIsNone(w)
33523448

3449+
def test_class_or_union_not_specialform(self):
3450+
from typing import Literal
3451+
name = type(Literal).__name__
3452+
msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)"
3453+
w = None
3454+
with self.assertRaisesRegex(TypeError, msg):
3455+
match 1:
3456+
case Literal():
3457+
w = 0
3458+
self.assertIsNone(w)
3459+
3460+
def test_legacy_union_type(self):
3461+
from typing import Union
3462+
IntOrStr = Union[int, str]
3463+
name = type(IntOrStr).__name__
3464+
msg = rf"called match pattern must be a class or types.UnionType \(got {name}\)"
3465+
w = None
3466+
with self.assertRaisesRegex(TypeError, msg):
3467+
match 1:
3468+
case IntOrStr():
3469+
w = 0
3470+
self.assertIsNone(w)
3471+
3472+
def test_expanded_union_mirrors_isinstance_success(self):
3473+
ListOfInt = list[int]
3474+
t = int | ListOfInt
3475+
try: # get the isinstance result
3476+
reference = isinstance(1, t)
3477+
except TypeError as exc:
3478+
reference = exc
3479+
3480+
try: # get the match-case result
3481+
match 1:
3482+
case int() | ListOfInt():
3483+
result = True
3484+
case _:
3485+
result = False
3486+
except TypeError as exc:
3487+
result = exc
3488+
3489+
# we should ge the same result
3490+
self.assertIs(result, True)
3491+
self.assertIs(reference, True)
3492+
3493+
def test_expanded_union_mirrors_isinstance_failure(self):
3494+
ListOfInt = list[int]
3495+
t = ListOfInt | int
3496+
3497+
try: # get the isinstance result
3498+
reference = isinstance(1, t)
3499+
except TypeError as exc:
3500+
reference = exc
3501+
3502+
try: # get the match-case result
3503+
match 1:
3504+
case ListOfInt() | int():
3505+
result = True
3506+
case _:
3507+
result = False
3508+
except TypeError as exc:
3509+
result = exc
3510+
3511+
# we should ge the same result
3512+
self.assertIsInstance(result, TypeError)
3513+
self.assertIsInstance(reference, TypeError)
3514+
3515+
def test_union_type_mirrors_isinstance_success(self):
3516+
t = int | list[int]
3517+
3518+
try: # get the isinstance result
3519+
reference = isinstance(1, t)
3520+
except TypeError as exc:
3521+
reference = exc
3522+
3523+
try: # get the match-case result
3524+
match 1:
3525+
case t():
3526+
result = True
3527+
case _:
3528+
result = False
3529+
except TypeError as exc:
3530+
result = exc
3531+
3532+
# we should ge the same result
3533+
self.assertIs(result, True)
3534+
self.assertIs(reference, True)
3535+
3536+
def test_union_type_mirrors_isinstance_failure(self):
3537+
t = list[int] | int
3538+
3539+
try: # get the isinstance result
3540+
reference = isinstance(1, t)
3541+
except TypeError as exc:
3542+
reference = exc
3543+
3544+
try: # get the match-case result
3545+
match 1:
3546+
case t():
3547+
result = True
3548+
case _:
3549+
result = False
3550+
except TypeError as exc:
3551+
result = exc
3552+
3553+
# we should ge the same result
3554+
self.assertIsInstance(result, TypeError)
3555+
self.assertIsInstance(reference, TypeError)
3556+
3557+
def test_generic_union_type(self):
3558+
from collections.abc import Sequence, Set
3559+
t = Sequence[str] | Set[str]
3560+
w = None
3561+
with self.assertRaises(TypeError):
3562+
match ["s"]:
3563+
case t():
3564+
w = 0
3565+
self.assertIsNone(w)
3566+
33533567
def test_regular_protocol(self):
33543568
from typing import Protocol
33553569
class P(Protocol): ...
@@ -3379,31 +3593,6 @@ class A:
33793593
w = 0
33803594
self.assertIsNone(w)
33813595

3382-
def test_union_type_postional_subpattern(self):
3383-
IntOrStr = int | str
3384-
x = 1
3385-
w = None
3386-
with self.assertRaises(TypeError):
3387-
match x:
3388-
case IntOrStr(x):
3389-
w = 0
3390-
self.assertEqual(x, 1)
3391-
self.assertIsNone(w)
3392-
3393-
def test_union_type_keyword_subpattern(self):
3394-
@dataclasses.dataclass
3395-
class Point2:
3396-
x: int
3397-
y: int
3398-
EitherPoint = Point | Point2
3399-
x = Point(x=1, y=2)
3400-
w = None
3401-
with self.assertRaises(TypeError):
3402-
match x:
3403-
case EitherPoint(x=1, y=2):
3404-
w = 0
3405-
self.assertIsNone(w)
3406-
34073596

34083597
class TestValueErrors(unittest.TestCase):
34093598

Python/ceval.c

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,26 +726,34 @@ PyObject*
726726
_PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
727727
Py_ssize_t nargs, PyObject *kwargs)
728728
{
729-
if (!PyType_Check(type) && !_PyUnion_Check(type)) {
730-
const char *e = "called match pattern must be a class or a union";
731-
_PyErr_Format(tstate, PyExc_TypeError, e);
729+
// Recurse on unions.
730+
if (_PyUnion_Check(type)) {
731+
// get union members
732+
PyObject *members = _Py_union_args(type);
733+
const Py_ssize_t n = PyTuple_GET_SIZE(members);
734+
735+
// iterate over union members and return first match
736+
for (Py_ssize_t i = 0; i < n; i++) {
737+
PyObject *member = PyTuple_GET_ITEM(members, i);
738+
PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs);
739+
// match found
740+
if (attrs != NULL) {
741+
return attrs;
742+
}
743+
}
744+
// no match found
745+
return NULL;
746+
}
747+
if (!PyType_Check(type)) {
748+
const char *e = "called match pattern must be a class or types.UnionType (got %s)";
749+
_PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name);
732750
return NULL;
733751
}
734752
assert(PyTuple_CheckExact(kwargs));
735753
// First, an isinstance check:
736754
if (PyObject_IsInstance(subject, type) <= 0) {
737755
return NULL;
738756
}
739-
// Subpatterns are not supported for union types:
740-
if (_PyUnion_Check(type)) {
741-
// Return error if any positional or keyword arguments are given:
742-
if (nargs || PyTuple_GET_SIZE(kwargs)) {
743-
const char *e = "union types do not support sub-patterns";
744-
_PyErr_Format(tstate, PyExc_TypeError, e);
745-
return NULL;
746-
}
747-
return PyTuple_New(0);
748-
}
749757
// So far so good:
750758
PyObject *seen = PySet_New(NULL);
751759
if (seen == NULL) {

0 commit comments

Comments
 (0)