Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions ext/openssl/ossl_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode,
id_i_session_remove_cb, id_i_npn_select_cb, id_i_npn_protocols,
id_i_alpn_select_cb, id_i_alpn_protocols, id_i_servername_cb,
id_i_verify_hostname, id_i_keylog_cb, id_i_tmp_dh_callback;
static ID id_i_io, id_i_context, id_i_hostname;
static ID id_i_io, id_i_context, id_i_hostname, id_i_sync_close;

static int ossl_ssl_ex_ptr_idx;
static int ossl_sslctx_ex_ptr_idx;
Expand Down Expand Up @@ -1590,39 +1590,42 @@ ossl_ssl_s_alloc(VALUE klass)
}

static VALUE
peer_ip_address(VALUE self)
peer_ip_address(VALUE io)
{
VALUE remote_address = rb_funcall(rb_attr_get(self, id_i_io), rb_intern("remote_address"), 0);
VALUE remote_address = rb_funcall(io, rb_intern("remote_address"), 0);

return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0);
}

static VALUE
fallback_peer_ip_address(VALUE self, VALUE args)
fallback_peer_ip_address(VALUE self, VALUE exc)
{
return rb_str_new_cstr("(null)");
}

static VALUE
peeraddr_ip_str(VALUE self)
peeraddr_ip_str(VALUE io)
{
VALUE rb_mErrno = rb_const_get(rb_cObject, rb_intern("Errno"));
VALUE rb_eSystemCallError = rb_const_get(rb_mErrno, rb_intern("SystemCallError"));

return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL);
return rb_rescue2(peer_ip_address, io, fallback_peer_ip_address, Qnil,
rb_eSystemCallError, (VALUE)0);
}

/*
* call-seq:
* SSLSocket.new(io) => aSSLSocket
* SSLSocket.new(io, ctx) => aSSLSocket
* SSLSocket.new(io, ctx, sync_close:) => aSSLSocket
*
* Creates a new SSL socket from _io_ which must be a real IO object (not an
* IO-like object that responds to read/write).
*
* If _ctx_ is provided the SSL Sockets initial params will be taken from
* the context.
*
* The optional _sync_close_ keyword parameter sets the _sync_close_ instance
* variable. Setting this to +true+ will cause the underlying socket to be
* closed when the SSL/TLS connection is shut down.
*
* The OpenSSL::Buffering module provides additional IO methods.
*
* This method will freeze the SSLContext if one is provided;
Expand All @@ -1631,6 +1634,10 @@ peeraddr_ip_str(VALUE self)
static VALUE
ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
{
static ID kw_ids[1];
VALUE kw_args[1];
VALUE opts;

VALUE io, v_ctx;
SSL *ssl;
SSL_CTX *ctx;
Expand All @@ -1639,9 +1646,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
if (ssl)
ossl_raise(eSSLError, "SSL already initialized");

if (rb_scan_args(argc, argv, "11", &io, &v_ctx) == 1)
if (rb_scan_args(argc, argv, "11:", &io, &v_ctx, &opts) == 1)
v_ctx = rb_funcall(cSSLContext, rb_intern("new"), 0);

if (!kw_ids[0]) {
kw_ids[0] = rb_intern_const("sync_close");
}

rb_get_kwargs(opts, kw_ids, 0, 1, kw_args);
if (kw_args[0] != Qundef) {
rb_ivar_set(self, id_i_sync_close, kw_args[0]);
}

GetSSLCTX(v_ctx, ctx);
rb_ivar_set(self, id_i_context, v_ctx);
ossl_sslctx_setup(v_ctx);
Expand Down Expand Up @@ -1696,11 +1712,15 @@ ossl_ssl_setup(VALUE self)
return Qtrue;
}

static int
errno_mapped(void)
{
#ifdef _WIN32
#define ssl_get_error(ssl, ret) (errno = rb_w32_map_errno(WSAGetLastError()), SSL_get_error((ssl), (ret)))
return rb_w32_map_errno(WSAGetLastError());
#else
#define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret))
return errno;
#endif
}

static void
write_would_block(int nonblock)
Expand Down Expand Up @@ -1741,35 +1761,34 @@ static void
io_wait_writable(VALUE io)
{
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
if (!rb_io_wait(io, INT2NUM(RUBY_IO_WRITABLE), RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
}
#else
rb_io_t *fptr;
GetOpenFile(io, fptr);
rb_io_wait_writable(fptr->fd);
rb_thread_fd_writable(fptr->fd);
#endif
}

static void
io_wait_readable(VALUE io)
{
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
if (!rb_io_wait(io, INT2NUM(RUBY_IO_READABLE), RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
}
#else
rb_io_t *fptr;
GetOpenFile(io, fptr);
rb_io_wait_readable(fptr->fd);
rb_thread_wait_fd(fptr->fd);
#endif
}

static VALUE
ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
{
SSL *ssl;
int ret, ret2;
VALUE cb_state;
int nonblock = opts != Qfalse;

Expand All @@ -1779,7 +1798,8 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)

VALUE io = rb_attr_get(self, id_i_io);
for (;;) {
ret = func(ssl);
int ret = func(ssl);
int saved_errno = errno_mapped();

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
Expand All @@ -1791,7 +1811,8 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
if (ret > 0)
break;

switch ((ret2 = ssl_get_error(ssl, ret))) {
int code = SSL_get_error(ssl, ret);
switch (code) {
case SSL_ERROR_WANT_WRITE:
if (no_exception_p(opts)) { return sym_wait_writable; }
write_would_block(nonblock);
Expand All @@ -1805,10 +1826,11 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
case SSL_ERROR_SYSCALL:
#ifdef __APPLE__
/* See ossl_ssl_write_internal() */
if (errno == EPROTOTYPE)
if (saved_errno == EPROTOTYPE)
continue;
#endif
if (errno) rb_sys_fail(funcname);
if (saved_errno)
rb_exc_raise(rb_syserr_new(saved_errno, funcname));
/* fallthrough */
default: {
VALUE error_append = Qnil;
Expand All @@ -1829,10 +1851,10 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
ossl_raise(eSSLError,
"%s%s returned=%d errno=%d peeraddr=%"PRIsVALUE" state=%s%"PRIsVALUE,
funcname,
ret2 == SSL_ERROR_SYSCALL ? " SYSCALL" : "",
ret2,
errno,
peeraddr_ip_str(self),
code == SSL_ERROR_SYSCALL ? " SYSCALL" : "",
code,
saved_errno,
peeraddr_ip_str(io),
SSL_state_string_long(ssl),
error_append);
}
Expand Down Expand Up @@ -1974,6 +1996,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
for (;;) {
rb_str_locktmp(str);
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);
int saved_errno = errno_mapped();
rb_str_unlocktmp(str);

cb_state = rb_attr_get(self, ID_callback_state);
Expand All @@ -1983,7 +2006,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
rb_jump_tag(NUM2INT(cb_state));
}

switch (ssl_get_error(ssl, nread)) {
switch (SSL_get_error(ssl, nread)) {
case SSL_ERROR_NONE:
rb_str_set_len(str, nread);
return str;
Expand All @@ -2006,8 +2029,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
break;
case SSL_ERROR_SYSCALL:
if (!ERR_peek_error()) {
if (errno)
rb_sys_fail(0);
if (saved_errno)
rb_exc_raise(rb_syserr_new(saved_errno, "SSL_read"));
else {
/*
* The underlying BIO returned 0. This is actually a
Expand Down Expand Up @@ -2092,6 +2115,7 @@ ossl_ssl_write_internal_safe(VALUE _args)

for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(str), num);
int saved_errno = errno_mapped();

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
Expand All @@ -2100,7 +2124,7 @@ ossl_ssl_write_internal_safe(VALUE _args)
rb_jump_tag(NUM2INT(cb_state));
}

switch (ssl_get_error(ssl, nwritten)) {
switch (SSL_get_error(ssl, nwritten)) {
case SSL_ERROR_NONE:
return INT2NUM(nwritten);
case SSL_ERROR_WANT_WRITE:
Expand All @@ -2121,10 +2145,11 @@ ossl_ssl_write_internal_safe(VALUE _args)
* make the error handling in line with the socket library.
* [Bug #14713] https://bugs.ruby-lang.org/issues/14713
*/
if (errno == EPROTOTYPE)
if (saved_errno == EPROTOTYPE)
continue;
#endif
if (errno) rb_sys_fail(0);
if (saved_errno)
rb_exc_raise(rb_syserr_new(saved_errno, "SSL_write"));
/* fallthrough */
default:
ossl_raise(eSSLError, "SSL_write");
Expand Down Expand Up @@ -3300,5 +3325,6 @@ Init_ossl_ssl(void)
DefIVarID(io);
DefIVarID(context);
DefIVarID(hostname);
DefIVarID(sync_close);
#endif /* !defined(OPENSSL_NO_SOCK) */
}
66 changes: 46 additions & 20 deletions test/openssl/test_ssl.rb
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ def test_sync_close
end
end

def test_sync_close_initialize_opt
start_server do |port|
begin
sock = TCPSocket.new("127.0.0.1", port)
ssl = OpenSSL::SSL::SSLSocket.new(sock, sync_close: true)
assert_equal true, ssl.sync_close
ssl.connect
ssl.puts "abc"; assert_equal "abc\n", ssl.gets
ssl.close
assert_predicate sock, :closed?
ensure
sock&.close
end
end
end

def test_copy_stream
start_server do |port|
server_connect(port) do |ssl|
Expand Down Expand Up @@ -1064,36 +1080,46 @@ def test_tlsext_hostname
end
end

def test_servername_cb_raises_an_exception_on_unknown_objects
hostname = 'example.org'

ctx2 = OpenSSL::SSL::SSLContext.new
ctx2.cert = @svr_cert
ctx2.key = @svr_key
ctx2.servername_cb = lambda { |args| Object.new }

def test_servername_cb_exception
sock1, sock2 = socketpair

t = Thread.new {
s1 = OpenSSL::SSL::SSLSocket.new(sock1)
s1.hostname = "localhost"
assert_raise_with_message(OpenSSL::SSL::SSLError, /unrecognized.name/i) {
s1.connect
}
}

ctx2 = OpenSSL::SSL::SSLContext.new
ctx2.servername_cb = lambda { |args| raise RuntimeError, "foo" }
s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2)
assert_raise_with_message(RuntimeError, "foo") { s2.accept }
assert t.join
ensure
sock1.close
sock2.close
t.kill.join
end

ctx1 = OpenSSL::SSL::SSLContext.new
def test_servername_cb_raises_an_exception_on_unknown_objects
sock1, sock2 = socketpair

s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1)
s1.hostname = hostname
t = Thread.new {
assert_raise(OpenSSL::SSL::SSLError) do
s1.connect
end
s1 = OpenSSL::SSL::SSLSocket.new(sock1)
s1.hostname = "localhost"
assert_raise(OpenSSL::SSL::SSLError) { s1.connect }
}

assert_raise(ArgumentError) do
s2.accept
end

ctx2 = OpenSSL::SSL::SSLContext.new
ctx2.servername_cb = lambda { |args| Object.new }
s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2)
assert_raise(ArgumentError) { s2.accept }
assert t.join
ensure
sock1.close if sock1
sock2.close if sock2
sock1.close
sock2.close
t.kill.join
end

def test_accept_errors_include_peeraddr
Expand Down
8 changes: 7 additions & 1 deletion thread_pthread_mn.c
Original file line number Diff line number Diff line change
Expand Up @@ -617,11 +617,17 @@ kqueue_wait(rb_vm_t *vm)
struct timespec *timeout = NULL;
int timeout_ms = timer_thread_set_timeout(vm);

if (timeout_ms >= 0) {
if (timeout_ms > 0) {
calculated_timeout.tv_sec = timeout_ms / 1000;
calculated_timeout.tv_nsec = (timeout_ms % 1000) * 1000000;
timeout = &calculated_timeout;
}
else if (timeout_ms == 0) {
// Relying on the absence of other members of struct timespec is not strictly portable,
// and kevent needs a 0-valued timespec to mean immediate timeout.
memset(&calculated_timeout, 0, sizeof(struct timespec));
timeout = &calculated_timeout;
}

return kevent(timer_th.event_fd, NULL, 0, timer_th.finished_events, KQUEUE_EVENTS_MAX, timeout);
}
Expand Down