Skip to content

Commit cb2b850

Browse files
tmke8randolf-scholz
authored andcommitted
Allow the use of unions as match patterns
1 parent 958657b commit cb2b850

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

Lib/test/test_patma.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dis
55
import enum
66
import inspect
7+
from re import I
78
import sys
89
import unittest
910
from test import support
@@ -2888,6 +2889,14 @@ class B(A): ...
28882889
h = 1
28892890
self.assertEqual(h, 1)
28902891

2892+
def test_patma_union_type(self):
2893+
IntOrStr = int | str
2894+
x = 0
2895+
match x:
2896+
case IntOrStr():
2897+
x = 1
2898+
self.assertEqual(x, 1)
2899+
28912900

28922901
class TestSyntaxErrors(unittest.TestCase):
28932902

@@ -3370,6 +3379,31 @@ class A:
33703379
w = 0
33713380
self.assertIsNone(w)
33723381

3382+
def test_union_type_postional_subpattern(self):
3383+
IntOrStr = int | str
3384+
x = 1
3385+
w = None
3386+
with self.assertRaises(TypeError):
3387+
match x:
3388+
case IntOrStr(x):
3389+
w = 0
3390+
self.assertEqual(x, 1)
3391+
self.assertIsNone(w)
3392+
3393+
def test_union_type_keyword_subpattern(self):
3394+
@dataclasses.dataclass
3395+
class Point2:
3396+
x: int
3397+
y: int
3398+
EitherPoint = Point | Point2
3399+
x = Point(x=1, y=2)
3400+
w = None
3401+
with self.assertRaises(TypeError):
3402+
match x:
3403+
case EitherPoint(x=1, y=2):
3404+
w = 0
3405+
self.assertIsNone(w)
3406+
33733407

33743408
class TestValueErrors(unittest.TestCase):
33753409

Python/ceval.c

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "pycore_template.h" // _PyTemplate_Build()
4040
#include "pycore_traceback.h" // _PyTraceBack_FromFrame
4141
#include "pycore_tuple.h" // _PyTuple_ITEMS()
42+
#include "pycore_unionobject.h" // _PyUnion_Check()
4243
#include "pycore_uop_ids.h" // Uops
4344

4445
#include "dictobject.h"
@@ -725,8 +726,8 @@ PyObject*
725726
_PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
726727
Py_ssize_t nargs, PyObject *kwargs)
727728
{
728-
if (!PyType_Check(type)) {
729-
const char *e = "called match pattern must be a class";
729+
if (!PyType_Check(type) && !_PyUnion_Check(type)) {
730+
const char *e = "called match pattern must be a class or a union";
730731
_PyErr_Format(tstate, PyExc_TypeError, e);
731732
return NULL;
732733
}
@@ -735,6 +736,16 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
735736
if (PyObject_IsInstance(subject, type) <= 0) {
736737
return NULL;
737738
}
739+
// Subpatterns are not supported for union types:
740+
if (_PyUnion_Check(type)) {
741+
// Return error if any positional or keyword arguments are given:
742+
if (nargs || PyTuple_GET_SIZE(kwargs)) {
743+
const char *e = "union types do not support sub-patterns";
744+
_PyErr_Format(tstate, PyExc_TypeError, e);
745+
return NULL;
746+
}
747+
return PyTuple_New(0);
748+
}
738749
// So far so good:
739750
PyObject *seen = PySet_New(NULL);
740751
if (seen == NULL) {

0 commit comments

Comments
 (0)