diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index ebdf5455163c65..9dc99fbf5cf7d2 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -5277,6 +5277,20 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): with self.assertRaises(TypeError): client_context._msg_callback = object() + def test_msg_callback_exception(self): + client_context, server_context, hostname = testing_context() + + def msg_cb(conn, direction, version, content_type, msg_type, data): + raise RuntimeError("msg_cb exception") + + client_context._msg_callback = msg_cb + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + with self.assertRaisesRegex(RuntimeError, "msg_cb exception"): + s.connect((HOST, server.port)) + def test_msg_callback_tls12(self): client_context, server_context, hostname = testing_context() client_context.maximum_version = ssl.TLSVersion.TLSv1_2 diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 5d2f075ed0c675..7dd57e7892af41 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -360,11 +360,6 @@ typedef struct { enum py_ssl_server_or_client socket_type; PyObject *owner; /* Python level "owner" passed to servername callback */ PyObject *server_hostname; - /* Some SSL callbacks don't have error reporting. Callback wrappers - * store exception information on the socket. The handshake, read, write, - * and shutdown methods check for chained exceptions. - */ - PyObject *exc; } PySSLSocket; #define PySSLSocket_CAST(op) ((PySSLSocket *)(op)) @@ -657,18 +652,12 @@ fill_and_set_sslerror(_sslmodulestate *state, PyUnicodeWriter_Discard(writer); } -static int -PySSL_ChainExceptions(PySSLSocket *sslsock) { - if (sslsock->exc == NULL) - return 0; - - _PyErr_ChainExceptions1(sslsock->exc); - sslsock->exc = NULL; - return -1; -} - +// Set the appropriate SSL error exception. +// err - error information from SSL and libc +// exc - if not NULL, an exception from _debughelpers.c callback to be chained static PyObject * -PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, const char *filename, int lineno) +PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, PyObject *exc, + const char *filename, int lineno) { PyObject *type; char *errstr = NULL; @@ -776,7 +765,7 @@ PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, const char *filename, int } fill_and_set_sslerror(state, sslsock, type, p, errstr, lineno, e); ERR_clear_error(); - PySSL_ChainExceptions(sslsock); + _PyErr_ChainExceptions1(exc); // chain any exceptions from callbacks return NULL; } @@ -908,7 +897,6 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, self->shutdown_seen_zero = 0; self->owner = NULL; self->server_hostname = NULL; - self->exc = NULL; /* Make sure the SSL error state is initialized */ ERR_clear_error(); @@ -1029,6 +1017,7 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self) { int ret; _PySSLError err; + PyObject *exc = NULL; int sockstate, nonblocking; PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; @@ -1064,6 +1053,12 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self) Py_END_ALLOW_THREADS; _PySSL_FIX_ERRNO; + // Get any exception that occurred in a debughelpers.c callback + exc = PyErr_GetRaisedException(); + if (exc != NULL) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -1098,13 +1093,15 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self) Py_XDECREF(sock); if (ret < 1) - return PySSL_SetError(self, err, __FILE__, __LINE__); - if (PySSL_ChainExceptions(self) < 0) + return PySSL_SetError(self, err, exc, __FILE__, __LINE__); + if (exc != NULL) { + PyErr_SetRaisedException(exc); return NULL; + } Py_RETURN_NONE; error: + assert(exc == NULL); Py_XDECREF(sock); - PySSL_ChainExceptions(self); return NULL; } @@ -2434,17 +2431,7 @@ _ssl__SSLSocket_owner_set_impl(PySSLSocket *self, PyObject *value) static int PySSL_traverse(PyObject *op, visitproc visit, void *arg) { - PySSLSocket *self = PySSLSocket_CAST(op); - Py_VISIT(self->exc); - Py_VISIT(Py_TYPE(self)); - return 0; -} - -static int -PySSL_clear(PyObject *op) -{ - PySSLSocket *self = PySSLSocket_CAST(op); - Py_CLEAR(self->exc); + Py_VISIT(Py_TYPE(op)); return 0; } @@ -2619,6 +2606,7 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset, Py_ssize_t retval; int sockstate; _PySSLError err; + PyObject *exc = NULL; PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; int has_timeout; @@ -2666,6 +2654,11 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset, Py_END_ALLOW_THREADS; _PySSL_FIX_ERRNO; + exc = PyErr_GetRaisedException(); + if (exc != NULL) { + break; + } + if (PyErr_CheckSignals()) { goto error; } @@ -2715,15 +2708,18 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset, } Py_XDECREF(sock); if (retval < 0) { - return PySSL_SetError(self, err, __FILE__, __LINE__); + return PySSL_SetError(self, err, exc, __FILE__, __LINE__); } - if (PySSL_ChainExceptions(self) < 0) { + if (exc != NULL) { + PyErr_SetRaisedException(exc); return NULL; } return PyLong_FromSize_t(retval); error: Py_XDECREF(sock); - (void)PySSL_ChainExceptions(self); + if (exc != NULL) { + _PyErr_ChainExceptions1(exc); + } return NULL; } #endif /* BIO_get_ktls_send */ @@ -2747,6 +2743,7 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) int retval; int sockstate; _PySSLError err; + PyObject *exc = NULL; int nonblocking; PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; @@ -2797,6 +2794,11 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) Py_END_ALLOW_THREADS; _PySSL_FIX_ERRNO; + exc = PyErr_GetRaisedException(); + if (exc != NULL) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -2828,13 +2830,15 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) Py_XDECREF(sock); if (retval == 0) - return PySSL_SetError(self, err, __FILE__, __LINE__); - if (PySSL_ChainExceptions(self) < 0) + return PySSL_SetError(self, err, exc, __FILE__, __LINE__); + if (exc != NULL) { + PyErr_SetRaisedException(exc); return NULL; + } return PyLong_FromSize_t(count); error: + assert(exc == NULL); Py_XDECREF(sock); - PySSL_ChainExceptions(self); return NULL; } @@ -2860,7 +2864,7 @@ _ssl__SSLSocket_pending_impl(PySSLSocket *self) _PySSL_FIX_ERRNO; if (count < 0) - return PySSL_SetError(self, err, __FILE__, __LINE__); + return PySSL_SetError(self, err, NULL, __FILE__, __LINE__); else return PyLong_FromLong(count); } @@ -2888,6 +2892,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, int retval; int sockstate; _PySSLError err; + PyObject *exc = NULL; int nonblocking; PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; @@ -2955,6 +2960,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, Py_END_ALLOW_THREADS; _PySSL_FIX_ERRNO; + exc = PyErr_GetRaisedException(); + if (exc != NULL) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -2986,13 +2996,18 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, err.ssl == SSL_ERROR_WANT_WRITE); if (retval == 0) { - PySSL_SetError(self, err, __FILE__, __LINE__); + PySSL_SetError(self, err, exc, __FILE__, __LINE__); + exc = NULL; goto error; } - if (self->exc != NULL) + else if (exc != NULL) { + PyErr_SetRaisedException(exc); + exc = NULL; goto error; + } done: + assert(exc == NULL); Py_XDECREF(sock); if (!group_right_1) { return PyBytesWriter_FinishWithSize(writer, count); @@ -3002,7 +3017,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, } error: - PySSL_ChainExceptions(self); + assert(exc == NULL); Py_XDECREF(sock); if (!group_right_1) { PyBytesWriter_Discard(writer); @@ -3022,6 +3037,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) /*[clinic end generated code: output=ca1aa7ed9d25ca42 input=98d9635cd4e16514]*/ { _PySSLError err; + PyObject *exc = NULL; int sockstate, nonblocking, ret; int zeros = 0; PySocketSockObject *sock = GET_SOCKET(self); @@ -3067,6 +3083,11 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) Py_END_ALLOW_THREADS; _PySSL_FIX_ERRNO; + exc = PyErr_GetRaisedException(); + if (exc != NULL) { + break; + } + /* If err == 1, a secure shutdown with SSL_shutdown() is complete */ if (ret > 0) break; @@ -3113,11 +3134,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) } if (ret < 0) { Py_XDECREF(sock); - PySSL_SetError(self, err, __FILE__, __LINE__); + PySSL_SetError(self, err, exc, __FILE__, __LINE__); + return NULL; + } + else if (exc != NULL) { + Py_XDECREF(sock); + PyErr_SetRaisedException(exc); return NULL; } - if (self->exc != NULL) - goto error; if (sock) /* It's already INCREF'ed */ return (PyObject *) sock; @@ -3125,8 +3149,8 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) Py_RETURN_NONE; error: + assert(exc == NULL); Py_XDECREF(sock); - PySSL_ChainExceptions(self); return NULL; } @@ -3335,7 +3359,6 @@ static PyType_Slot PySSLSocket_slots[] = { {Py_tp_getset, ssl_getsetlist}, {Py_tp_dealloc, PySSL_dealloc}, {Py_tp_traverse, PySSL_traverse}, - {Py_tp_clear, PySSL_clear}, {0, 0}, }; diff --git a/Modules/_ssl/debughelpers.c b/Modules/_ssl/debughelpers.c index 866c172e4996f7..e0cb7ca9a09f91 100644 --- a/Modules/_ssl/debughelpers.c +++ b/Modules/_ssl/debughelpers.c @@ -26,6 +26,8 @@ _PySSL_msg_callback(int write_p, int version, int content_type, return; } + PyObject *exc = PyErr_GetRaisedException(); + PyObject *ssl_socket; /* ssl.SSLSocket or ssl.SSLObject */ if (ssl_obj->owner) PyWeakref_GetRef(ssl_obj->owner, &ssl_socket); @@ -73,13 +75,13 @@ _PySSL_msg_callback(int write_p, int version, int content_type, version, content_type, msg_type, buf, len ); - if (res == NULL) { - ssl_obj->exc = PyErr_GetRaisedException(); - } else { - Py_DECREF(res); - } + Py_XDECREF(res); Py_XDECREF(ssl_socket); + if (exc != NULL) { + _PyErr_ChainExceptions1(exc); + } + PyGILState_Release(threadstate); } @@ -122,10 +124,13 @@ _PySSL_keylog_callback(const SSL *ssl, const char *line) { PyGILState_STATE threadstate; PySSLSocket *ssl_obj = NULL; /* ssl._SSLSocket, borrowed ref */ + PyObject *exc; int res, e; threadstate = PyGILState_Ensure(); + exc = PyErr_GetRaisedException(); + ssl_obj = (PySSLSocket *)SSL_get_app_data(ssl); assert(Py_IS_TYPE(ssl_obj, get_state_sock(ssl_obj)->PySSLSocket_Type)); PyThread_type_lock lock = get_state_sock(ssl_obj)->keylog_lock; @@ -153,10 +158,12 @@ _PySSL_keylog_callback(const SSL *ssl, const char *line) errno = e; PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, ssl_obj->ctx->keylog_filename); - ssl_obj->exc = PyErr_GetRaisedException(); } done: + if (exc != NULL) { + _PyErr_ChainExceptions1(exc); + } PyGILState_Release(threadstate); }