Skip to content

Commit 0a7ceb5

Browse files
added subpattern support
1 parent cb2b850 commit 0a7ceb5

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

Lib/test/test_patma.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import dis
55
import enum
66
import inspect
7-
from re import I
87
import sys
98
import unittest
109
from test import support
@@ -16,6 +15,13 @@ class Point:
1615
y: int
1716

1817

18+
@dataclasses.dataclass
19+
class Point3D:
20+
x: int
21+
y: int
22+
z: int
23+
24+
1925
class TestCompiler(unittest.TestCase):
2026

2127
def test_refleaks(self):
@@ -2897,6 +2903,60 @@ def test_patma_union_type(self):
28972903
x = 1
28982904
self.assertEqual(x, 1)
28992905

2906+
def test_union_type_positional_subpattern(self):
2907+
IntOrStr = int | str
2908+
x = 1
2909+
w = None
2910+
match x:
2911+
case IntOrStr(y):
2912+
w = y
2913+
self.assertEqual(w, 1)
2914+
2915+
def test_union_type_keyword_subpattern(self):
2916+
EitherPoint = Point | Point3D
2917+
p = Point(x=1, y=2)
2918+
w = None
2919+
match p:
2920+
case EitherPoint(x=1, y=2):
2921+
w = 1
2922+
self.assertEqual(w, 1)
2923+
2924+
def test_patma_union_no_match(self):
2925+
IntOrStr = int | str
2926+
x = None
2927+
match x:
2928+
case IntOrStr():
2929+
x = 1
2930+
self.assertIsNone(x)
2931+
2932+
def test_patma_union_arg(self):
2933+
p = Point(x=1, y=2)
2934+
IntOrStr = int | str
2935+
w = None
2936+
match p:
2937+
case Point(IntOrStr(), IntOrStr()):
2938+
w = 1
2939+
self.assertEqual(w, 1)
2940+
2941+
def test_patma_union_kwarg(self):
2942+
p = Point(x=1, y=2)
2943+
IntOrStr = int | str
2944+
w = None
2945+
match p:
2946+
case Point(x=IntOrStr(), y=IntOrStr()):
2947+
w = 1
2948+
self.assertEqual(w, 1)
2949+
2950+
def test_union_type_match_second_member(self):
2951+
EitherPoint = Point | Point3D
2952+
p = Point3D(x=1, y=2, z=3)
2953+
w = None
2954+
match p:
2955+
case EitherPoint(x=1, y=2, z=3):
2956+
w = 1
2957+
self.assertEqual(w, 1)
2958+
2959+
29002960

29012961
class TestSyntaxErrors(unittest.TestCase):
29022962

@@ -3379,31 +3439,6 @@ class A:
33793439
w = 0
33803440
self.assertIsNone(w)
33813441

3382-
def test_union_type_postional_subpattern(self):
3383-
IntOrStr = int | str
3384-
x = 1
3385-
w = None
3386-
with self.assertRaises(TypeError):
3387-
match x:
3388-
case IntOrStr(x):
3389-
w = 0
3390-
self.assertEqual(x, 1)
3391-
self.assertIsNone(w)
3392-
3393-
def test_union_type_keyword_subpattern(self):
3394-
@dataclasses.dataclass
3395-
class Point2:
3396-
x: int
3397-
y: int
3398-
EitherPoint = Point | Point2
3399-
x = Point(x=1, y=2)
3400-
w = None
3401-
with self.assertRaises(TypeError):
3402-
match x:
3403-
case EitherPoint(x=1, y=2):
3404-
w = 0
3405-
self.assertIsNone(w)
3406-
34073442

34083443
class TestValueErrors(unittest.TestCase):
34093444

Python/ceval.c

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,26 +726,34 @@ PyObject*
726726
_PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
727727
Py_ssize_t nargs, PyObject *kwargs)
728728
{
729-
if (!PyType_Check(type) && !_PyUnion_Check(type)) {
730-
const char *e = "called match pattern must be a class or a union";
731-
_PyErr_Format(tstate, PyExc_TypeError, e);
729+
// Recurse on unions.
730+
if (_PyUnion_Check(type)) {
731+
// get union members
732+
PyObject *members = _Py_union_args(type);
733+
Py_ssize_t n = PyTuple_GET_SIZE(members);
734+
735+
// iterate over union members and return first match
736+
for (Py_ssize_t i = 0; i < n; i++) {
737+
PyObject *member = PyTuple_GET_ITEM(members, i);
738+
PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs);
739+
// match found
740+
if (attrs != NULL) {
741+
return attrs;
742+
}
743+
}
744+
// no match found
745+
return NULL;
746+
}
747+
if (!PyType_Check(type)) {
748+
const char *e = "called match pattern must be a class or a union, (got %s, %s)";
749+
_PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name, subject->ob_type->tp_name);
732750
return NULL;
733751
}
734752
assert(PyTuple_CheckExact(kwargs));
735753
// First, an isinstance check:
736754
if (PyObject_IsInstance(subject, type) <= 0) {
737755
return NULL;
738756
}
739-
// Subpatterns are not supported for union types:
740-
if (_PyUnion_Check(type)) {
741-
// Return error if any positional or keyword arguments are given:
742-
if (nargs || PyTuple_GET_SIZE(kwargs)) {
743-
const char *e = "union types do not support sub-patterns";
744-
_PyErr_Format(tstate, PyExc_TypeError, e);
745-
return NULL;
746-
}
747-
return PyTuple_New(0);
748-
}
749757
// So far so good:
750758
PyObject *seen = PySet_New(NULL);
751759
if (seen == NULL) {

0 commit comments

Comments
 (0)