Skip to content

Commit 4dd0652

Browse files
committed
gh-142830: prevent crashes when replacing sqlite3 callbacks
1 parent c2f25b5 commit 4dd0652

File tree

3 files changed

+135
-7
lines changed

3 files changed

+135
-7
lines changed

Lib/test/test_sqlite3/test_dbapi.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,5 +2029,104 @@ def test_row_is_a_sequence(self):
20292029
self.assertIsInstance(row, Sequence)
20302030

20312031

2032+
class CallbackTests(unittest.TestCase):
2033+
2034+
def setUp(self):
2035+
super().setUp()
2036+
self.cx = sqlite.connect(":memory:")
2037+
self.addCleanup(self.cx.close)
2038+
self.cu = self.cx.cursor()
2039+
self.cu.execute("create table test(a number)")
2040+
2041+
class Handler:
2042+
cx = self.cx
2043+
2044+
self.handler_class = Handler
2045+
2046+
def assert_not_authorized(self, func, /, *args, **kwargs):
2047+
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
2048+
func(*args, **kwargs)
2049+
2050+
def assert_interrupted(self, func, /, *args, **kwargs):
2051+
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
2052+
func(*args, **kwargs)
2053+
2054+
def assert_invalid_trace(self, func, /, *args, **kwargs):
2055+
# Exception in trace callbacks are entirely suppressed.
2056+
pass
2057+
2058+
# When a handler has an invalid signature, the exception raised is
2059+
# the same that would be raised if the handler "negatively" replied.
2060+
2061+
def test_authorizer_invalid_signature(self):
2062+
self.cx.set_authorizer(lambda: None)
2063+
self.assert_not_authorized(self.cx.execute, "select * from test")
2064+
2065+
def test_progress_handler_invalid_signature(self):
2066+
self.cx.set_progress_handler(lambda x: None, 1)
2067+
self.assert_interrupted(self.cx.execute, "select * from test")
2068+
2069+
def test_trace_callback_invalid_signature_traceback(self):
2070+
self.cx.set_trace_callback(lambda: None)
2071+
self.assert_invalid_trace(self.cx.execute, "select * from test")
2072+
2073+
# Tests for checking that callback context mutations do not crash.
2074+
# Regression tests for https://github.com/python/cpython/issues/142830.
2075+
2076+
def test_authorizer_concurrent_mutation_in_call(self):
2077+
class Handler(self.handler_class):
2078+
def __call__(self, *a, **kw):
2079+
self.cx.set_authorizer(None)
2080+
raise ValueError
2081+
2082+
self.cx.set_authorizer(Handler())
2083+
self.assert_not_authorized(self.cx.execute, "select * from test")
2084+
2085+
def test_authorizer_concurrent_mutation_with_overflown_value(self):
2086+
_testcapi = import_helper.import_module("_testcapi")
2087+
2088+
class Handler(self.handler_class):
2089+
def __call__(self, *a, **kw):
2090+
self.cx.set_authorizer(None)
2091+
# We expect 'int' at the C level, so this one will raise
2092+
# when converting via PyLong_Int().
2093+
return _testcapi.INT_MAX + 1
2094+
2095+
self.cx.set_authorizer(Handler())
2096+
self.assert_not_authorized(self.cx.execute, "select * from test")
2097+
2098+
def test_progress_handler_concurrent_mutation_in_call(self):
2099+
class Handler(self.handler_class):
2100+
def __call__(self, *a, **kw):
2101+
self.cx.set_authorizer(None)
2102+
raise ValueError
2103+
2104+
self.cx.set_progress_handler(Handler(), 1)
2105+
self.assert_interrupted(self.cx.execute, "select * from test")
2106+
2107+
def test_progress_handler_concurrent_mutation_in_conversion(self):
2108+
class Handler(self.handler_class):
2109+
def __bool__(self):
2110+
# clear the progress handler
2111+
self.cx.set_progress_handler(None, 1)
2112+
raise ValueError # force PyObject_True() to fail
2113+
2114+
self.cx.set_progress_handler(Handler.__init__, 1)
2115+
self.assert_interrupted(self.cx.execute, "select * from test")
2116+
2117+
def test_trace_callback_concurrent_mutation_in_call(self):
2118+
class Handler:
2119+
def __call__(self, statement):
2120+
# clear the progress handler
2121+
self.cx.set_progress_handler(None, 1)
2122+
raise ValueError
2123+
2124+
self.cx.set_trace_callback(Handler())
2125+
self.assert_invalid_trace(self.cx.execute, "select * from test")
2126+
2127+
# TODO(picnixz): increase test coverage for other callbacks
2128+
# such as 'func', 'step', 'finalize', and 'collation'.
2129+
2130+
20322131
if __name__ == "__main__":
20332132
unittest.main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
2+
are mutated during a callback execution. Patch by Bénédikt Tran.

Modules/_sqlite/connection.c

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
914914
if (args) {
915915
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
916916
assert(ctx != NULL);
917+
Py_INCREF(ctx);
917918
py_retval = PyObject_CallObject(ctx->callable, args);
919+
Py_DECREF(ctx);
918920
Py_DECREF(args);
919921
}
920922

@@ -942,6 +944,8 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
942944

943945
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
944946
assert(ctx != NULL);
947+
// Hold a reference to 'ctx' to prevent concurrent mutations.
948+
Py_INCREF(ctx);
945949

946950
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
947951
if (aggregate_instance == NULL) {
@@ -971,6 +975,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
971975
}
972976

973977
function_result = PyObject_CallObject(stepmethod, args);
978+
Py_CLEAR(ctx);
974979
Py_DECREF(args);
975980

976981
if (!function_result) {
@@ -979,6 +984,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
979984
}
980985

981986
error:
987+
Py_XDECREF(ctx);
982988
Py_XDECREF(stepmethod);
983989
Py_XDECREF(function_result);
984990

@@ -1011,8 +1017,10 @@ final_callback(sqlite3_context *context)
10111017

10121018
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
10131019
assert(ctx != NULL);
1020+
Py_INCREF(ctx);
10141021
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
10151022
ctx->state->str_finalize);
1023+
Py_DECREF(ctx);
10161024
Py_DECREF(*aggregate_instance);
10171025

10181026
ok = 0;
@@ -1163,6 +1171,8 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
11631171

11641172
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
11651173
assert(ctx != NULL);
1174+
// Hold a reference to 'ctx' to prevent concurrent mutations.
1175+
Py_INCREF(ctx);
11661176

11671177
int size = sizeof(PyObject *);
11681178
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@@ -1191,9 +1201,11 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
11911201
"user-defined aggregate's 'inverse' method raised error");
11921202
goto exit;
11931203
}
1204+
Py_CLEAR(ctx);
11941205
Py_DECREF(res);
11951206

11961207
exit:
1208+
Py_XDECREF(ctx);
11971209
Py_XDECREF(method);
11981210
PyGILState_Release(gilstate);
11991211
}
@@ -1217,7 +1229,10 @@ value_callback(sqlite3_context *context)
12171229
assert(cls != NULL);
12181230
assert(*cls != NULL);
12191231

1232+
Py_INCREF(ctx);
12201233
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
1234+
Py_DECREF(ctx);
1235+
12211236
if (res == NULL) {
12221237
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
12231238
set_sqlite_error(context, attr_err
@@ -1360,10 +1375,11 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,
13601375

13611376
assert(ctx_vp != NULL);
13621377
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
1363-
PyObject *callable = ctx->callable;
1364-
ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
1365-
access_attempt_source);
1378+
// Hold a reference to 'ctx' to prevent concurrent mutations.
1379+
Py_INCREF(ctx);
13661380

1381+
ret = PyObject_CallFunction(ctx->callable, "issss", action, arg1, arg2,
1382+
dbname, access_attempt_source);
13671383
if (ret == NULL) {
13681384
print_or_clear_traceback(ctx);
13691385
rc = SQLITE_DENY;
@@ -1381,6 +1397,7 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,
13811397
}
13821398
Py_DECREF(ret);
13831399
}
1400+
Py_DECREF(ctx);
13841401

13851402
PyGILState_Release(gilstate);
13861403
return rc;
@@ -1396,8 +1413,10 @@ progress_callback(void *ctx_vp)
13961413

13971414
assert(ctx_vp != NULL);
13981415
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
1399-
PyObject *callable = ctx->callable;
1400-
ret = PyObject_CallNoArgs(callable);
1416+
// Hold a reference to 'ctx' to prevent concurrent mutations.
1417+
Py_INCREF(ctx);
1418+
1419+
ret = PyObject_CallNoArgs(ctx->callable);
14011420
if (!ret) {
14021421
/* abort query if error occurred */
14031422
rc = -1;
@@ -1409,7 +1428,7 @@ progress_callback(void *ctx_vp)
14091428
if (rc < 0) {
14101429
print_or_clear_traceback(ctx);
14111430
}
1412-
1431+
Py_DECREF(ctx);
14131432
PyGILState_Release(gilstate);
14141433
return rc;
14151434
}
@@ -1455,7 +1474,9 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
14551474
sqlite3_free((void *)expanded_sql);
14561475
}
14571476
if (py_statement) {
1477+
Py_INCREF(ctx);
14581478
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
1479+
Py_DECREF(ctx);
14591480
Py_DECREF(py_statement);
14601481
Py_XDECREF(ret);
14611482
}
@@ -1889,6 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
18891910
{
18901911
PyGILState_STATE gilstate = PyGILState_Ensure();
18911912

1913+
pysqlite_CallbackContext *ctx = NULL;
18921914
PyObject* string1 = 0;
18931915
PyObject* string2 = 0;
18941916
PyObject* retval = NULL;
@@ -1910,8 +1932,11 @@ collation_callback(void *context, int text1_length, const void *text1_data,
19101932
goto finally;
19111933
}
19121934

1913-
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context);
1935+
ctx = pysqlite_CallbackContext_CAST(context);
19141936
assert(ctx != NULL);
1937+
// Hold a reference to 'ctx' to prevent concurrent mutations.
1938+
Py_INCREF(ctx);
1939+
19151940
PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
19161941
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
19171942
retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
@@ -1931,8 +1956,10 @@ collation_callback(void *context, int text1_length, const void *text1_data,
19311956
else if (longval < 0)
19321957
result = -1;
19331958
}
1959+
Py_CLEAR(ctx);
19341960

19351961
finally:
1962+
Py_XDECREF(ctx);
19361963
Py_XDECREF(string1);
19371964
Py_XDECREF(string2);
19381965
Py_XDECREF(retval);

0 commit comments

Comments
 (0)