Skip to content

Commit d279720

Browse files
committed
Use a thread state cache instead of creating a new one for each access.
1 parent 04e3778 commit d279720

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

Include/internal/pycore_pystate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ extern PyThreadState * _PyThreadState_RemoveExcept(PyThreadState *tstate);
231231
extern void _PyThreadState_DeleteList(PyThreadState *list, int is_after_fork);
232232
extern void _PyThreadState_ClearMimallocHeaps(PyThreadState *tstate);
233233

234+
// Export for '_interpreters' shared extension
235+
PyAPI_FUNC(PyThreadState *) _PyThreadState_NewForExec(PyInterpreterState *interp);
236+
234237
// Export for '_testinternalcapi' shared extension
235238
PyAPI_FUNC(PyObject*) _PyThreadState_GetDict(PyThreadState *tstate);
236239

Lib/test/test_interpreters/test_object_proxy.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import_helper.import_module("_interpreters")
88
from concurrent.interpreters import share, SharedObjectProxy
99
from test.test_interpreters.utils import TestBase
10-
from threading import Barrier, Thread, Lock
10+
from threading import Barrier, Thread, Lock, local
1111
from concurrent import interpreters
1212
from contextlib import contextmanager
1313

@@ -242,6 +242,20 @@ def thread(interp):
242242

243243
self.run_concurrently(thread, proxy=proxy)
244244

245+
def test_retain_thread_local_variables(self):
246+
thread_local = local()
247+
thread_local.value = 42
248+
249+
def test():
250+
old = thread_local.value
251+
thread_local.value = 24
252+
return old
253+
254+
proxy = share(test)
255+
with self.create_interp(proxy=proxy) as interp:
256+
interp.exec("assert proxy() == 42")
257+
self.assertEqual(thread_local.value, 24)
258+
245259

246260
if __name__ == "__main__":
247261
unittest.main()

Modules/_interpretersmodule.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,7 @@ _sharedobjectproxy_enter(SharedObjectProxy *self, _PyXI_proxy_state *state)
414414
return 0;
415415
}
416416
state->to_restore = tstate;
417-
PyThreadState *for_call = _PyThreadState_NewBound(self->interp,
418-
_PyThreadState_WHENCE_EXEC);
417+
PyThreadState *for_call = _PyThreadState_NewForExec(self->interp);
419418
state->for_call = for_call;
420419
if (for_call == NULL) {
421420
PyErr_NoMemory();
@@ -446,9 +445,10 @@ _sharedobjectproxy_exit(SharedObjectProxy *self, _PyXI_proxy_state *state)
446445
}
447446

448447
assert(state->for_call == _PyThreadState_GET());
449-
PyThreadState_Clear(state->for_call);
450448
PyThreadState_Swap(state->to_restore);
451-
PyThreadState_Delete(state->for_call);
449+
// If we created a new thread state, we don't want to delete it.
450+
// It's likely that it will be used again, but if not, the interpreter
451+
// will clean it up at the end anyway.
452452

453453
if (should_throw) {
454454
_PyErr_SetString(state->to_restore, PyExc_RuntimeError, "exception in interpreter");

Python/pystate.c

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,56 @@ _Py_thread_local PyThreadState *_Py_tss_gilstate = NULL;
8080
and is same as tstate->interp. */
8181
_Py_thread_local PyInterpreterState *_Py_tss_interp = NULL;
8282

83+
/* The last thread state used for each interpreter by this thread. */
84+
_Py_thread_local _Py_hashtable_t *_Py_tss_tstate_map = NULL;
85+
86+
// TODO: Let's add a way to use _Py_hashtable_t statically to avoid the
87+
// extra heap allocation.
88+
89+
static void
90+
mark_thread_state_used(PyThreadState *tstate)
91+
{
92+
assert(tstate != NULL);
93+
if (_Py_tss_tstate_map == NULL) {
94+
_Py_hashtable_allocator_t alloc = {
95+
.malloc = PyMem_RawMalloc,
96+
.free = PyMem_RawFree
97+
};
98+
_Py_tss_tstate_map = _Py_hashtable_new_full(_Py_hashtable_hash_ptr,
99+
_Py_hashtable_compare_direct,
100+
NULL, NULL, &alloc);
101+
if (_Py_tss_tstate_map == NULL) {
102+
return;
103+
}
104+
}
105+
106+
(void)_Py_hashtable_steal(_Py_tss_tstate_map, tstate->interp);
107+
(void)_Py_hashtable_set(_Py_tss_tstate_map, tstate->interp, tstate);
108+
}
109+
110+
static PyThreadState *
111+
last_thread_state_for_interp(PyInterpreterState *interp)
112+
{
113+
assert(interp != NULL);
114+
if (_Py_tss_tstate_map == NULL) {
115+
return NULL;
116+
}
117+
118+
return _Py_hashtable_get(_Py_tss_tstate_map, interp);
119+
}
120+
121+
static void
122+
mark_thread_state_dead(PyThreadState *tstate)
123+
{
124+
if (_Py_tss_tstate_map == NULL) {
125+
return;
126+
}
127+
128+
if (tstate == _Py_hashtable_get(_Py_tss_tstate_map, tstate->interp)) {
129+
(void)_Py_hashtable_steal(_Py_tss_tstate_map, tstate->interp);
130+
}
131+
}
132+
83133
static inline PyThreadState *
84134
current_fast_get(void)
85135
{
@@ -1603,6 +1653,21 @@ _PyThreadState_NewBound(PyInterpreterState *interp, int whence)
16031653
return tstate;
16041654
}
16051655

1656+
/* Get the last thread state used for this interpreter, or create a new
1657+
* one if none exists.
1658+
* The thread state returned by this may or may not be attached. */
1659+
PyThreadState *
1660+
_PyThreadState_NewForExec(PyInterpreterState *interp)
1661+
{
1662+
assert(interp != NULL);
1663+
PyThreadState *cached = last_thread_state_for_interp(interp);
1664+
if (cached != NULL) {
1665+
return cached;
1666+
}
1667+
1668+
return _PyThreadState_NewBound(interp, _PyThreadState_WHENCE_EXEC);
1669+
}
1670+
16061671
// This must be followed by a call to _PyThreadState_Bind();
16071672
PyThreadState *
16081673
_PyThreadState_New(PyInterpreterState *interp, int whence)
@@ -1649,6 +1714,7 @@ PyThreadState_Clear(PyThreadState *tstate)
16491714
// disabled.
16501715
// XXX assert(!_PyThreadState_IsRunningMain(tstate));
16511716
// XXX assert(!tstate->_status.bound || tstate->_status.unbound);
1717+
mark_thread_state_dead(tstate);
16521718
tstate->_status.finalizing = 1; // just in case
16531719

16541720
/* XXX Conditions we need to enforce:
@@ -1961,7 +2027,6 @@ _PyThreadState_DeleteList(PyThreadState *list, int is_after_fork)
19612027
}
19622028
}
19632029

1964-
19652030
//----------
19662031
// accessors
19672032
//----------
@@ -2168,6 +2233,8 @@ _PyThreadState_Attach(PyThreadState *tstate)
21682233
#if defined(Py_DEBUG)
21692234
errno = err;
21702235
#endif
2236+
2237+
mark_thread_state_used(tstate);
21712238
}
21722239

21732240
static void

0 commit comments

Comments
 (0)