Skip to content

Commit b6a9784

Browse files
committed
Allow the use of unions as match patterns
1 parent 2770d5c commit b6a9784

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

Lib/test/test_patma.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dataclasses
44
import enum
55
import inspect
6+
from re import I
67
import sys
78
import unittest
89

@@ -2886,6 +2887,14 @@ class B(A): ...
28862887
h = 1
28872888
self.assertEqual(h, 1)
28882889

2890+
def test_patma_union_type(self):
2891+
IntOrStr = int | str
2892+
x = 0
2893+
match x:
2894+
case IntOrStr():
2895+
x = 1
2896+
self.assertEqual(x, 1)
2897+
28892898

28902899
class TestSyntaxErrors(unittest.TestCase):
28912900

@@ -3361,6 +3370,31 @@ class A:
33613370
w = 0
33623371
self.assertIsNone(w)
33633372

3373+
def test_union_type_postional_subpattern(self):
3374+
IntOrStr = int | str
3375+
x = 1
3376+
w = None
3377+
with self.assertRaises(TypeError):
3378+
match x:
3379+
case IntOrStr(x):
3380+
w = 0
3381+
self.assertEqual(x, 1)
3382+
self.assertIsNone(w)
3383+
3384+
def test_union_type_keyword_subpattern(self):
3385+
@dataclasses.dataclass
3386+
class Point2:
3387+
x: int
3388+
y: int
3389+
EitherPoint = Point | Point2
3390+
x = Point(x=1, y=2)
3391+
w = None
3392+
with self.assertRaises(TypeError):
3393+
match x:
3394+
case EitherPoint(x=1, y=2):
3395+
w = 0
3396+
self.assertIsNone(w)
3397+
33643398

33653399
class TestValueErrors(unittest.TestCase):
33663400

Python/ceval.c

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "pycore_sysmodule.h" // _PySys_Audit()
3030
#include "pycore_tuple.h" // _PyTuple_ITEMS()
3131
#include "pycore_typeobject.h" // _PySuper_Lookup()
32+
#include "pycore_unionobject.h" // _PyUnion_Check()
3233
#include "pycore_uop_ids.h" // Uops
3334
#include "pycore_pyerrors.h"
3435

@@ -460,8 +461,8 @@ PyObject*
460461
_PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
461462
Py_ssize_t nargs, PyObject *kwargs)
462463
{
463-
if (!PyType_Check(type)) {
464-
const char *e = "called match pattern must be a class";
464+
if (!PyType_Check(type) && !_PyUnion_Check(type)) {
465+
const char *e = "called match pattern must be a class or a union";
465466
_PyErr_Format(tstate, PyExc_TypeError, e);
466467
return NULL;
467468
}
@@ -470,6 +471,16 @@ _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type,
470471
if (PyObject_IsInstance(subject, type) <= 0) {
471472
return NULL;
472473
}
474+
// Subpatterns are not supported for union types:
475+
if (_PyUnion_Check(type)) {
476+
// Return error if any positional or keyword arguments are given:
477+
if (nargs || PyTuple_GET_SIZE(kwargs)) {
478+
const char *e = "union types do not support sub-patterns";
479+
_PyErr_Format(tstate, PyExc_TypeError, e);
480+
return NULL;
481+
}
482+
return PyTuple_New(0);
483+
}
473484
// So far so good:
474485
PyObject *seen = PySet_New(NULL);
475486
if (seen == NULL) {

0 commit comments

Comments
 (0)