Skip to content

Commit 4e7e015

Browse files
committed
gh-142829: Fix use-after-free in Context.__eq__ via re-entrant ContextVar.set
1 parent 25397f9 commit 4e7e015

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

Lib/test/test_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,20 @@ def fun():
556556

557557
ctx.run(fun)
558558

559+
def test_context_eq_reentrant_contextvar_set(self):
560+
var = contextvars.ContextVar("v")
561+
ctx1 = contextvars.Context()
562+
ctx2 = contextvars.Context()
563+
564+
class ReentrantEq:
565+
def __eq__(self, other):
566+
ctx1.run(lambda: var.set(object()))
567+
return True
568+
569+
ctx1.run(var.set, ReentrantEq())
570+
ctx2.run(var.set, object())
571+
ctx1 == ctx2
572+
559573

560574
# HAMT Tests
561575

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fix a use-after-free crash in :class:`contextvars.Context` comparison when a
2+
custom ``__eq__`` method modifies the context via
3+
:meth:`~contextvars.ContextVar.set`.

Python/hamt.c

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,6 +2328,10 @@ _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w)
23282328
return 0;
23292329
}
23302330

2331+
Py_INCREF(v);
2332+
Py_INCREF(w);
2333+
2334+
int res = 1;
23312335
PyHamtIteratorState iter;
23322336
hamt_iter_t iter_res;
23332337
hamt_find_t find_res;
@@ -2343,25 +2347,38 @@ _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w)
23432347
find_res = hamt_find(w, v_key, &w_val);
23442348
switch (find_res) {
23452349
case F_ERROR:
2346-
return -1;
2350+
res = -1;
2351+
goto done;
23472352

23482353
case F_NOT_FOUND:
2349-
return 0;
2354+
res = 0;
2355+
goto done;
23502356

23512357
case F_FOUND: {
2358+
Py_INCREF(v_key);
2359+
Py_INCREF(v_val);
2360+
Py_INCREF(w_val);
23522361
int cmp = PyObject_RichCompareBool(v_val, w_val, Py_EQ);
2362+
Py_DECREF(v_key);
2363+
Py_DECREF(v_val);
2364+
Py_DECREF(w_val);
23532365
if (cmp < 0) {
2354-
return -1;
2366+
res = -1;
2367+
goto done;
23552368
}
23562369
if (cmp == 0) {
2357-
return 0;
2370+
res = 0;
2371+
goto done;
23582372
}
23592373
}
23602374
}
23612375
}
23622376
} while (iter_res != I_END);
23632377

2364-
return 1;
2378+
done:
2379+
Py_DECREF(v);
2380+
Py_DECREF(w);
2381+
return res;
23652382
}
23662383

23672384
Py_ssize_t

0 commit comments

Comments
 (0)