Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5277,6 +5277,20 @@ def msg_cb(conn, direction, version, content_type, msg_type, data):
with self.assertRaises(TypeError):
client_context._msg_callback = object()

def test_msg_callback_exception(self):
client_context, server_context, hostname = testing_context()

def msg_cb(conn, direction, version, content_type, msg_type, data):
raise RuntimeError("msg_cb exception")

client_context._msg_callback = msg_cb
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
with self.assertRaisesRegex(RuntimeError, "msg_cb exception"):
s.connect((HOST, server.port))

def test_msg_callback_tls12(self):
client_context, server_context, hostname = testing_context()
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
Expand Down
110 changes: 63 additions & 47 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,6 @@ typedef struct {
enum py_ssl_server_or_client socket_type;
PyObject *owner; /* Python level "owner" passed to servername callback */
PyObject *server_hostname;
/* Some SSL callbacks don't have error reporting. Callback wrappers
* store exception information on the socket. The handshake, read, write,
* and shutdown methods check for chained exceptions.
*/
PyObject *exc;
} PySSLSocket;

#define PySSLSocket_CAST(op) ((PySSLSocket *)(op))
Expand Down Expand Up @@ -657,18 +652,12 @@ fill_and_set_sslerror(_sslmodulestate *state,
PyUnicodeWriter_Discard(writer);
}

static int
PySSL_ChainExceptions(PySSLSocket *sslsock) {
if (sslsock->exc == NULL)
return 0;

_PyErr_ChainExceptions1(sslsock->exc);
sslsock->exc = NULL;
return -1;
}

// Set the appropriate SSL error exception.
// err - error information from SSL and libc
// exc - if not NULL, an exception from _debughelpers.c callback to be chained
static PyObject *
PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, const char *filename, int lineno)
PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, PyObject *exc,
const char *filename, int lineno)
{
PyObject *type;
char *errstr = NULL;
Expand Down Expand Up @@ -776,7 +765,7 @@ PySSL_SetError(PySSLSocket *sslsock, _PySSLError err, const char *filename, int
}
fill_and_set_sslerror(state, sslsock, type, p, errstr, lineno, e);
ERR_clear_error();
PySSL_ChainExceptions(sslsock);
_PyErr_ChainExceptions1(exc); // chain any exceptions from callbacks
return NULL;
}

Expand Down Expand Up @@ -908,7 +897,6 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
self->shutdown_seen_zero = 0;
self->owner = NULL;
self->server_hostname = NULL;
self->exc = NULL;

/* Make sure the SSL error state is initialized */
ERR_clear_error();
Expand Down Expand Up @@ -1029,6 +1017,7 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
{
int ret;
_PySSLError err;
PyObject *exc = NULL;
int sockstate, nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
Expand Down Expand Up @@ -1064,6 +1053,12 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;

// Get any exception that occurred in a debughelpers.c callback
exc = PyErr_GetRaisedException();
if (exc != NULL) {
break;
}

if (PyErr_CheckSignals())
goto error;

Expand Down Expand Up @@ -1098,13 +1093,14 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
Py_XDECREF(sock);

if (ret < 1)
return PySSL_SetError(self, err, __FILE__, __LINE__);
if (PySSL_ChainExceptions(self) < 0)
return PySSL_SetError(self, err, exc, __FILE__, __LINE__);
if (exc != NULL) {
PyErr_SetRaisedException(exc);
return NULL;
}
Py_RETURN_NONE;
error:
Py_XDECREF(sock);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might add: assert(exc == NULL);.

PySSL_ChainExceptions(self);
return NULL;
}

Expand Down Expand Up @@ -2434,17 +2430,7 @@ _ssl__SSLSocket_owner_set_impl(PySSLSocket *self, PyObject *value)
static int
PySSL_traverse(PyObject *op, visitproc visit, void *arg)
{
PySSLSocket *self = PySSLSocket_CAST(op);
Py_VISIT(self->exc);
Py_VISIT(Py_TYPE(self));
return 0;
}

static int
PySSL_clear(PyObject *op)
{
PySSLSocket *self = PySSLSocket_CAST(op);
Py_CLEAR(self->exc);
Py_VISIT(Py_TYPE(op));
return 0;
}

Expand Down Expand Up @@ -2619,6 +2605,7 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset,
Py_ssize_t retval;
int sockstate;
_PySSLError err;
PyObject *exc = NULL;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
int has_timeout;
Expand Down Expand Up @@ -2666,6 +2653,11 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset,
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;

exc = PyErr_GetRaisedException();
if (exc != NULL) {
break;
}

if (PyErr_CheckSignals()) {
goto error;
}
Expand Down Expand Up @@ -2715,15 +2707,18 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset,
}
Py_XDECREF(sock);
if (retval < 0) {
return PySSL_SetError(self, err, __FILE__, __LINE__);
return PySSL_SetError(self, err, exc, __FILE__, __LINE__);
}
if (PySSL_ChainExceptions(self) < 0) {
if (exc != NULL) {
PyErr_SetRaisedException(exc);
return NULL;
}
return PyLong_FromSize_t(retval);
error:
Py_XDECREF(sock);
(void)PySSL_ChainExceptions(self);
if (exc != NULL) {
_PyErr_ChainExceptions1(exc);
}
return NULL;
}
#endif /* BIO_get_ktls_send */
Expand All @@ -2747,6 +2742,7 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
int retval;
int sockstate;
_PySSLError err;
PyObject *exc = NULL;
int nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
Expand Down Expand Up @@ -2797,6 +2793,11 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;

exc = PyErr_GetRaisedException();
if (exc != NULL) {
break;
}

if (PyErr_CheckSignals())
goto error;

Expand Down Expand Up @@ -2828,13 +2829,14 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)

Py_XDECREF(sock);
if (retval == 0)
return PySSL_SetError(self, err, __FILE__, __LINE__);
if (PySSL_ChainExceptions(self) < 0)
return PySSL_SetError(self, err, exc, __FILE__, __LINE__);
if (exc != NULL) {
PyErr_SetRaisedException(exc);
return NULL;
}
return PyLong_FromSize_t(count);
error:
Py_XDECREF(sock);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might add: assert(exc == NULL);.

PySSL_ChainExceptions(self);
return NULL;
}

Expand All @@ -2860,7 +2862,7 @@ _ssl__SSLSocket_pending_impl(PySSLSocket *self)
_PySSL_FIX_ERRNO;

if (count < 0)
return PySSL_SetError(self, err, __FILE__, __LINE__);
return PySSL_SetError(self, err, NULL, __FILE__, __LINE__);
else
return PyLong_FromLong(count);
}
Expand Down Expand Up @@ -2888,6 +2890,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
int retval;
int sockstate;
_PySSLError err;
PyObject *exc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other functions:

Suggested change
PyObject *exc;
PyObject *exc = NULL;

int nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
Expand Down Expand Up @@ -2955,6 +2958,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;

exc = PyErr_GetRaisedException();
if (exc != NULL) {
break;
}

if (PyErr_CheckSignals())
goto error;

Expand Down Expand Up @@ -2986,11 +2994,13 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
err.ssl == SSL_ERROR_WANT_WRITE);

if (retval == 0) {
PySSL_SetError(self, err, __FILE__, __LINE__);
PySSL_SetError(self, err, exc, __FILE__, __LINE__);
goto error;
}
if (self->exc != NULL)
else if (exc != NULL) {
PyErr_SetRaisedException(exc);
goto error;
}

done:
Py_XDECREF(sock);
Expand All @@ -3002,7 +3012,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
}

error:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might add: assert(exc == NULL); here and in the done: label.

PySSL_ChainExceptions(self);
Py_XDECREF(sock);
if (!group_right_1) {
PyBytesWriter_Discard(writer);
Expand All @@ -3022,6 +3031,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
/*[clinic end generated code: output=ca1aa7ed9d25ca42 input=98d9635cd4e16514]*/
{
_PySSLError err;
PyObject *exc = NULL;
int sockstate, nonblocking, ret;
int zeros = 0;
PySocketSockObject *sock = GET_SOCKET(self);
Expand Down Expand Up @@ -3067,6 +3077,11 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;

exc = PyErr_GetRaisedException();
if (exc != NULL) {
break;
}

/* If err == 1, a secure shutdown with SSL_shutdown() is complete */
if (ret > 0)
break;
Expand Down Expand Up @@ -3113,11 +3128,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
}
if (ret < 0) {
Py_XDECREF(sock);
PySSL_SetError(self, err, __FILE__, __LINE__);
PySSL_SetError(self, err, exc, __FILE__, __LINE__);
return NULL;
}
else if (exc != NULL) {
Py_XDECREF(sock);
PyErr_SetRaisedException(exc);
return NULL;
}
if (self->exc != NULL)
goto error;
if (sock)
/* It's already INCREF'ed */
return (PyObject *) sock;
Expand All @@ -3126,7 +3144,6 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)

error:
Py_XDECREF(sock);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might add: assert(exc == NULL);.

PySSL_ChainExceptions(self);
return NULL;
}

Expand Down Expand Up @@ -3335,7 +3352,6 @@ static PyType_Slot PySSLSocket_slots[] = {
{Py_tp_getset, ssl_getsetlist},
{Py_tp_dealloc, PySSL_dealloc},
{Py_tp_traverse, PySSL_traverse},
{Py_tp_clear, PySSL_clear},
{0, 0},
};

Expand Down
19 changes: 13 additions & 6 deletions Modules/_ssl/debughelpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ _PySSL_msg_callback(int write_p, int version, int content_type,
return;
}

PyObject *exc = PyErr_GetRaisedException();

PyObject *ssl_socket; /* ssl.SSLSocket or ssl.SSLObject */
if (ssl_obj->owner)
PyWeakref_GetRef(ssl_obj->owner, &ssl_socket);
Expand Down Expand Up @@ -73,13 +75,13 @@ _PySSL_msg_callback(int write_p, int version, int content_type,
version, content_type, msg_type,
buf, len
);
if (res == NULL) {
ssl_obj->exc = PyErr_GetRaisedException();
} else {
Py_DECREF(res);
}
Py_XDECREF(res);
Py_XDECREF(ssl_socket);

if (exc != NULL) {
_PyErr_ChainExceptions1(exc);
}

PyGILState_Release(threadstate);
}

Expand Down Expand Up @@ -122,10 +124,13 @@ _PySSL_keylog_callback(const SSL *ssl, const char *line)
{
PyGILState_STATE threadstate;
PySSLSocket *ssl_obj = NULL; /* ssl._SSLSocket, borrowed ref */
PyObject *exc;
int res, e;

threadstate = PyGILState_Ensure();

exc = PyErr_GetRaisedException();

ssl_obj = (PySSLSocket *)SSL_get_app_data(ssl);
assert(Py_IS_TYPE(ssl_obj, get_state_sock(ssl_obj)->PySSLSocket_Type));
PyThread_type_lock lock = get_state_sock(ssl_obj)->keylog_lock;
Expand Down Expand Up @@ -153,10 +158,12 @@ _PySSL_keylog_callback(const SSL *ssl, const char *line)
errno = e;
PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError,
ssl_obj->ctx->keylog_filename);
ssl_obj->exc = PyErr_GetRaisedException();
}

done:
if (exc != NULL) {
_PyErr_ChainExceptions1(exc);
}
PyGILState_Release(threadstate);
}

Expand Down
Loading