Skip to content

Commit c582ff3

Browse files
authored
gh-141510: Fix frozendict.fromkeys() for subclasses (#144952)
Copy the frozendict if needed.
1 parent 1ddb412 commit c582ff3

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

Lib/test/test_dict.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,34 @@ def test_hash(self):
17871787
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
17881788
hash(fd)
17891789

1790+
def test_fromkeys(self):
1791+
self.assertEqual(frozendict.fromkeys('abc'),
1792+
frozendict(a=None, b=None, c=None))
1793+
1794+
# Subclass which overrides the constructor
1795+
created = frozendict(x=1)
1796+
class FrozenDictSubclass(frozendict):
1797+
def __new__(self):
1798+
return created
1799+
1800+
fd = FrozenDictSubclass.fromkeys("abc")
1801+
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
1802+
self.assertEqual(type(fd), FrozenDictSubclass)
1803+
self.assertEqual(created, frozendict(x=1))
1804+
1805+
fd = FrozenDictSubclass.fromkeys(frozendict(y=2))
1806+
self.assertEqual(fd, frozendict(x=1, y=None))
1807+
self.assertEqual(type(fd), FrozenDictSubclass)
1808+
self.assertEqual(created, frozendict(x=1))
1809+
1810+
# Subclass which doesn't override the constructor
1811+
class FrozenDictSubclass2(frozendict):
1812+
pass
1813+
1814+
fd = FrozenDictSubclass2.fromkeys("abc")
1815+
self.assertEqual(fd, frozendict(a=None, b=None, c=None))
1816+
self.assertEqual(type(fd), FrozenDictSubclass2)
1817+
17901818

17911819
if __name__ == "__main__":
17921820
unittest.main()

Objects/dictobject.c

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ As a consequence of this, split keys have a maximum size of 16.
138138
// Forward declarations
139139
static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
140140
PyObject *kwds);
141+
static int dict_merge(PyObject *a, PyObject *b, int override);
141142

142143

143144
/*[clinic input]
@@ -294,6 +295,8 @@ can_modify_dict(PyDictObject *mp)
294295
return PyUnstable_Object_IsUniquelyReferenced(_PyObject_CAST(mp));
295296
}
296297
else {
298+
// Locking is only required if the dictionary is not
299+
// uniquely referenced.
297300
ASSERT_DICT_LOCKED(mp);
298301
return 1;
299302
}
@@ -3238,6 +3241,8 @@ _PyDict_Pop(PyObject *dict, PyObject *key, PyObject *default_value)
32383241
static PyDictObject *
32393242
dict_dict_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
32403243
{
3244+
assert(can_modify_dict(mp));
3245+
32413246
PyObject *oldvalue;
32423247
Py_ssize_t pos = 0;
32433248
PyObject *key;
@@ -3263,6 +3268,8 @@ dict_dict_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
32633268
static PyDictObject *
32643269
dict_set_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
32653270
{
3271+
assert(can_modify_dict(mp));
3272+
32663273
Py_ssize_t pos = 0;
32673274
PyObject *key;
32683275
Py_hash_t hash;
@@ -3294,9 +3301,31 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
32943301
int status;
32953302

32963303
d = _PyObject_CallNoArgs(cls);
3297-
if (d == NULL)
3304+
if (d == NULL) {
32983305
return NULL;
3306+
}
32993307

3308+
// If cls is a frozendict subclass with overridden constructor,
3309+
// copy the frozendict.
3310+
PyTypeObject *cls_type = _PyType_CAST(cls);
3311+
if (PyFrozenDict_Check(d)
3312+
&& PyObject_IsSubclass(cls, (PyObject*)&PyFrozenDict_Type)
3313+
&& cls_type->tp_new != frozendict_new)
3314+
{
3315+
// Subclass-friendly copy
3316+
PyObject *copy = frozendict_new(cls_type, NULL, NULL);
3317+
if (copy == NULL) {
3318+
Py_DECREF(d);
3319+
return NULL;
3320+
}
3321+
if (dict_merge(copy, d, 1) < 0) {
3322+
Py_DECREF(d);
3323+
Py_DECREF(copy);
3324+
return NULL;
3325+
}
3326+
Py_SETREF(d, copy);
3327+
}
3328+
assert(!PyFrozenDict_Check(d) || can_modify_dict((PyDictObject*)d));
33003329

33013330
if (PyDict_CheckExact(d)) {
33023331
if (PyDict_CheckExact(iterable)) {
@@ -3367,7 +3396,7 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
33673396
dict_iter_exit:;
33683397
Py_END_CRITICAL_SECTION();
33693398
}
3370-
else if (PyFrozenDict_CheckExact(d)) {
3399+
else if (PyFrozenDict_Check(d)) {
33713400
while ((key = PyIter_Next(it)) != NULL) {
33723401
// setitem_take2_lock_held consumes a reference to key
33733402
status = setitem_take2_lock_held((PyDictObject *)d,
@@ -8002,6 +8031,8 @@ frozendict_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
80028031
if (d == NULL) {
80038032
return NULL;
80048033
}
8034+
assert(can_modify_dict(_PyAnyDict_CAST(d)));
8035+
80058036
PyFrozenDictObject *self = _PyFrozenDictObject_CAST(d);
80068037
self->ma_hash = -1;
80078038

0 commit comments

Comments
 (0)