diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 0c98410245cd..50897003e4c8 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -73,7 +73,7 @@ make_vendor "wasi/src/p3" " clocks@13d1c82@wit-0.3.0-draft filesystem@e2a2ddc@wit-0.3.0-draft random@4e94663@wit-0.3.0-draft - sockets@bb247e2@wit-0.3.0-draft + sockets@e863ee2@wit-0.3.0-draft " rm -rf $cache_dir diff --git a/crates/test-programs/src/bin/p3_sockets_ip_name_lookup.rs b/crates/test-programs/src/bin/p3_sockets_ip_name_lookup.rs new file mode 100644 index 000000000000..94548c2191f1 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_ip_name_lookup.rs @@ -0,0 +1,102 @@ +use futures::join; +use test_programs::p3::wasi::sockets::ip_name_lookup::{ErrorCode, resolve_addresses}; +use test_programs::p3::wasi::sockets::types::IpAddress; + +struct Component; + +test_programs::p3::export!(Component); + +async fn resolve_one(name: &str) -> Result { + Ok(resolve_addresses(name.into()) + .await? + .first() + .unwrap() + .to_owned()) +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + // Valid domains + let (res0, res1) = join!( + resolve_addresses("localhost".into()), + resolve_addresses("example.com".into()) + ); + if res0.is_err() && res1.is_err() { + panic!("should have been able to resolve at least one domain"); + } + + // NB: this is an actual real resolution, so it might time out, might cause + // issues, etc. This result is ignored to prevent flaky failures in CI. + let _ = resolve_addresses("münchen.de".into()).await; + + // Valid IP addresses + assert_eq!( + resolve_one("0.0.0.0").await.unwrap(), + IpAddress::IPV4_UNSPECIFIED + ); + assert_eq!( + resolve_one("127.0.0.1").await.unwrap(), + IpAddress::IPV4_LOOPBACK + ); + assert_eq!( + resolve_one("192.0.2.0").await.unwrap(), + IpAddress::Ipv4((192, 0, 2, 0)) + ); + assert_eq!( + resolve_one("::").await.unwrap(), + IpAddress::IPV6_UNSPECIFIED + ); + assert_eq!(resolve_one("::1").await.unwrap(), IpAddress::IPV6_LOOPBACK); + assert_eq!( + resolve_one("[::]").await.unwrap(), + IpAddress::IPV6_UNSPECIFIED + ); + assert_eq!( + resolve_one("2001:0db8:0:0:0:0:0:0").await.unwrap(), + IpAddress::Ipv6((0x2001, 0x0db8, 0, 0, 0, 0, 0, 0)) + ); + assert_eq!( + resolve_one("dead:beef::").await.unwrap(), + IpAddress::Ipv6((0xdead, 0xbeef, 0, 0, 0, 0, 0, 0)) + ); + assert_eq!( + resolve_one("dead:beef::0").await.unwrap(), + IpAddress::Ipv6((0xdead, 0xbeef, 0, 0, 0, 0, 0, 0)) + ); + assert_eq!( + resolve_one("DEAD:BEEF::0").await.unwrap(), + IpAddress::Ipv6((0xdead, 0xbeef, 0, 0, 0, 0, 0, 0)) + ); + + // Invalid inputs + assert_eq!( + resolve_addresses("".into()).await.unwrap_err(), + ErrorCode::InvalidArgument + ); + assert_eq!( + resolve_addresses(" ".into()).await.unwrap_err(), + ErrorCode::InvalidArgument + ); + assert_eq!( + resolve_addresses("a.b<&>".into()).await.unwrap_err(), + ErrorCode::InvalidArgument + ); + assert_eq!( + resolve_addresses("127.0.0.1:80".into()).await.unwrap_err(), + ErrorCode::InvalidArgument + ); + assert_eq!( + resolve_addresses("[::]:80".into()).await.unwrap_err(), + ErrorCode::InvalidArgument + ); + assert_eq!( + resolve_addresses("http://example.com/".into()) + .await + .unwrap_err(), + ErrorCode::InvalidArgument + ); + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_bind.rs b/crates/test-programs/src/bin/p3_sockets_tcp_bind.rs new file mode 100644 index 000000000000..6045fc1c8dc1 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_bind.rs @@ -0,0 +1,202 @@ +use futures::join; +use test_programs::p3::sockets::attempt_random_port; +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, +}; +use test_programs::p3::wit_stream; +use wit_bindgen::yield_blocking; + +struct Component; + +test_programs::p3::export!(Component); + +/// Bind a socket and let the system determine a port. +fn test_tcp_bind_ephemeral_port(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock = TcpSocket::new(ip.family()); + sock.bind(bind_addr).unwrap(); + + let bound_addr = sock.local_address().unwrap(); + + assert_eq!(bind_addr.ip(), bound_addr.ip()); + assert_ne!(bind_addr.port(), bound_addr.port()); +} + +/// Bind a socket on a specified port. +fn test_tcp_bind_specific_port(ip: IpAddress) { + let sock = TcpSocket::new(ip.family()); + + let bind_addr = attempt_random_port(ip, |bind_addr| sock.bind(bind_addr)).unwrap(); + + let bound_addr = sock.local_address().unwrap(); + + assert_eq!(bind_addr.ip(), bound_addr.ip()); + assert_eq!(bind_addr.port(), bound_addr.port()); +} + +/// Two sockets may not be actively bound to the same address at the same time. +fn test_tcp_bind_addrinuse(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock1 = TcpSocket::new(ip.family()); + sock1.bind(bind_addr).unwrap(); + sock1.listen().unwrap(); + + let bound_addr = sock1.local_address().unwrap(); + + let sock2 = TcpSocket::new(ip.family()); + assert_eq!(sock2.bind(bound_addr), Err(ErrorCode::AddressInUse)); +} + +// The WASI runtime should set SO_REUSEADDR for us +async fn test_tcp_bind_reuseaddr(ip: IpAddress) { + let client = TcpSocket::new(ip.family()); + + let bind_addr = { + let listener1 = TcpSocket::new(ip.family()); + + let bind_addr = attempt_random_port(ip, |bind_addr| listener1.bind(bind_addr)).unwrap(); + + let mut accept = listener1.listen().unwrap(); + + let connect_addr = + IpSocketAddress::new(IpAddress::new_loopback(ip.family()), bind_addr.port()); + join!( + async { + client.connect(connect_addr).await.unwrap(); + }, + async { + let sock = accept.next().await.unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + sock.send(data_rx).await.unwrap(); + }, + async { + let remaining = data_tx.write_all(vec![0; 10]).await; + assert!(remaining.is_empty()); + drop(data_tx); + } + ); + }, + ); + + bind_addr + }; + + // If SO_REUSEADDR was configured correctly, the following lines + // shouldn't be affected by the TIME_WAIT state of the just closed + // `listener1` socket. + // + // Note though that the way things are modeled in Wasmtime right now is that + // the TCP socket is kept alive by a spawned task created in `listen` + // meaning that to fully close the socket it requires the spawned task to + // shut down. That may require yielding to the host or similar so try a few + // times to let the host get around to closing the task while testing each + // time to see if we can reuse the address. This loop is bounded because it + // should complete "quickly". + for _ in 0..10 { + let listener2 = TcpSocket::new(ip.family()); + if listener2.bind(bind_addr).is_ok() { + listener2.listen().unwrap(); + return; + } + yield_blocking(); + } + + panic!("looks like REUSEADDR isn't in use?"); +} + +// Try binding to an address that is not configured on the system. +fn test_tcp_bind_addrnotavail(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock = TcpSocket::new(ip.family()); + + assert_eq!(sock.bind(bind_addr), Err(ErrorCode::AddressNotBindable)); +} + +/// Bind should validate the address family. +fn test_tcp_bind_wrong_family(family: IpAddressFamily) { + let wrong_ip = match family { + IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK, + IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK, + }; + + let sock = TcpSocket::new(family); + let result = sock.bind(IpSocketAddress::new(wrong_ip, 0)); + + assert!(matches!(result, Err(ErrorCode::InvalidArgument))); +} + +/// Bind only works on unicast addresses. +fn test_tcp_bind_non_unicast() { + let ipv4_broadcast = IpSocketAddress::new(IpAddress::IPV4_BROADCAST, 0); + let ipv4_multicast = IpSocketAddress::new(IpAddress::Ipv4((224, 254, 0, 0)), 0); + let ipv6_multicast = IpSocketAddress::new(IpAddress::Ipv6((0xff00, 0, 0, 0, 0, 0, 0, 0)), 0); + + let sock_v4 = TcpSocket::new(IpAddressFamily::Ipv4); + let sock_v6 = TcpSocket::new(IpAddressFamily::Ipv6); + + assert!(matches!( + sock_v4.bind(ipv4_broadcast), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!( + sock_v4.bind(ipv4_multicast), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!( + sock_v6.bind(ipv6_multicast), + Err(ErrorCode::InvalidArgument) + )); +} + +fn test_tcp_bind_dual_stack() { + let sock = TcpSocket::new(IpAddressFamily::Ipv6); + let addr = IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, 0); + + // Binding an IPv4-mapped-IPv6 address on a ipv6-only socket should fail: + assert!(matches!(sock.bind(addr), Err(ErrorCode::InvalidArgument))); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + const RESERVED_IPV4_ADDRESS: IpAddress = IpAddress::Ipv4((192, 0, 2, 0)); // Reserved for documentation and examples. + const RESERVED_IPV6_ADDRESS: IpAddress = + IpAddress::Ipv6((0x2001, 0x0db8, 0, 0, 0, 0, 0, 0)); // Reserved for documentation and examples. + + test_tcp_bind_ephemeral_port(IpAddress::IPV4_LOOPBACK); + test_tcp_bind_ephemeral_port(IpAddress::IPV6_LOOPBACK); + test_tcp_bind_ephemeral_port(IpAddress::IPV4_UNSPECIFIED); + test_tcp_bind_ephemeral_port(IpAddress::IPV6_UNSPECIFIED); + + test_tcp_bind_specific_port(IpAddress::IPV4_LOOPBACK); + test_tcp_bind_specific_port(IpAddress::IPV6_LOOPBACK); + test_tcp_bind_specific_port(IpAddress::IPV4_UNSPECIFIED); + test_tcp_bind_specific_port(IpAddress::IPV6_UNSPECIFIED); + + test_tcp_bind_reuseaddr(IpAddress::IPV4_LOOPBACK).await; + test_tcp_bind_reuseaddr(IpAddress::IPV6_LOOPBACK).await; + + test_tcp_bind_addrinuse(IpAddress::IPV4_LOOPBACK); + test_tcp_bind_addrinuse(IpAddress::IPV6_LOOPBACK); + test_tcp_bind_addrinuse(IpAddress::IPV4_UNSPECIFIED); + test_tcp_bind_addrinuse(IpAddress::IPV6_UNSPECIFIED); + + test_tcp_bind_addrnotavail(RESERVED_IPV4_ADDRESS); + test_tcp_bind_addrnotavail(RESERVED_IPV6_ADDRESS); + + test_tcp_bind_wrong_family(IpAddressFamily::Ipv4); + test_tcp_bind_wrong_family(IpAddressFamily::Ipv6); + + test_tcp_bind_non_unicast(); + + test_tcp_bind_dual_stack(); + + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_connect.rs b/crates/test-programs/src/bin/p3_sockets_tcp_connect.rs new file mode 100644 index 000000000000..a8be14003ea9 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_connect.rs @@ -0,0 +1,146 @@ +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +const SOME_PORT: u16 = 47; // If the tests pass, this will never actually be connected to. + +/// `0.0.0.0` / `::` is not a valid remote address in WASI. +async fn test_tcp_connect_unspec(family: IpAddressFamily) { + let addr = IpSocketAddress::new(IpAddress::new_unspecified(family), SOME_PORT); + let sock = TcpSocket::new(family); + + assert_eq!(sock.connect(addr).await, Err(ErrorCode::InvalidArgument)); +} + +/// 0 is not a valid remote port. +async fn test_tcp_connect_port_0(family: IpAddressFamily) { + let addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock = TcpSocket::new(family); + + assert_eq!(sock.connect(addr).await, Err(ErrorCode::InvalidArgument)); +} + +/// Connect should validate the address family. +async fn test_tcp_connect_wrong_family(family: IpAddressFamily) { + let wrong_ip = match family { + IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK, + IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK, + }; + let remote_addr = IpSocketAddress::new(wrong_ip, SOME_PORT); + + let sock = TcpSocket::new(family); + + assert_eq!( + sock.connect(remote_addr).await, + Err(ErrorCode::InvalidArgument) + ); +} + +/// Can only connect to unicast addresses. +async fn test_tcp_connect_non_unicast() { + let ipv4_broadcast = IpSocketAddress::new(IpAddress::IPV4_BROADCAST, SOME_PORT); + let ipv4_multicast = IpSocketAddress::new(IpAddress::Ipv4((224, 254, 0, 0)), SOME_PORT); + let ipv6_multicast = + IpSocketAddress::new(IpAddress::Ipv6((0xff00, 0, 0, 0, 0, 0, 0, 0)), SOME_PORT); + + let sock_v4 = TcpSocket::new(IpAddressFamily::Ipv4); + let sock_v6 = TcpSocket::new(IpAddressFamily::Ipv6); + + assert_eq!( + sock_v4.connect(ipv4_broadcast).await, + Err(ErrorCode::InvalidArgument) + ); + assert_eq!( + sock_v4.connect(ipv4_multicast).await, + Err(ErrorCode::InvalidArgument) + ); + assert_eq!( + sock_v6.connect(ipv6_multicast).await, + Err(ErrorCode::InvalidArgument) + ); +} + +async fn test_tcp_connect_dual_stack() { + // Set-up: + let v4_listener = TcpSocket::new(IpAddressFamily::Ipv4); + v4_listener + .bind(IpSocketAddress::new(IpAddress::IPV4_LOOPBACK, 0)) + .unwrap(); + v4_listener.listen().unwrap(); + + let v4_listener_addr = v4_listener.local_address().unwrap(); + let v6_listener_addr = + IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, v4_listener_addr.port()); + + let v6_client = TcpSocket::new(IpAddressFamily::Ipv6); + + // Tests: + + // Connecting to an IPv4 address on an IPv6 socket should fail: + assert_eq!( + v6_client.connect(v4_listener_addr).await, + Err(ErrorCode::InvalidArgument) + ); + // Connecting to an IPv4-mapped-IPv6 address on an IPv6 socket should fail: + assert_eq!( + v6_client.connect(v6_listener_addr).await, + Err(ErrorCode::InvalidArgument) + ); +} + +/// Client sockets can be explicitly bound. +async fn test_tcp_connect_explicit_bind(family: IpAddressFamily) { + let ip = IpAddress::new_loopback(family); + + let (listener, mut accept) = { + let bind_address = IpSocketAddress::new(ip, 0); + let listener = TcpSocket::new(family); + listener.bind(bind_address).unwrap(); + let accept = listener.listen().unwrap(); + (listener, accept) + }; + + let listener_address = listener.local_address().unwrap(); + let client = TcpSocket::new(family); + + // Manually bind the client: + client.bind(IpSocketAddress::new(ip, 0)).unwrap(); + + // Connect should work: + join!( + async { + client.connect(listener_address).await.unwrap(); + }, + async { + accept.next().await.unwrap(); + } + ); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_tcp_connect_unspec(IpAddressFamily::Ipv4).await; + test_tcp_connect_unspec(IpAddressFamily::Ipv6).await; + + test_tcp_connect_port_0(IpAddressFamily::Ipv4).await; + test_tcp_connect_port_0(IpAddressFamily::Ipv6).await; + + test_tcp_connect_wrong_family(IpAddressFamily::Ipv4).await; + test_tcp_connect_wrong_family(IpAddressFamily::Ipv6).await; + + test_tcp_connect_non_unicast().await; + + test_tcp_connect_dual_stack().await; + + test_tcp_connect_explicit_bind(IpAddressFamily::Ipv4).await; + test_tcp_connect_explicit_bind(IpAddressFamily::Ipv6).await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_sample_application.rs b/crates/test-programs/src/bin/p3_sockets_tcp_sample_application.rs new file mode 100644 index 000000000000..bc2d717d3dd9 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_sample_application.rs @@ -0,0 +1,108 @@ +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, TcpSocket, +}; +use test_programs::p3::wit_stream; +use wit_bindgen_rt::async_support::StreamResult; + +struct Component; + +test_programs::p3::export!(Component); + +async fn test_tcp_sample_application(family: IpAddressFamily, bind_address: IpSocketAddress) { + let first_message = b"Hello, world!"; + let second_message = b"Greetings, planet!"; + + let listener = TcpSocket::new(family); + + listener.bind(bind_address).unwrap(); + listener.set_listen_backlog_size(32).unwrap(); + let mut accept = listener.listen().unwrap(); + + let addr = listener.local_address().unwrap(); + + join!( + async { + let client = TcpSocket::new(family); + client.connect(addr).await.unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + client.send(data_rx).await.unwrap(); + }, + async { + let (result, _) = data_tx.write(vec![]).await; + assert_eq!(result, StreamResult::Complete(0)); + let remaining = data_tx.write_all(first_message.into()).await; + assert!(remaining.is_empty()); + drop(data_tx); + } + ); + }, + async { + let sock = accept.next().await.unwrap(); + let (mut data_rx, fut) = sock.receive(); + let (result, data) = data_rx.read(Vec::with_capacity(100)).await; + assert_eq!(result, StreamResult::Complete(first_message.len())); + + // Check that we sent and received our message! + assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. + fut.await.unwrap() + }, + ); + + // Another client + join!( + async { + let client = TcpSocket::new(family); + client.connect(addr).await.unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + client.send(data_rx).await.unwrap(); + }, + async { + let remaining = data_tx.write_all(second_message.into()).await; + assert!(remaining.is_empty()); + drop(data_tx); + } + ); + }, + async { + let sock = accept.next().await.unwrap(); + let (mut data_rx, fut) = sock.receive(); + let (result, data) = data_rx.read(Vec::with_capacity(100)).await; + assert_eq!(result, StreamResult::Complete(second_message.len())); + + // Check that we sent and received our message! + assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. + fut.await.unwrap() + } + ); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_tcp_sample_application( + IpAddressFamily::Ipv4, + IpSocketAddress::Ipv4(Ipv4SocketAddress { + port: 0, // use any free port + address: (127, 0, 0, 1), // localhost + }), + ) + .await; + test_tcp_sample_application( + IpAddressFamily::Ipv6, + IpSocketAddress::Ipv6(Ipv6SocketAddress { + port: 0, // use any free port + address: (0, 0, 0, 0, 0, 0, 0, 1), // localhost + flow_info: 0, + scope_id: 0, + }), + ) + .await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_sockopts.rs b/crates/test-programs/src/bin/p3_sockets_tcp_sockopts.rs new file mode 100644 index 000000000000..b1636fd06889 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_sockopts.rs @@ -0,0 +1,238 @@ +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +const SECOND: u64 = 1_000_000_000; + +fn test_tcp_sockopt_defaults(family: IpAddressFamily) { + let sock = TcpSocket::new(family); + + assert_eq!(sock.address_family(), family); + + sock.keep_alive_enabled().unwrap(); // Only verify that it has a default value at all, but either value is valid. + assert!(sock.keep_alive_idle_time().unwrap() > 0); + assert!(sock.keep_alive_interval().unwrap() > 0); + assert!(sock.keep_alive_count().unwrap() > 0); + assert!(sock.hop_limit().unwrap() > 0); + assert!(sock.receive_buffer_size().unwrap() > 0); + assert!(sock.send_buffer_size().unwrap() > 0); +} + +fn test_tcp_sockopt_input_ranges(family: IpAddressFamily) { + let sock = TcpSocket::new(family); + + assert!(matches!( + sock.set_listen_backlog_size(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_listen_backlog_size(1), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!(sock.set_listen_backlog_size(u64::MAX), Ok(_))); // Unsupported sizes should be silently capped. + + assert!(matches!(sock.set_keep_alive_enabled(true), Ok(_))); + assert!(matches!(sock.set_keep_alive_enabled(false), Ok(_))); + + assert!(matches!( + sock.set_keep_alive_idle_time(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_keep_alive_idle_time(1), Ok(_))); // Unsupported sizes should be silently clamped. + let idle_time = sock.keep_alive_idle_time().unwrap(); // Check that the special 0/reset behavior was not triggered by the previous line. + assert!(idle_time > 0 && idle_time <= 1 * SECOND); + assert!(matches!(sock.set_keep_alive_idle_time(u64::MAX), Ok(_))); // Unsupported sizes should be silently clamped. + + assert!(matches!( + sock.set_keep_alive_interval(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_keep_alive_interval(1), Ok(_))); // Unsupported sizes should be silently clamped. + let idle_time = sock.keep_alive_interval().unwrap(); // Check that the special 0/reset behavior was not triggered by the previous line. + assert!(idle_time > 0 && idle_time <= 1 * SECOND); + assert!(matches!(sock.set_keep_alive_interval(u64::MAX), Ok(_))); // Unsupported sizes should be silently clamped. + + assert!(matches!( + sock.set_keep_alive_count(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_keep_alive_count(1), Ok(_))); // Unsupported sizes should be silently clamped. + assert!(matches!(sock.set_keep_alive_count(u32::MAX), Ok(_))); // Unsupported sizes should be silently clamped. + + assert!(matches!( + sock.set_hop_limit(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_hop_limit(1), Ok(_))); + assert!(matches!(sock.set_hop_limit(u8::MAX), Ok(_))); + + assert!(matches!( + sock.set_receive_buffer_size(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_receive_buffer_size(1), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!(sock.set_receive_buffer_size(u64::MAX), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!( + sock.set_send_buffer_size(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_send_buffer_size(1), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!(sock.set_send_buffer_size(u64::MAX), Ok(_))); // Unsupported sizes should be silently capped. +} + +fn test_tcp_sockopt_readback(family: IpAddressFamily) { + let sock = TcpSocket::new(family); + + sock.set_keep_alive_enabled(true).unwrap(); + assert_eq!(sock.keep_alive_enabled().unwrap(), true); + sock.set_keep_alive_enabled(false).unwrap(); + assert_eq!(sock.keep_alive_enabled().unwrap(), false); + + sock.set_keep_alive_idle_time(42 * SECOND).unwrap(); + assert_eq!(sock.keep_alive_idle_time().unwrap(), 42 * SECOND); + + sock.set_keep_alive_interval(42 * SECOND).unwrap(); + assert_eq!(sock.keep_alive_interval().unwrap(), 42 * SECOND); + + sock.set_keep_alive_count(42).unwrap(); + assert_eq!(sock.keep_alive_count().unwrap(), 42); + + sock.set_hop_limit(42).unwrap(); + assert_eq!(sock.hop_limit().unwrap(), 42); + + sock.set_receive_buffer_size(0x10000).unwrap(); + assert_eq!(sock.receive_buffer_size().unwrap(), 0x10000); + + sock.set_send_buffer_size(0x10000).unwrap(); + assert_eq!(sock.send_buffer_size().unwrap(), 0x10000); +} + +async fn test_tcp_sockopt_inheritance(family: IpAddressFamily) { + let bind_addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let listener = TcpSocket::new(family); + + let default_keep_alive = listener.keep_alive_enabled().unwrap(); + + // Configure options on listener: + { + listener + .set_keep_alive_enabled(!default_keep_alive) + .unwrap(); + listener.set_keep_alive_idle_time(42 * SECOND).unwrap(); + listener.set_keep_alive_interval(42 * SECOND).unwrap(); + listener.set_keep_alive_count(42).unwrap(); + listener.set_hop_limit(42).unwrap(); + listener.set_receive_buffer_size(0x10000).unwrap(); + listener.set_send_buffer_size(0x10000).unwrap(); + } + + listener.bind(bind_addr).unwrap(); + let mut accept = listener.listen().unwrap(); + let bound_addr = listener.local_address().unwrap(); + let client = TcpSocket::new(family); + let ((), sock) = join!( + async { + client.connect(bound_addr).await.unwrap(); + }, + async { accept.next().await.unwrap() } + ); + + // Verify options on accepted socket: + { + assert_eq!(sock.keep_alive_enabled().unwrap(), !default_keep_alive); + assert_eq!(sock.keep_alive_idle_time().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_interval().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_count().unwrap(), 42); + assert_eq!(sock.hop_limit().unwrap(), 42); + assert_eq!(sock.receive_buffer_size().unwrap(), 0x10000); + assert_eq!(sock.send_buffer_size().unwrap(), 0x10000); + } + + // Update options on listener to something else: + { + listener.set_keep_alive_enabled(default_keep_alive).unwrap(); + listener.set_keep_alive_idle_time(43 * SECOND).unwrap(); + listener.set_keep_alive_interval(43 * SECOND).unwrap(); + listener.set_keep_alive_count(43).unwrap(); + listener.set_hop_limit(43).unwrap(); + listener.set_receive_buffer_size(0x20000).unwrap(); + listener.set_send_buffer_size(0x20000).unwrap(); + } + + // Verify that the already accepted socket was not affected: + { + assert_eq!(sock.keep_alive_enabled().unwrap(), !default_keep_alive); + assert_eq!(sock.keep_alive_idle_time().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_interval().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_count().unwrap(), 42); + assert_eq!(sock.hop_limit().unwrap(), 42); + assert_eq!(sock.receive_buffer_size().unwrap(), 0x10000); + assert_eq!(sock.send_buffer_size().unwrap(), 0x10000); + } +} + +async fn test_tcp_sockopt_after_listen(family: IpAddressFamily) { + let bind_addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let listener = TcpSocket::new(family); + listener.bind(bind_addr).unwrap(); + let mut accept = listener.listen().unwrap(); + let bound_addr = listener.local_address().unwrap(); + + let default_keep_alive = listener.keep_alive_enabled().unwrap(); + + // Update options while the socket is already listening: + { + listener + .set_keep_alive_enabled(!default_keep_alive) + .unwrap(); + listener.set_keep_alive_idle_time(42 * SECOND).unwrap(); + listener.set_keep_alive_interval(42 * SECOND).unwrap(); + listener.set_keep_alive_count(42).unwrap(); + listener.set_hop_limit(42).unwrap(); + listener.set_receive_buffer_size(0x10000).unwrap(); + listener.set_send_buffer_size(0x10000).unwrap(); + } + + let client = TcpSocket::new(family); + let ((), sock) = join!( + async { + client.connect(bound_addr).await.unwrap(); + }, + async { accept.next().await.unwrap() } + ); + + // Verify options on accepted socket: + { + assert_eq!(sock.keep_alive_enabled().unwrap(), !default_keep_alive); + assert_eq!(sock.keep_alive_idle_time().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_interval().unwrap(), 42 * SECOND); + assert_eq!(sock.keep_alive_count().unwrap(), 42); + assert_eq!(sock.hop_limit().unwrap(), 42); + assert_eq!(sock.receive_buffer_size().unwrap(), 0x10000); + assert_eq!(sock.send_buffer_size().unwrap(), 0x10000); + } +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_tcp_sockopt_defaults(IpAddressFamily::Ipv4); + test_tcp_sockopt_defaults(IpAddressFamily::Ipv6); + + test_tcp_sockopt_input_ranges(IpAddressFamily::Ipv4); + test_tcp_sockopt_input_ranges(IpAddressFamily::Ipv6); + + test_tcp_sockopt_readback(IpAddressFamily::Ipv4); + test_tcp_sockopt_readback(IpAddressFamily::Ipv6); + + test_tcp_sockopt_inheritance(IpAddressFamily::Ipv4).await; + test_tcp_sockopt_inheritance(IpAddressFamily::Ipv6).await; + + test_tcp_sockopt_after_listen(IpAddressFamily::Ipv4).await; + test_tcp_sockopt_after_listen(IpAddressFamily::Ipv6).await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_states.rs b/crates/test-programs/src/bin/p3_sockets_tcp_states.rs new file mode 100644 index 000000000000..060a25ededa5 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_states.rs @@ -0,0 +1,216 @@ +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +fn test_tcp_unbound_state_invariants(family: IpAddressFamily) { + let sock = TcpSocket::new(family); + + // TODO: Test send and receive + //assert!(matches!( + // sock.shutdown(ShutdownType::Both), + // Err(ErrorCode::InvalidState) + //)); + assert_eq!(sock.local_address(), Err(ErrorCode::InvalidState)); + assert_eq!(sock.remote_address(), Err(ErrorCode::InvalidState)); + assert!(!sock.is_listening()); + assert_eq!(sock.address_family(), family); + + assert_eq!(sock.set_listen_backlog_size(32), Ok(())); + + assert!(sock.keep_alive_enabled().is_ok()); + assert_eq!(sock.set_keep_alive_enabled(false), Ok(())); + assert_eq!(sock.keep_alive_enabled(), Ok(false)); + + assert!(sock.keep_alive_idle_time().is_ok()); + assert_eq!(sock.set_keep_alive_idle_time(1), Ok(())); + + assert!(sock.keep_alive_interval().is_ok()); + assert_eq!(sock.set_keep_alive_interval(1), Ok(())); + + assert!(sock.keep_alive_count().is_ok()); + assert_eq!(sock.set_keep_alive_count(1), Ok(())); + + assert!(sock.hop_limit().is_ok()); + assert_eq!(sock.set_hop_limit(255), Ok(())); + assert_eq!(sock.hop_limit(), Ok(255)); + + assert!(sock.receive_buffer_size().is_ok()); + assert_eq!(sock.set_receive_buffer_size(16000), Ok(())); + + assert!(sock.send_buffer_size().is_ok()); + assert_eq!(sock.set_send_buffer_size(16000), Ok(())); +} + +fn test_tcp_bound_state_invariants(family: IpAddressFamily) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock = TcpSocket::new(family); + sock.bind(bind_address).unwrap(); + + assert_eq!(sock.bind(bind_address), Err(ErrorCode::InvalidState)); + // TODO: Test send and receive + //assert!(matches!( + // sock.shutdown(ShutdownType::Both), + // Err(ErrorCode::InvalidState) + //)); + + assert!(sock.local_address().is_ok()); + assert_eq!(sock.remote_address(), Err(ErrorCode::InvalidState)); + assert!(!sock.is_listening()); + assert_eq!(sock.address_family(), family); + + assert_eq!(sock.set_listen_backlog_size(32), Ok(())); + + assert!(sock.keep_alive_enabled().is_ok()); + assert_eq!(sock.set_keep_alive_enabled(false), Ok(())); + assert_eq!(sock.keep_alive_enabled(), Ok(false)); + + assert!(sock.keep_alive_idle_time().is_ok()); + assert_eq!(sock.set_keep_alive_idle_time(1), Ok(())); + + assert!(sock.keep_alive_interval().is_ok()); + assert_eq!(sock.set_keep_alive_interval(1), Ok(())); + + assert!(sock.keep_alive_count().is_ok()); + assert_eq!(sock.set_keep_alive_count(1), Ok(())); + + assert!(sock.hop_limit().is_ok()); + assert_eq!(sock.set_hop_limit(255), Ok(())); + assert_eq!(sock.hop_limit(), Ok(255)); + + assert!(sock.receive_buffer_size().is_ok()); + assert_eq!(sock.set_receive_buffer_size(16000), Ok(())); + + assert!(sock.send_buffer_size().is_ok()); + assert_eq!(sock.set_send_buffer_size(16000), Ok(())); +} + +async fn test_tcp_listening_state_invariants(family: IpAddressFamily) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock = TcpSocket::new(family); + sock.bind(bind_address).unwrap(); + sock.listen().unwrap(); + + assert_eq!(sock.bind(bind_address), Err(ErrorCode::InvalidState)); + assert_eq!( + sock.connect(IpSocketAddress::new(IpAddress::new_loopback(family), 1)) + .await, + Err(ErrorCode::InvalidState) + ); + assert!(matches!(sock.listen(), Err(ErrorCode::InvalidState))); + // Skipping: tcp::accept + // TODO: Test send and receive + //assert!(matches!( + // sock.shutdown(ShutdownType::Both), + // Err(ErrorCode::InvalidState) + //)); + + assert!(sock.local_address().is_ok()); + assert_eq!(sock.remote_address(), Err(ErrorCode::InvalidState)); + assert!(sock.is_listening()); + assert_eq!(sock.address_family(), family); + + assert!(matches!( + sock.set_listen_backlog_size(32), + Ok(_) | Err(ErrorCode::NotSupported) + )); + + assert!(sock.keep_alive_enabled().is_ok()); + assert_eq!(sock.set_keep_alive_enabled(false), Ok(())); + assert_eq!(sock.keep_alive_enabled(), Ok(false)); + + assert!(sock.keep_alive_idle_time().is_ok()); + assert_eq!(sock.set_keep_alive_idle_time(1), Ok(())); + + assert!(sock.keep_alive_interval().is_ok()); + assert_eq!(sock.set_keep_alive_interval(1), Ok(())); + + assert!(sock.keep_alive_count().is_ok()); + assert_eq!(sock.set_keep_alive_count(1), Ok(())); + + assert!(sock.hop_limit().is_ok()); + assert_eq!(sock.set_hop_limit(255), Ok(())); + assert_eq!(sock.hop_limit(), Ok(255)); + + assert!(sock.receive_buffer_size().is_ok()); + assert_eq!(sock.set_receive_buffer_size(16000), Ok(())); + + assert!(sock.send_buffer_size().is_ok()); + assert_eq!(sock.set_send_buffer_size(16000), Ok(())); +} + +async fn test_tcp_connected_state_invariants(family: IpAddressFamily) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock_listener = TcpSocket::new(family); + sock_listener.bind(bind_address).unwrap(); + let mut accept = sock_listener.listen().unwrap(); + let addr_listener = sock_listener.local_address().unwrap(); + let sock = TcpSocket::new(family); + join!( + async { + sock.connect(addr_listener).await.unwrap(); + }, + async { + accept.next().await.unwrap(); + } + ); + + assert_eq!(sock.bind(bind_address), Err(ErrorCode::InvalidState)); + assert_eq!( + sock.connect(addr_listener).await, + Err(ErrorCode::InvalidState) + ); + assert!(matches!(sock.listen(), Err(ErrorCode::InvalidState))); + // Skipping: tcp::shutdown + + assert!(sock.local_address().is_ok()); + assert!(sock.remote_address().is_ok()); + assert!(!sock.is_listening()); + assert_eq!(sock.address_family(), family); + + assert!(sock.keep_alive_enabled().is_ok()); + assert_eq!(sock.set_keep_alive_enabled(false), Ok(())); + assert_eq!(sock.keep_alive_enabled(), Ok(false)); + + assert!(sock.keep_alive_idle_time().is_ok()); + assert_eq!(sock.set_keep_alive_idle_time(1), Ok(())); + + assert!(sock.keep_alive_interval().is_ok()); + assert_eq!(sock.set_keep_alive_interval(1), Ok(())); + + assert!(sock.keep_alive_count().is_ok()); + assert_eq!(sock.set_keep_alive_count(1), Ok(())); + + assert!(sock.hop_limit().is_ok()); + assert_eq!(sock.set_hop_limit(255), Ok(())); + assert_eq!(sock.hop_limit(), Ok(255)); + + assert!(sock.receive_buffer_size().is_ok()); + assert_eq!(sock.set_receive_buffer_size(16000), Ok(())); + + assert!(sock.send_buffer_size().is_ok()); + assert_eq!(sock.set_send_buffer_size(16000), Ok(())); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_tcp_unbound_state_invariants(IpAddressFamily::Ipv4); + test_tcp_unbound_state_invariants(IpAddressFamily::Ipv6); + + test_tcp_bound_state_invariants(IpAddressFamily::Ipv4); + test_tcp_bound_state_invariants(IpAddressFamily::Ipv6); + + test_tcp_listening_state_invariants(IpAddressFamily::Ipv4).await; + test_tcp_listening_state_invariants(IpAddressFamily::Ipv6).await; + + test_tcp_connected_state_invariants(IpAddressFamily::Ipv4).await; + test_tcp_connected_state_invariants(IpAddressFamily::Ipv6).await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_tcp_streams.rs b/crates/test-programs/src/bin/p3_sockets_tcp_streams.rs new file mode 100644 index 000000000000..0aaf63b06524 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_tcp_streams.rs @@ -0,0 +1,151 @@ +use core::future::Future; + +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, +}; +use test_programs::p3::wit_stream; + +struct Component; + +test_programs::p3::export!(Component); + +/// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server. +async fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(family: IpAddressFamily) { + setup(family, |server, client| async move { + drop(server); + + let (mut client_rx, client_fut) = client.receive(); + + // The input stream should immediately signal StreamError::Closed. + // Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK) + // See: https://github.com/bytecodealliance/wasmtime/pull/8968 + + // Wait for the shutdown signal to reach the client: + assert!(client_rx.next().await.is_none()); + assert_eq!(client_fut.await, Ok(())); + }) + .await; +} + +/// InputStream::read should return `StreamError::Closed` after the connection has been shut down locally. +async fn test_tcp_input_stream_should_be_closed_by_local_shutdown(family: IpAddressFamily) { + setup(family, |server, client| async move { + let (mut server_tx, server_rx) = wit_stream::new(); + join!( + async { + server.send(server_rx).await.unwrap(); + }, + async { + // On Linux, `recv` continues to work even after `shutdown(sock, SHUT_RD)` + // has been called. To properly test that this behavior doesn't happen in + // WASI, we make sure there's some data to read by the client: + let rest = server_tx.write_all(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await; + assert!(rest.is_empty()); + drop(server_tx); + }, + ); + + let (client_rx, client_fut) = client.receive(); + + // Shut down socket locally: + drop(client_rx); + // Wait for the shutdown signal to reach the client: + assert_eq!(client_fut.await, Ok(())); + }).await; +} + +/// StreamWriter should return `StreamError::Closed` after the connection has been locally shut down for sending. +async fn test_tcp_output_stream_should_be_closed_by_local_shutdown(family: IpAddressFamily) { + setup(family, |_server, client| async move { + let (client_tx, client_rx) = wit_stream::new(); + join!( + async { + client.send(client_rx).await.unwrap(); + }, + async { + // TODO: Verify if send on the stream should return an error + //assert!(client_tx.send(b"Hi!".into()).await.is_err()); + drop(client_tx); + } + ); + }) + .await; +} + +/// Calling `shutdown` while the StreamWriter is in the middle of a background write should not cause that write to be lost. +async fn test_tcp_shutdown_should_not_lose_data(family: IpAddressFamily) { + setup(family, |server, client| async move { + // Minimize the local send buffer: + client.set_send_buffer_size(1024).unwrap(); + let small_buffer_size = client.send_buffer_size().unwrap(); + + // Create a significantly bigger buffer, so that we can be pretty sure the `write` won't finish immediately: + let big_buffer_size = 100 * small_buffer_size; + assert!(big_buffer_size > small_buffer_size); + let outgoing_data = vec![0; big_buffer_size as usize]; + + // Submit the oversized buffer and immediately initiate the shutdown: + let (mut client_tx, client_rx) = wit_stream::new(); + join!( + async { + client.send(client_rx).await.unwrap(); + }, + async { + let ret = client_tx.write_all(outgoing_data.clone()).await; + assert!(ret.is_empty()); + drop(client_tx); + }, + async { + // The peer should receive _all_ data: + let (server_rx, server_fut) = server.receive(); + let incoming_data = server_rx.collect().await; + assert_eq!( + outgoing_data, incoming_data, + "Received data should match the sent data" + ); + server_fut.await.unwrap(); + }, + ); + }) + .await; +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_tcp_input_stream_should_be_closed_by_remote_shutdown(IpAddressFamily::Ipv4).await; + test_tcp_input_stream_should_be_closed_by_remote_shutdown(IpAddressFamily::Ipv6).await; + + test_tcp_input_stream_should_be_closed_by_local_shutdown(IpAddressFamily::Ipv4).await; + test_tcp_input_stream_should_be_closed_by_local_shutdown(IpAddressFamily::Ipv6).await; + + test_tcp_output_stream_should_be_closed_by_local_shutdown(IpAddressFamily::Ipv4).await; + test_tcp_output_stream_should_be_closed_by_local_shutdown(IpAddressFamily::Ipv6).await; + + test_tcp_shutdown_should_not_lose_data(IpAddressFamily::Ipv4).await; + test_tcp_shutdown_should_not_lose_data(IpAddressFamily::Ipv6).await; + Ok(()) + } +} + +fn main() {} + +/// Set up a connected pair of sockets +async fn setup>( + family: IpAddressFamily, + body: impl FnOnce(TcpSocket, TcpSocket) -> Fut, +) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let listener = TcpSocket::new(family); + listener.bind(bind_address).unwrap(); + let mut accept = listener.listen().unwrap(); + let bound_address = listener.local_address().unwrap(); + let client_socket = TcpSocket::new(family); + let ((), accepted_socket) = join!( + async { + client_socket.connect(bound_address).await.unwrap(); + }, + async { accept.next().await.unwrap() }, + ); + body(accepted_socket, client_socket).await; +} diff --git a/crates/test-programs/src/bin/p3_sockets_udp_bind.rs b/crates/test-programs/src/bin/p3_sockets_udp_bind.rs new file mode 100644 index 000000000000..eab803a4c0fb --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_udp_bind.rs @@ -0,0 +1,116 @@ +use test_programs::p3::sockets::attempt_random_port; +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, UdpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +/// Bind a socket and let the system determine a port. +fn test_udp_bind_ephemeral_port(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock = UdpSocket::new(ip.family()); + sock.bind(bind_addr).unwrap(); + + let bound_addr = sock.local_address().unwrap(); + + assert_eq!(bind_addr.ip(), bound_addr.ip()); + assert_ne!(bind_addr.port(), bound_addr.port()); +} + +/// Bind a socket on a specified port. +fn test_udp_bind_specific_port(ip: IpAddress) { + let sock = UdpSocket::new(ip.family()); + + let bind_addr = attempt_random_port(ip, |bind_addr| sock.bind(bind_addr)).unwrap(); + + let bound_addr = sock.local_address().unwrap(); + + assert_eq!(bind_addr.ip(), bound_addr.ip()); + assert_eq!(bind_addr.port(), bound_addr.port()); +} + +/// Two sockets may not be actively bound to the same address at the same time. +fn test_udp_bind_addrinuse(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock1 = UdpSocket::new(ip.family()); + sock1.bind(bind_addr).unwrap(); + + let bound_addr = sock1.local_address().unwrap(); + + let sock2 = UdpSocket::new(ip.family()); + assert!(matches!( + sock2.bind(bound_addr), + Err(ErrorCode::AddressInUse) + )); +} + +// Try binding to an address that is not configured on the system. +fn test_udp_bind_addrnotavail(ip: IpAddress) { + let bind_addr = IpSocketAddress::new(ip, 0); + + let sock = UdpSocket::new(ip.family()); + + assert!(matches!( + sock.bind(bind_addr), + Err(ErrorCode::AddressNotBindable) + )); +} + +/// Bind should validate the address family. +fn test_udp_bind_wrong_family(family: IpAddressFamily) { + let wrong_ip = match family { + IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK, + IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK, + }; + + let sock = UdpSocket::new(family); + let result = sock.bind(IpSocketAddress::new(wrong_ip, 0)); + + assert!(matches!(result, Err(ErrorCode::InvalidArgument))); +} + +fn test_udp_bind_dual_stack() { + let sock = UdpSocket::new(IpAddressFamily::Ipv6); + let addr = IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, 0); + + // Binding an IPv4-mapped-IPv6 address on a ipv6-only socket should fail: + assert!(matches!(sock.bind(addr), Err(ErrorCode::InvalidArgument))); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + const RESERVED_IPV4_ADDRESS: IpAddress = IpAddress::Ipv4((192, 0, 2, 0)); // Reserved for documentation and examples. + const RESERVED_IPV6_ADDRESS: IpAddress = + IpAddress::Ipv6((0x2001, 0x0db8, 0, 0, 0, 0, 0, 0)); // Reserved for documentation and examples. + + test_udp_bind_ephemeral_port(IpAddress::IPV4_LOOPBACK); + test_udp_bind_ephemeral_port(IpAddress::IPV6_LOOPBACK); + test_udp_bind_ephemeral_port(IpAddress::IPV4_UNSPECIFIED); + test_udp_bind_ephemeral_port(IpAddress::IPV6_UNSPECIFIED); + + test_udp_bind_specific_port(IpAddress::IPV4_LOOPBACK); + test_udp_bind_specific_port(IpAddress::IPV6_LOOPBACK); + test_udp_bind_specific_port(IpAddress::IPV4_UNSPECIFIED); + test_udp_bind_specific_port(IpAddress::IPV6_UNSPECIFIED); + + test_udp_bind_addrinuse(IpAddress::IPV4_LOOPBACK); + test_udp_bind_addrinuse(IpAddress::IPV6_LOOPBACK); + test_udp_bind_addrinuse(IpAddress::IPV4_UNSPECIFIED); + test_udp_bind_addrinuse(IpAddress::IPV6_UNSPECIFIED); + + test_udp_bind_addrnotavail(RESERVED_IPV4_ADDRESS); + test_udp_bind_addrnotavail(RESERVED_IPV6_ADDRESS); + + test_udp_bind_wrong_family(IpAddressFamily::Ipv4); + test_udp_bind_wrong_family(IpAddressFamily::Ipv6); + + test_udp_bind_dual_stack(); + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_udp_connect.rs b/crates/test-programs/src/bin/p3_sockets_udp_connect.rs new file mode 100644 index 000000000000..3791d25c9153 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_udp_connect.rs @@ -0,0 +1,131 @@ +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, UdpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +const SOME_PORT: u16 = 47; // If the tests pass, this will never actually be connected to. + +fn test_udp_connect_disconnect_reconnect(family: IpAddressFamily) { + let unspecified_addr = IpSocketAddress::new(IpAddress::new_unspecified(family), 0); + let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); + let remote2 = IpSocketAddress::new(IpAddress::new_loopback(family), 4320); + + let client = UdpSocket::new(family); + client.bind(unspecified_addr).unwrap(); + + assert_eq!(client.disconnect(), Err(ErrorCode::InvalidState)); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + assert_eq!(client.disconnect(), Err(ErrorCode::InvalidState)); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + _ = client.connect(remote1).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); + + _ = client.connect(remote1).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); + + _ = client.connect(remote2).unwrap(); + assert_eq!(client.remote_address(), Ok(remote2)); + + _ = client.disconnect().unwrap(); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + _ = client.connect(remote1).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); +} + +/// `0.0.0.0` / `::` is not a valid remote address in WASI. +fn test_udp_connect_unspec(family: IpAddressFamily) { + let ip = IpAddress::new_unspecified(family); + let addr = IpSocketAddress::new(ip, SOME_PORT); + let sock = UdpSocket::new(family); + sock.bind_unspecified().unwrap(); + + assert!(matches!( + sock.connect(addr), + Err(ErrorCode::InvalidArgument) + )); +} + +/// 0 is not a valid remote port. +fn test_udp_connect_port_0(family: IpAddressFamily) { + let addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock = UdpSocket::new(family); + sock.bind_unspecified().unwrap(); + + assert!(matches!( + sock.connect(addr), + Err(ErrorCode::InvalidArgument) + )); +} + +/// Connect should validate the address family. +fn test_udp_connect_wrong_family(family: IpAddressFamily) { + let wrong_ip = match family { + IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK, + IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK, + }; + let remote_addr = IpSocketAddress::new(wrong_ip, SOME_PORT); + + let sock = UdpSocket::new(family); + sock.bind_unspecified().unwrap(); + + assert!(matches!( + sock.connect(remote_addr), + Err(ErrorCode::InvalidArgument) + )); +} + +fn test_udp_connect_dual_stack() { + // Set-up: + let v4_server = UdpSocket::new(IpAddressFamily::Ipv4); + v4_server + .bind(IpSocketAddress::new(IpAddress::IPV4_LOOPBACK, 0)) + .unwrap(); + + let v4_server_addr = v4_server.local_address().unwrap(); + let v6_server_addr = + IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, v4_server_addr.port()); + + // Tests: + let v6_client = UdpSocket::new(IpAddressFamily::Ipv6); + + v6_client.bind_unspecified().unwrap(); + + // Connecting to an IPv4 address on an IPv6 socket should fail: + assert!(matches!( + v6_client.connect(v4_server_addr), + Err(ErrorCode::InvalidArgument) + )); + + // Connecting to an IPv4-mapped-IPv6 address on an IPv6 socket should fail: + assert!(matches!( + v6_client.connect(v6_server_addr), + Err(ErrorCode::InvalidArgument) + )); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv4); + test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv6); + + test_udp_connect_unspec(IpAddressFamily::Ipv4); + test_udp_connect_unspec(IpAddressFamily::Ipv6); + + test_udp_connect_port_0(IpAddressFamily::Ipv4); + test_udp_connect_port_0(IpAddressFamily::Ipv6); + + test_udp_connect_wrong_family(IpAddressFamily::Ipv4); + test_udp_connect_wrong_family(IpAddressFamily::Ipv6); + + test_udp_connect_dual_stack(); + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_udp_sample_application.rs b/crates/test-programs/src/bin/p3_sockets_udp_sample_application.rs new file mode 100644 index 000000000000..35526cf35a58 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_udp_sample_application.rs @@ -0,0 +1,87 @@ +use futures::join; +use test_programs::p3::wasi::sockets::types::{ + IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, UdpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +async fn test_udp_sample_application(family: IpAddressFamily, bind_address: IpSocketAddress) { + let unspecified_addr = IpSocketAddress::new(IpAddress::new_unspecified(family), 0); + + let first_message = &[]; + let second_message = b"Hello, world!"; + let third_message = b"Greetings, planet!"; + + let server = UdpSocket::new(family); + + server.bind(bind_address).unwrap(); + let addr = server.local_address().unwrap(); + + let client = UdpSocket::new(family); + client.bind(unspecified_addr).unwrap(); + client.connect(addr).unwrap(); + let client_addr = client.local_address().unwrap(); + join!( + async { + client.send(first_message.to_vec(), None).await.unwrap(); + client + .send(second_message.to_vec(), Some(addr)) + .await + .unwrap(); + }, + async { + // Check that we've received our sent messages. + let (buf, addr) = server.receive().await.unwrap(); + assert_eq!(buf, first_message); + assert_eq!(addr, client_addr); + + let (buf, addr) = server.receive().await.unwrap(); + assert_eq!(buf, second_message); + assert_eq!(addr, client_addr); + } + ); + join!( + async { + // Another client + let client = UdpSocket::new(family); + client.bind(unspecified_addr).unwrap(); + client + .send(third_message.to_vec(), Some(addr)) + .await + .unwrap(); + }, + async { + // Check that we sent and received our message! + let (buf, _) = server.receive().await.unwrap(); + assert_eq!(buf, third_message); + }, + ); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_udp_sample_application( + IpAddressFamily::Ipv4, + IpSocketAddress::Ipv4(Ipv4SocketAddress { + port: 0, // use any free port + address: (127, 0, 0, 1), // localhost + }), + ) + .await; + test_udp_sample_application( + IpAddressFamily::Ipv6, + IpSocketAddress::Ipv6(Ipv6SocketAddress { + port: 0, // use any free port + address: (0, 0, 0, 0, 0, 0, 0, 1), // localhost + flow_info: 0, + scope_id: 0, + }), + ) + .await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_udp_sockopts.rs b/crates/test-programs/src/bin/p3_sockets_udp_sockopts.rs new file mode 100644 index 000000000000..7330428e8225 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_udp_sockopts.rs @@ -0,0 +1,68 @@ +use test_programs::p3::wasi::sockets::types::{ErrorCode, IpAddressFamily, UdpSocket}; + +struct Component; + +test_programs::p3::export!(Component); + +fn test_udp_sockopt_defaults(family: IpAddressFamily) { + let sock = UdpSocket::new(family); + + assert_eq!(sock.address_family(), family); + + assert!(sock.unicast_hop_limit().unwrap() > 0); + assert!(sock.receive_buffer_size().unwrap() > 0); + assert!(sock.send_buffer_size().unwrap() > 0); +} + +fn test_udp_sockopt_input_ranges(family: IpAddressFamily) { + let sock = UdpSocket::new(family); + + assert!(matches!( + sock.set_unicast_hop_limit(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_unicast_hop_limit(1), Ok(_))); + assert!(matches!(sock.set_unicast_hop_limit(u8::MAX), Ok(_))); + + assert!(matches!( + sock.set_receive_buffer_size(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_receive_buffer_size(1), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!(sock.set_receive_buffer_size(u64::MAX), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!( + sock.set_send_buffer_size(0), + Err(ErrorCode::InvalidArgument) + )); + assert!(matches!(sock.set_send_buffer_size(1), Ok(_))); // Unsupported sizes should be silently capped. + assert!(matches!(sock.set_send_buffer_size(u64::MAX), Ok(_))); // Unsupported sizes should be silently capped. +} + +fn test_udp_sockopt_readback(family: IpAddressFamily) { + let sock = UdpSocket::new(family); + + sock.set_unicast_hop_limit(42).unwrap(); + assert_eq!(sock.unicast_hop_limit().unwrap(), 42); + + sock.set_receive_buffer_size(0x10000).unwrap(); + assert_eq!(sock.receive_buffer_size().unwrap(), 0x10000); + + sock.set_send_buffer_size(0x10000).unwrap(); + assert_eq!(sock.send_buffer_size().unwrap(), 0x10000); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_udp_sockopt_defaults(IpAddressFamily::Ipv4); + test_udp_sockopt_defaults(IpAddressFamily::Ipv6); + + test_udp_sockopt_input_ranges(IpAddressFamily::Ipv4); + test_udp_sockopt_input_ranges(IpAddressFamily::Ipv6); + + test_udp_sockopt_readback(IpAddressFamily::Ipv4); + test_udp_sockopt_readback(IpAddressFamily::Ipv6); + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/bin/p3_sockets_udp_states.rs b/crates/test-programs/src/bin/p3_sockets_udp_states.rs new file mode 100644 index 000000000000..45da9552e630 --- /dev/null +++ b/crates/test-programs/src/bin/p3_sockets_udp_states.rs @@ -0,0 +1,96 @@ +use test_programs::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, UdpSocket, +}; + +struct Component; + +test_programs::p3::export!(Component); + +async fn test_udp_unbound_state_invariants(family: IpAddressFamily) { + let sock = UdpSocket::new(family); + + // Skipping: udp::start_bind + + assert_eq!( + sock.send(b"test".into(), None).await, + Err(ErrorCode::InvalidArgument) + ); + assert_eq!(sock.disconnect(), Err(ErrorCode::InvalidState)); + assert_eq!(sock.local_address(), Err(ErrorCode::InvalidState)); + assert_eq!(sock.remote_address(), Err(ErrorCode::InvalidState)); + assert_eq!(sock.address_family(), family); + + assert!(matches!(sock.unicast_hop_limit(), Ok(_))); + assert!(matches!(sock.set_unicast_hop_limit(255), Ok(_))); + assert!(matches!(sock.receive_buffer_size(), Ok(_))); + assert!(matches!(sock.set_receive_buffer_size(16000), Ok(_))); + assert!(matches!(sock.send_buffer_size(), Ok(_))); + assert!(matches!(sock.set_send_buffer_size(16000), Ok(_))); +} + +fn test_udp_bound_state_invariants(family: IpAddressFamily) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let sock = UdpSocket::new(family); + sock.bind(bind_address).unwrap(); + + assert!(matches!( + sock.bind(bind_address), + Err(ErrorCode::InvalidState) + )); + // Skipping: udp::stream + + assert!(matches!(sock.local_address(), Ok(_))); + assert!(matches!( + sock.remote_address(), + Err(ErrorCode::InvalidState) + )); + assert_eq!(sock.address_family(), family); + + assert!(matches!(sock.unicast_hop_limit(), Ok(_))); + assert!(matches!(sock.set_unicast_hop_limit(255), Ok(_))); + assert!(matches!(sock.receive_buffer_size(), Ok(_))); + assert!(matches!(sock.set_receive_buffer_size(16000), Ok(_))); + assert!(matches!(sock.send_buffer_size(), Ok(_))); + assert!(matches!(sock.set_send_buffer_size(16000), Ok(_))); +} + +fn test_udp_connected_state_invariants(family: IpAddressFamily) { + let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); + let connect_address = IpSocketAddress::new(IpAddress::new_loopback(family), 54321); + let sock = UdpSocket::new(family); + sock.bind(bind_address).unwrap(); + sock.connect(connect_address).unwrap(); + + assert!(matches!( + sock.bind(bind_address), + Err(ErrorCode::InvalidState) + )); + // Skipping: udp::stream + + assert!(matches!(sock.local_address(), Ok(_))); + assert!(matches!(sock.remote_address(), Ok(_))); + assert_eq!(sock.address_family(), family); + + assert!(matches!(sock.unicast_hop_limit(), Ok(_))); + assert!(matches!(sock.set_unicast_hop_limit(255), Ok(_))); + assert!(matches!(sock.receive_buffer_size(), Ok(_))); + assert!(matches!(sock.set_receive_buffer_size(16000), Ok(_))); + assert!(matches!(sock.send_buffer_size(), Ok(_))); + assert!(matches!(sock.set_send_buffer_size(16000), Ok(_))); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + test_udp_unbound_state_invariants(IpAddressFamily::Ipv4).await; + test_udp_unbound_state_invariants(IpAddressFamily::Ipv6).await; + + test_udp_bound_state_invariants(IpAddressFamily::Ipv4); + test_udp_bound_state_invariants(IpAddressFamily::Ipv6); + + test_udp_connected_state_invariants(IpAddressFamily::Ipv4); + test_udp_connected_state_invariants(IpAddressFamily::Ipv6); + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/p3/mod.rs b/crates/test-programs/src/p3/mod.rs index 09b94a1498af..c6bd6e3fdd1e 100644 --- a/crates/test-programs/src/p3/mod.rs +++ b/crates/test-programs/src/p3/mod.rs @@ -1,3 +1,5 @@ +pub mod sockets; + wit_bindgen::generate!({ inline: " package wasmtime:test; diff --git a/crates/test-programs/src/p3/sockets.rs b/crates/test-programs/src/p3/sockets.rs new file mode 100644 index 000000000000..654b9491d7ee --- /dev/null +++ b/crates/test-programs/src/p3/sockets.rs @@ -0,0 +1,157 @@ +use core::ops::Range; + +use crate::p3::wasi::random; +use crate::p3::wasi::sockets::types::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, + UdpSocket, +}; + +impl IpAddress { + pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255)); + + pub const IPV4_LOOPBACK: IpAddress = IpAddress::Ipv4((127, 0, 0, 1)); + pub const IPV6_LOOPBACK: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 1)); + + pub const IPV4_UNSPECIFIED: IpAddress = IpAddress::Ipv4((0, 0, 0, 0)); + pub const IPV6_UNSPECIFIED: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 0)); + + pub const IPV4_MAPPED_LOOPBACK: IpAddress = + IpAddress::Ipv6((0, 0, 0, 0, 0, 0xFFFF, 0x7F00, 0x0001)); + + pub const fn new_loopback(family: IpAddressFamily) -> IpAddress { + match family { + IpAddressFamily::Ipv4 => Self::IPV4_LOOPBACK, + IpAddressFamily::Ipv6 => Self::IPV6_LOOPBACK, + } + } + + pub const fn new_unspecified(family: IpAddressFamily) -> IpAddress { + match family { + IpAddressFamily::Ipv4 => Self::IPV4_UNSPECIFIED, + IpAddressFamily::Ipv6 => Self::IPV6_UNSPECIFIED, + } + } + + pub const fn family(&self) -> IpAddressFamily { + match self { + IpAddress::Ipv4(_) => IpAddressFamily::Ipv4, + IpAddress::Ipv6(_) => IpAddressFamily::Ipv6, + } + } +} + +impl PartialEq for IpAddress { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Ipv4(left), Self::Ipv4(right)) => left == right, + (Self::Ipv6(left), Self::Ipv6(right)) => left == right, + _ => false, + } + } +} + +impl IpSocketAddress { + pub const fn new(ip: IpAddress, port: u16) -> IpSocketAddress { + match ip { + IpAddress::Ipv4(addr) => IpSocketAddress::Ipv4(Ipv4SocketAddress { + port, + address: addr, + }), + IpAddress::Ipv6(addr) => IpSocketAddress::Ipv6(Ipv6SocketAddress { + port, + address: addr, + flow_info: 0, + scope_id: 0, + }), + } + } + + pub const fn ip(&self) -> IpAddress { + match self { + IpSocketAddress::Ipv4(addr) => IpAddress::Ipv4(addr.address), + IpSocketAddress::Ipv6(addr) => IpAddress::Ipv6(addr.address), + } + } + + pub const fn port(&self) -> u16 { + match self { + IpSocketAddress::Ipv4(addr) => addr.port, + IpSocketAddress::Ipv6(addr) => addr.port, + } + } + + pub const fn family(&self) -> IpAddressFamily { + match self { + IpSocketAddress::Ipv4(_) => IpAddressFamily::Ipv4, + IpSocketAddress::Ipv6(_) => IpAddressFamily::Ipv6, + } + } +} + +impl PartialEq for Ipv4SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port && self.address == other.address + } +} + +impl PartialEq for Ipv6SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port + && self.flow_info == other.flow_info + && self.address == other.address + && self.scope_id == other.scope_id + } +} + +impl PartialEq for IpSocketAddress { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0, + (Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0, + _ => false, + } + } +} + +fn generate_random_u16(range: Range) -> u16 { + let start = range.start as u64; + let end = range.end as u64; + let port = start + (random::random::get_random_u64() % (end - start)); + port as u16 +} + +/// Execute the inner function with a randomly generated port. +/// To prevent random failures, we make a few attempts before giving up. +pub fn attempt_random_port( + local_address: IpAddress, + mut f: F, +) -> Result +where + F: FnMut(IpSocketAddress) -> Result<(), ErrorCode>, +{ + const MAX_ATTEMPTS: u32 = 10; + let mut i = 0; + loop { + i += 1; + + let port: u16 = generate_random_u16(1024..u16::MAX); + let sock_addr = IpSocketAddress::new(local_address, port); + + match f(sock_addr) { + Ok(_) => return Ok(sock_addr), + Err(e) if i >= MAX_ATTEMPTS => return Err(e), + // Try again if the port is already taken. This can sometimes show up as `AccessDenied` on Windows. + Err(ErrorCode::AddressInUse | ErrorCode::AccessDenied) => {} + Err(e) => return Err(e), + } + } +} + +impl UdpSocket { + pub fn bind_unspecified(&self) -> Result<(), ErrorCode> { + let ip = IpAddress::new_unspecified(self.address_family()); + let port = 0; + + self.bind(IpSocketAddress::new(ip, port)) + } +} diff --git a/crates/wasi/src/clocks.rs b/crates/wasi/src/clocks.rs index bb030e1d1c32..b08e3c3340bb 100644 --- a/crates/wasi/src/clocks.rs +++ b/crates/wasi/src/clocks.rs @@ -2,6 +2,24 @@ use cap_std::time::{Duration, Instant, SystemClock}; use cap_std::{AmbientAuthority, ambient_authority}; use cap_time_ext::{MonotonicClockExt as _, SystemClockExt as _}; +pub struct WasiClocksCtx { + pub wall_clock: Box, + pub monotonic_clock: Box, +} + +impl Default for WasiClocksCtx { + fn default() -> Self { + Self { + wall_clock: wall_clock(), + monotonic_clock: monotonic_clock(), + } + } +} + +pub trait WasiClocksView: Send { + fn clocks(&mut self) -> &mut WasiClocksCtx; +} + impl WasiClocksView for &mut T { fn clocks(&mut self) -> &mut WasiClocksCtx { T::clocks(self) @@ -20,24 +38,6 @@ impl WasiClocksView for WasiClocksCtx { } } -pub trait WasiClocksView: Send { - fn clocks(&mut self) -> &mut WasiClocksCtx; -} - -pub struct WasiClocksCtx { - pub wall_clock: Box, - pub monotonic_clock: Box, -} - -impl Default for WasiClocksCtx { - fn default() -> Self { - Self { - wall_clock: wall_clock(), - monotonic_clock: monotonic_clock(), - } - } -} - pub trait HostWallClock: Send { fn resolution(&self) -> Duration; fn now(&self) -> Duration; diff --git a/crates/wasi/src/ctx.rs b/crates/wasi/src/ctx.rs index e2cec900297e..a13adf7d6a3a 100644 --- a/crates/wasi/src/ctx.rs +++ b/crates/wasi/src/ctx.rs @@ -1,7 +1,7 @@ use crate::cli::WasiCliCtx; use crate::clocks::{HostMonotonicClock, HostWallClock, WasiClocksCtx}; -use crate::net::{SocketAddrCheck, SocketAddrUse}; use crate::random::WasiRandomCtx; +use crate::sockets::{SocketAddrCheck, SocketAddrUse, WasiSocketsCtx}; use cap_rand::RngCore; use std::future::Future; use std::net::SocketAddr; @@ -19,11 +19,10 @@ use std::sync::Arc; /// [`Store`]: wasmtime::Store #[derive(Default)] pub(crate) struct WasiCtxBuilder { - pub(crate) random: WasiRandomCtx, - pub(crate) clocks: WasiClocksCtx, pub(crate) cli: WasiCliCtx, - pub(crate) socket_addr_check: SocketAddrCheck, - pub(crate) allowed_network_uses: AllowedNetworkUses, + pub(crate) clocks: WasiClocksCtx, + pub(crate) random: WasiRandomCtx, + pub(crate) sockets: WasiSocketsCtx, pub(crate) allow_blocking_current_thread: bool, } @@ -46,8 +45,6 @@ impl WasiCtxBuilder { /// These defaults can all be updated via the various builder configuration /// methods below. pub(crate) fn new(stdin: I, stdout: O, stderr: O) -> Self { - let random = WasiRandomCtx::default(); - let clocks = WasiClocksCtx::default(); let cli = WasiCliCtx { environment: Vec::default(), arguments: Vec::default(), @@ -56,12 +53,14 @@ impl WasiCtxBuilder { stdout, stderr, }; + let clocks = WasiClocksCtx::default(); + let random = WasiRandomCtx::default(); + let sockets = WasiSocketsCtx::default(); Self { - random, - clocks, cli, - socket_addr_check: SocketAddrCheck::default(), - allowed_network_uses: AllowedNetworkUses::default(), + clocks, + random, + sockets, allow_blocking_current_thread: false, } } @@ -241,7 +240,7 @@ impl WasiCtxBuilder { + Sync + 'static, { - self.socket_addr_check = SocketAddrCheck(Arc::new(check)); + self.sockets.socket_addr_check = SocketAddrCheck(Arc::new(check)); self } @@ -249,7 +248,7 @@ impl WasiCtxBuilder { /// /// By default this is disabled. pub fn allow_ip_name_lookup(&mut self, enable: bool) -> &mut Self { - self.allowed_network_uses.ip_name_lookup = enable; + self.sockets.allowed_network_uses.ip_name_lookup = enable; self } @@ -258,7 +257,7 @@ impl WasiCtxBuilder { /// This is enabled by default, but can be disabled if UDP should be blanket /// disabled. pub fn allow_udp(&mut self, enable: bool) -> &mut Self { - self.allowed_network_uses.udp = enable; + self.sockets.allowed_network_uses.udp = enable; self } @@ -267,47 +266,7 @@ impl WasiCtxBuilder { /// This is enabled by default, but can be disabled if TCP should be blanket /// disabled. pub fn allow_tcp(&mut self, enable: bool) -> &mut Self { - self.allowed_network_uses.tcp = enable; + self.sockets.allowed_network_uses.tcp = enable; self } } - -pub struct AllowedNetworkUses { - pub ip_name_lookup: bool, - pub udp: bool, - pub tcp: bool, -} - -impl Default for AllowedNetworkUses { - fn default() -> Self { - Self { - ip_name_lookup: false, - udp: true, - tcp: true, - } - } -} - -impl AllowedNetworkUses { - pub(crate) fn check_allowed_udp(&self) -> std::io::Result<()> { - if !self.udp { - return Err(std::io::Error::new( - std::io::ErrorKind::PermissionDenied, - "UDP is not allowed", - )); - } - - Ok(()) - } - - pub(crate) fn check_allowed_tcp(&self) -> std::io::Result<()> { - if !self.tcp { - return Err(std::io::Error::new( - std::io::ErrorKind::PermissionDenied, - "TCP is not allowed", - )); - } - - Ok(()) - } -} diff --git a/crates/wasi/src/fs.rs b/crates/wasi/src/filesystem.rs similarity index 100% rename from crates/wasi/src/fs.rs rename to crates/wasi/src/filesystem.rs diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index 57ac8e4ac09f..53d098b38817 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -16,8 +16,7 @@ pub mod cli; pub mod clocks; mod ctx; mod error; -mod fs; -mod net; +mod filesystem; pub mod p2; #[cfg(feature = "p3")] pub mod p3; @@ -27,13 +26,14 @@ pub mod preview0; pub mod preview1; pub mod random; pub mod runtime; +pub mod sockets; pub use self::clocks::{HostMonotonicClock, HostWallClock}; pub(crate) use self::ctx::WasiCtxBuilder; pub use self::error::{I32Exit, TrappableError}; -pub use self::fs::{DirPerms, FilePerms, OpenMode}; -pub use self::net::{Network, SocketAddrUse}; +pub use self::filesystem::{DirPerms, FilePerms, OpenMode}; pub use self::random::{Deterministic, thread_rng}; +pub use self::sockets::{AllowedNetworkUses, SocketAddrUse}; #[doc(no_inline)] pub use async_trait::async_trait; #[doc(no_inline)] diff --git a/crates/wasi/src/net.rs b/crates/wasi/src/net.rs deleted file mode 100644 index d0d4aa374cb9..000000000000 --- a/crates/wasi/src/net.rs +++ /dev/null @@ -1,72 +0,0 @@ -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; - -/// Value taken from rust std library. -pub const DEFAULT_TCP_BACKLOG: u32 = 128; - -pub struct Network { - pub socket_addr_check: SocketAddrCheck, - pub allow_ip_name_lookup: bool, -} - -impl Network { - pub async fn check_socket_addr( - &self, - addr: SocketAddr, - reason: SocketAddrUse, - ) -> std::io::Result<()> { - self.socket_addr_check.check(addr, reason).await - } -} - -/// A check that will be called for each socket address that is used of whether the address is permitted. -#[derive(Clone)] -pub struct SocketAddrCheck( - pub(crate) Arc< - dyn Fn(SocketAddr, SocketAddrUse) -> Pin + Send + Sync>> - + Send - + Sync, - >, -); - -impl SocketAddrCheck { - pub async fn check(&self, addr: SocketAddr, reason: SocketAddrUse) -> std::io::Result<()> { - if (self.0)(addr, reason).await { - Ok(()) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::PermissionDenied, - "An address was not permitted by the socket address check.", - )) - } - } -} - -impl Default for SocketAddrCheck { - fn default() -> Self { - Self(Arc::new(|_, _| Box::pin(async { false }))) - } -} - -/// The reason what a socket address is being used for. -#[derive(Clone, Copy, Debug)] -pub enum SocketAddrUse { - /// Binding TCP socket - TcpBind, - /// Connecting TCP socket - TcpConnect, - /// Binding UDP socket - UdpBind, - /// Connecting UDP socket - UdpConnect, - /// Sending datagram on non-connected UDP socket - UdpOutgoingDatagram, -} - -#[derive(Copy, Clone)] -pub enum SocketAddressFamily { - Ipv4, - Ipv6, -} diff --git a/crates/wasi/src/p2/bindings.rs b/crates/wasi/src/p2/bindings.rs index 60308e924cc0..7cfdd1956e4f 100644 --- a/crates/wasi/src/p2/bindings.rs +++ b/crates/wasi/src/p2/bindings.rs @@ -396,7 +396,7 @@ mod async_io { // Configure all other resources to be concrete types defined in // this crate - "wasi:sockets/network/network": crate::net::Network, + "wasi:sockets/network/network": crate::p2::network::Network, "wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket, "wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket, "wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream, diff --git a/crates/wasi/src/p2/ctx.rs b/crates/wasi/src/p2/ctx.rs index 6d49d4c44f4d..92e0fa5f37a9 100644 --- a/crates/wasi/src/p2/ctx.rs +++ b/crates/wasi/src/p2/ctx.rs @@ -1,11 +1,10 @@ use crate::cli::WasiCliCtx; use crate::clocks::{HostMonotonicClock, HostWallClock, WasiClocksCtx}; -use crate::ctx::AllowedNetworkUses; -use crate::net::{SocketAddrCheck, SocketAddrUse}; use crate::p2::filesystem::Dir; use crate::p2::pipe; use crate::p2::stdio::{self, StdinStream, StdoutStream}; use crate::random::WasiRandomCtx; +use crate::sockets::{AllowedNetworkUses, SocketAddrCheck, SocketAddrUse, WasiSocketsCtx}; use crate::{DirPerms, FilePerms, OpenMode}; use anyhow::Result; use cap_rand::RngCore; @@ -435,17 +434,6 @@ impl WasiCtxBuilder { let Self { common: crate::WasiCtxBuilder { - random: - WasiRandomCtx { - random, - insecure_random, - insecure_random_seed, - }, - clocks: - WasiClocksCtx { - wall_clock, - monotonic_clock, - }, cli: WasiCliCtx { environment: env, @@ -455,8 +443,22 @@ impl WasiCtxBuilder { stdout, stderr, }, - socket_addr_check, - allowed_network_uses, + clocks: + WasiClocksCtx { + wall_clock, + monotonic_clock, + }, + random: + WasiRandomCtx { + random, + insecure_random, + insecure_random_seed, + }, + sockets: + WasiSocketsCtx { + socket_addr_check, + allowed_network_uses, + }, allow_blocking_current_thread, }, preopens, diff --git a/crates/wasi/src/p2/host/instance_network.rs b/crates/wasi/src/p2/host/instance_network.rs index 40af0d34541f..85c5a780184d 100644 --- a/crates/wasi/src/p2/host/instance_network.rs +++ b/crates/wasi/src/p2/host/instance_network.rs @@ -1,5 +1,5 @@ -use crate::net::Network; use crate::p2::bindings::sockets::instance_network; +use crate::p2::network::Network; use crate::p2::{IoView, WasiImpl, WasiView}; use wasmtime::component::Resource; diff --git a/crates/wasi/src/p2/host/network.rs b/crates/wasi/src/p2/host/network.rs index 2b316daa4821..30059aed838f 100644 --- a/crates/wasi/src/p2/host/network.rs +++ b/crates/wasi/src/p2/host/network.rs @@ -2,8 +2,8 @@ use crate::p2::bindings::sockets::network::{ self, ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, }; -use crate::p2::network::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; use crate::p2::{IoView, SocketError, WasiImpl, WasiView}; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; use anyhow::Error; use rustix::io::Errno; use std::io; @@ -242,331 +242,3 @@ impl From for IpAddressFamily { } } } - -pub(crate) mod util { - use std::io; - use std::net::{IpAddr, Ipv6Addr, SocketAddr}; - use std::time::Duration; - - use crate::net::SocketAddressFamily; - use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; - use rustix::fd::{AsFd, OwnedFd}; - use rustix::io::Errno; - use rustix::net::sockopt; - - pub fn validate_unicast(addr: &SocketAddr) -> io::Result<()> { - match to_canonical(&addr.ip()) { - IpAddr::V4(ipv4) => { - if ipv4.is_multicast() || ipv4.is_broadcast() { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Both IPv4 broadcast and multicast addresses are not supported", - )) - } else { - Ok(()) - } - } - IpAddr::V6(ipv6) => { - if ipv6.is_multicast() { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "IPv6 multicast addresses are not supported", - )) - } else { - Ok(()) - } - } - } - } - - pub fn validate_remote_address(addr: &SocketAddr) -> io::Result<()> { - if to_canonical(&addr.ip()).is_unspecified() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Remote address may not be `0.0.0.0` or `::`", - )); - } - - if addr.port() == 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Remote port may not be 0", - )); - } - - Ok(()) - } - - pub fn validate_address_family( - addr: &SocketAddr, - socket_family: &SocketAddressFamily, - ) -> io::Result<()> { - match (socket_family, addr.ip()) { - (SocketAddressFamily::Ipv4, IpAddr::V4(_)) => Ok(()), - (SocketAddressFamily::Ipv6, IpAddr::V6(ipv6)) => { - if is_deprecated_ipv4_compatible(&ipv6) { - // Reject IPv4-*compatible* IPv6 addresses. They have been deprecated - // since 2006, OS handling of them is inconsistent and our own - // validations don't take them into account either. - // Note that these are not the same as IPv4-*mapped* IPv6 addresses. - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "IPv4-compatible IPv6 addresses are not supported", - )) - } else if ipv6.to_ipv4_mapped().is_some() { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "IPv4-mapped IPv6 address passed to an IPv6-only socket", - )) - } else { - Ok(()) - } - } - _ => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Address family mismatch", - )), - } - } - - // Can be removed once `IpAddr::to_canonical` becomes stable. - pub fn to_canonical(addr: &IpAddr) -> IpAddr { - match addr { - IpAddr::V4(ipv4) => IpAddr::V4(*ipv4), - IpAddr::V6(ipv6) => { - if let Some(ipv4) = ipv6.to_ipv4_mapped() { - IpAddr::V4(ipv4) - } else { - IpAddr::V6(*ipv6) - } - } - } - } - - fn is_deprecated_ipv4_compatible(addr: &Ipv6Addr) -> bool { - matches!(addr.segments(), [0, 0, 0, 0, 0, 0, _, _]) - && *addr != Ipv6Addr::UNSPECIFIED - && *addr != Ipv6Addr::LOCALHOST - } - - /* - * Syscalls wrappers with (opinionated) portability fixes. - */ - - pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> io::Result { - // Delegate socket creation to cap_net_ext. They handle a couple of things for us: - // - On Windows: call WSAStartup if not done before. - // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, - // or afterwards using ioctl or fcntl. Exact method depends on the platform. - - let socket = cap_std::net::UdpSocket::new(family, blocking)?; - Ok(OwnedFd::from(socket)) - } - - pub fn udp_bind(sockfd: Fd, addr: &SocketAddr) -> rustix::io::Result<()> { - rustix::net::bind(sockfd, addr).map_err(|error| match error { - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS - // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. - #[cfg(windows)] - Errno::NOBUFS => Errno::ADDRINUSE, - _ => error, - }) - } - - pub fn udp_disconnect(sockfd: Fd) -> rustix::io::Result<()> { - match rustix::net::connect_unspec(sockfd) { - // BSD platforms return an error even if the UDP socket was disconnected successfully. - // - // MacOS was kind enough to document this: https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/connect.2.html - // > Datagram sockets may dissolve the association by connecting to an - // > invalid address, such as a null address or an address with the address - // > family set to AF_UNSPEC (the error EAFNOSUPPORT will be harmlessly - // > returned). - // - // ... except that this appears to be incomplete, because experiments - // have shown that MacOS actually returns EINVAL, depending on the - // address family of the socket. - #[cfg(target_os = "macos")] - Err(Errno::INVAL | Errno::AFNOSUPPORT) => Ok(()), - r => r, - } - } - - // Even though SO_REUSEADDR is a SOL_* level option, this function contain a - // compatibility fix specific to TCP. That's why it contains the `_tcp_` infix instead of `_socket_`. - pub fn set_tcp_reuseaddr(sockfd: Fd, value: bool) -> rustix::io::Result<()> { - // When a TCP socket is closed, the system may - // temporarily reserve that specific address+port pair in a so called - // TIME_WAIT state. During that period, any attempt to rebind to that pair - // will fail. Setting SO_REUSEADDR to true bypasses that behaviour. Unlike - // the name "SO_REUSEADDR" might suggest, it does not allow multiple - // active sockets to share the same local address. - - // On Windows that behavior is the default, so there is no need to manually - // configure such an option. But (!), Windows _does_ have an identically - // named socket option which allows users to "hijack" active sockets. - // This is definitely not what we want to do here. - - // Microsoft's own documentation[1] states that we should set SO_EXCLUSIVEADDRUSE - // instead (to the inverse value), however the github issue below[2] seems - // to indicate that that may no longer be correct. - // [1]: https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse - // [2]: https://github.com/python-trio/trio/issues/928 - - #[cfg(not(windows))] - sockopt::set_socket_reuseaddr(sockfd, value)?; - #[cfg(windows)] - let _ = (sockfd, value); - - Ok(()) - } - - pub fn set_tcp_keepidle(sockfd: Fd, value: Duration) -> rustix::io::Result<()> { - if value <= Duration::ZERO { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - return Err(Errno::INVAL); - } - - // Ensure that the value passed to the actual syscall never gets rounded down to 0. - const MIN_SECS: u64 = 1; - - // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. - const MAX_SECS: u64 = i16::MAX as u64; - - sockopt::set_tcp_keepidle( - sockfd, - value.clamp(Duration::from_secs(MIN_SECS), Duration::from_secs(MAX_SECS)), - ) - } - - pub fn set_tcp_keepintvl(sockfd: Fd, value: Duration) -> rustix::io::Result<()> { - if value <= Duration::ZERO { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - return Err(Errno::INVAL); - } - - // Ensure that any fractional value passed to the actual syscall never gets rounded down to 0. - const MIN_SECS: u64 = 1; - - // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. - const MAX_SECS: u64 = i16::MAX as u64; - - sockopt::set_tcp_keepintvl( - sockfd, - value.clamp(Duration::from_secs(MIN_SECS), Duration::from_secs(MAX_SECS)), - ) - } - - pub fn set_tcp_keepcnt(sockfd: Fd, value: u32) -> rustix::io::Result<()> { - if value == 0 { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - return Err(Errno::INVAL); - } - - const MIN_CNT: u32 = 1; - // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. - const MAX_CNT: u32 = i8::MAX as u32; - - sockopt::set_tcp_keepcnt(sockfd, value.clamp(MIN_CNT, MAX_CNT)) - } - - pub fn get_ip_ttl(sockfd: Fd) -> rustix::io::Result { - sockopt::ip_ttl(sockfd)? - .try_into() - .map_err(|_| Errno::OPNOTSUPP) - } - - pub fn get_ipv6_unicast_hops(sockfd: Fd) -> rustix::io::Result { - sockopt::ipv6_unicast_hops(sockfd) - } - - pub fn set_ip_ttl(sockfd: Fd, value: u8) -> rustix::io::Result<()> { - match value { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - // - // A well-behaved IP application should never send out new packets with TTL 0. - // We validate the value ourselves because OS'es are not consistent in this. - // On Linux the validation is even inconsistent between their IPv4 and IPv6 implementation. - 0 => Err(Errno::INVAL), - _ => sockopt::set_ip_ttl(sockfd, value.into()), - } - } - - pub fn set_ipv6_unicast_hops(sockfd: Fd, value: u8) -> rustix::io::Result<()> { - match value { - 0 => Err(Errno::INVAL), // See `set_ip_ttl` - _ => sockopt::set_ipv6_unicast_hops(sockfd, Some(value)), - } - } - - fn normalize_get_buffer_size(value: usize) -> usize { - if cfg!(target_os = "linux") { - // Linux doubles the value passed to setsockopt to allow space for bookkeeping overhead. - // getsockopt returns this internally doubled value. - // We'll half the value to at least get it back into the same ballpark that the application requested it in. - // - // This normalized behavior is tested for in: test-programs/src/bin/preview2_tcp_sockopts.rs - value / 2 - } else { - value - } - } - - fn normalize_set_buffer_size(value: usize) -> usize { - value.clamp(1, i32::MAX as usize) - } - - pub fn get_socket_recv_buffer_size(sockfd: Fd) -> rustix::io::Result { - let value = sockopt::socket_recv_buffer_size(sockfd)?; - Ok(normalize_get_buffer_size(value)) - } - - pub fn get_socket_send_buffer_size(sockfd: Fd) -> rustix::io::Result { - let value = sockopt::socket_send_buffer_size(sockfd)?; - Ok(normalize_get_buffer_size(value)) - } - - pub fn set_socket_recv_buffer_size( - sockfd: Fd, - value: usize, - ) -> rustix::io::Result<()> { - if value == 0 { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - return Err(Errno::INVAL); - } - - let value = normalize_set_buffer_size(value); - - match sockopt::set_socket_recv_buffer_size(sockfd, value) { - // Most platforms (Linux, Windows, Fuchsia, Solaris, Illumos, Haiku, ESP-IDF, ..and more?) treat the value - // passed to SO_SNDBUF/SO_RCVBUF as a performance tuning hint and silently clamp the input if it exceeds - // their capability. - // As far as I can see, only the *BSD family views this option as a hard requirement and fails when the - // value is out of range. We normalize this behavior in favor of the more commonly understood - // "performance hint" semantics. In other words; even ENOBUFS is "Ok". - // A future improvement could be to query the corresponding sysctl on *BSD platforms and clamp the input - // `size` ourselves, to completely close the gap with other platforms. - // - // This normalized behavior is tested for in: test-programs/src/bin/preview2_tcp_sockopts.rs - Err(Errno::NOBUFS) => Ok(()), - r => r, - } - } - - pub fn set_socket_send_buffer_size( - sockfd: Fd, - value: usize, - ) -> rustix::io::Result<()> { - if value == 0 { - // WIT: "If the provided value is 0, an `invalid-argument` error is returned." - return Err(Errno::INVAL); - } - - let value = normalize_set_buffer_size(value); - - match sockopt::set_socket_send_buffer_size(sockfd, value) { - Err(Errno::NOBUFS) => Ok(()), // See set_socket_recv_buffer_size - r => r, - } - } -} diff --git a/crates/wasi/src/p2/host/tcp.rs b/crates/wasi/src/p2/host/tcp.rs index 1f04fef247b2..f5c723227c60 100644 --- a/crates/wasi/src/p2/host/tcp.rs +++ b/crates/wasi/src/p2/host/tcp.rs @@ -1,9 +1,9 @@ -use crate::net::{SocketAddrUse, SocketAddressFamily}; use crate::p2::bindings::{ sockets::network::{IpAddressFamily, IpSocketAddress, Network}, sockets::tcp::{self, ShutdownType}, }; use crate::p2::{SocketResult, WasiImpl, WasiView}; +use crate::sockets::{SocketAddrUse, SocketAddressFamily}; use std::net::SocketAddr; use std::time::Duration; use wasmtime::component::Resource; @@ -197,8 +197,7 @@ where ) -> SocketResult<()> { let table = self.table(); let socket = table.get_mut(&this)?; - let duration = Duration::from_nanos(value); - socket.set_keep_alive_idle_time(duration) + socket.set_keep_alive_idle_time(value) } fn keep_alive_interval(&mut self, this: Resource) -> SocketResult { @@ -249,7 +248,7 @@ where let table = self.table(); let socket = table.get(&this)?; - Ok(socket.receive_buffer_size()? as u64) + Ok(socket.receive_buffer_size()?) } fn set_receive_buffer_size( @@ -259,7 +258,6 @@ where ) -> SocketResult<()> { let table = self.table(); let socket = table.get_mut(&this)?; - let value = value.try_into().unwrap_or(usize::MAX); socket.set_receive_buffer_size(value) } @@ -267,7 +265,7 @@ where let table = self.table(); let socket = table.get(&this)?; - Ok(socket.send_buffer_size()? as u64) + Ok(socket.send_buffer_size()?) } fn set_send_buffer_size( @@ -277,7 +275,6 @@ where ) -> SocketResult<()> { let table = self.table(); let socket = table.get_mut(&this)?; - let value = value.try_into().unwrap_or(usize::MAX); socket.set_send_buffer_size(value) } diff --git a/crates/wasi/src/p2/host/udp.rs b/crates/wasi/src/p2/host/udp.rs index 84491ff7ddd1..553c422e1811 100644 --- a/crates/wasi/src/p2/host/udp.rs +++ b/crates/wasi/src/p2/host/udp.rs @@ -1,9 +1,13 @@ -use crate::net::{SocketAddrUse, SocketAddressFamily}; use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}; use crate::p2::bindings::sockets::udp; -use crate::p2::host::network::util; use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState}; use crate::p2::{IoView, Pollable, SocketError, SocketResult, WasiImpl, WasiView}; +use crate::sockets::util::{ + get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address, + receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, + set_unicast_hop_limit, udp_bind, udp_disconnect, +}; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily}; use anyhow::anyhow; use async_trait::async_trait; use io_lifetimes::AsSocketlike; @@ -13,11 +17,6 @@ use tokio::io::Interest; use wasmtime::component::Resource; use wasmtime_wasi_io::poll::DynPollable; -/// Theoretical maximum byte size of a UDP datagram, the real limit is lower, -/// but we do not account for e.g. the transport layer here for simplicity. -/// In practice, datagrams are typically less than 1500 bytes. -const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize; - impl udp::Host for WasiImpl where T: WasiView {} impl udp::HostUdpSocket for WasiImpl @@ -49,23 +48,15 @@ where let socket = table.get(&this)?; let local_address: SocketAddr = local_address.into(); - util::validate_address_family(&local_address, &socket.family)?; + if !is_valid_address_family(local_address.ip(), socket.family) { + return Err(ErrorCode::InvalidArgument.into()); + } { check.check(local_address, SocketAddrUse::UdpBind).await?; // Perform the OS bind call. - util::udp_bind(socket.udp_socket(), &local_address).map_err(|error| match error { - // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: - // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket - // - // The most common reasons for this error should have already - // been handled by our own validation slightly higher up in this - // function. This error mapping is here just in case there is - // an edge case we didn't catch. - Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, - _ => ErrorCode::from(error), - })?; + udp_bind(socket.udp_socket(), local_address)?; } let socket = table.get_mut(&this)?; @@ -121,7 +112,7 @@ where // Step #1: Disconnect if let UdpState::Connected = socket.udp_state { - util::udp_disconnect(socket.udp_socket())?; + udp_disconnect(socket.udp_socket())?; socket.udp_state = UdpState::Bound; } @@ -130,8 +121,11 @@ where let Some(check) = socket.socket_addr_check.as_ref() else { return Err(ErrorCode::InvalidState.into()); }; - util::validate_remote_address(&connect_addr)?; - util::validate_address_family(&connect_addr, &socket.family)?; + if !is_valid_remote_address(connect_addr) + || !is_valid_address_family(connect_addr.ip(), socket.family) + { + return Err(ErrorCode::InvalidArgument.into()); + } check.check(connect_addr, SocketAddrUse::UdpConnect).await?; rustix::net::connect(socket.udp_socket(), &connect_addr).map_err( @@ -218,8 +212,8 @@ where let socket = table.get(&this)?; let ttl = match socket.family { - SocketAddressFamily::Ipv4 => util::get_ip_ttl(socket.udp_socket())?, - SocketAddressFamily::Ipv6 => util::get_ipv6_unicast_hops(socket.udp_socket())?, + SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?, + SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?, }; Ok(ttl) @@ -233,10 +227,7 @@ where let table = self.table(); let socket = table.get(&this)?; - match socket.family { - SocketAddressFamily::Ipv4 => util::set_ip_ttl(socket.udp_socket(), value)?, - SocketAddressFamily::Ipv6 => util::set_ipv6_unicast_hops(socket.udp_socket(), value)?, - } + set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?; Ok(()) } @@ -245,8 +236,8 @@ where let table = self.table(); let socket = table.get(&this)?; - let value = util::get_socket_recv_buffer_size(socket.udp_socket())?; - Ok(value as u64) + let value = receive_buffer_size(socket.udp_socket())?; + Ok(value) } fn set_receive_buffer_size( @@ -256,9 +247,8 @@ where ) -> SocketResult<()> { let table = self.table(); let socket = table.get(&this)?; - let value = value.try_into().unwrap_or(usize::MAX); - util::set_socket_recv_buffer_size(socket.udp_socket(), value)?; + set_receive_buffer_size(socket.udp_socket(), value)?; Ok(()) } @@ -266,8 +256,8 @@ where let table = self.table(); let socket = table.get(&this)?; - let value = util::get_socket_send_buffer_size(socket.udp_socket())?; - Ok(value as u64) + let value = send_buffer_size(socket.udp_socket())?; + Ok(value) } fn set_send_buffer_size( @@ -277,9 +267,8 @@ where ) -> SocketResult<()> { let table = self.table(); let socket = table.get(&this)?; - let value = value.try_into().unwrap_or(usize::MAX); - util::set_socket_send_buffer_size(socket.udp_socket(), value)?; + set_send_buffer_size(socket.udp_socket(), value)?; Ok(()) } @@ -448,8 +437,10 @@ where _ => return Err(ErrorCode::InvalidArgument.into()), }; - util::validate_remote_address(&addr)?; - util::validate_address_family(&addr, &stream.family)?; + if !is_valid_remote_address(addr) || !is_valid_address_family(addr.ip(), stream.family) + { + return Err(ErrorCode::InvalidArgument.into()); + } if stream.remote_address == Some(addr) { stream.inner.try_send(&datagram.data)?; diff --git a/crates/wasi/src/p2/ip_name_lookup.rs b/crates/wasi/src/p2/ip_name_lookup.rs index bc3d8a09a0ed..b0252fa7c2c2 100644 --- a/crates/wasi/src/p2/ip_name_lookup.rs +++ b/crates/wasi/src/p2/ip_name_lookup.rs @@ -1,18 +1,16 @@ use crate::p2::bindings::sockets::ip_name_lookup::{Host, HostResolveAddressStream}; use crate::p2::bindings::sockets::network::{ErrorCode, IpAddress, Network}; -use crate::p2::host::network::util; use crate::p2::{IoView, SocketError, WasiImpl, WasiView}; use crate::runtime::{AbortOnDropJoinHandle, spawn_blocking}; use anyhow::Result; use std::mem; -use std::net::{Ipv6Addr, ToSocketAddrs}; +use std::net::ToSocketAddrs; use std::pin::Pin; -use std::str::FromStr; use std::vec; use wasmtime::component::Resource; use wasmtime_wasi_io::poll::{DynPollable, Pollable, subscribe}; -use super::network::{from_ipv4_addr, from_ipv6_addr}; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, parse_host}; pub enum ResolveAddressStream { Waiting(AbortOnDropJoinHandle, SocketError>>), @@ -30,7 +28,7 @@ where ) -> Result, SocketError> { let network = self.table().get(&network)?; - let host = parse(&name)?; + let host = parse_host(&name)?; if !network.allow_ip_name_lookup { return Err(ErrorCode::PermanentResolverFailure.into()); @@ -92,24 +90,6 @@ impl Pollable for ResolveAddressStream { } } -fn parse(name: &str) -> Result { - // `url::Host::parse` serves us two functions: - // 1. validate the input is a valid domain name or IP, - // 2. convert unicode domains to punycode. - match url::Host::parse(&name) { - Ok(host) => Ok(host), - - // `url::Host::parse` doesn't understand bare IPv6 addresses without [brackets] - Err(_) => { - if let Ok(addr) = Ipv6Addr::from_str(name) { - Ok(url::Host::Ipv6(addr)) - } else { - Err(ErrorCode::InvalidArgument.into()) - } - } - } -} - fn blocking_resolve(host: &url::Host) -> Result, SocketError> { match host { url::Host::Ipv4(v4addr) => Ok(vec![IpAddress::Ipv4(from_ipv4_addr(*v4addr))]), @@ -121,7 +101,7 @@ fn blocking_resolve(host: &url::Host) -> Result, SocketError> { let addresses = (domain.as_str(), 0) .to_socket_addrs() .map_err(|_| ErrorCode::NameUnresolvable)? // If/when we use `getaddrinfo` directly, map the error properly. - .map(|addr| util::to_canonical(&addr.ip()).into()) + .map(|addr| addr.ip().to_canonical().into()) .collect(); Ok(addresses) diff --git a/crates/wasi/src/p2/mod.rs b/crates/wasi/src/p2/mod.rs index b31fb1856607..45ca8e3c52a6 100644 --- a/crates/wasi/src/p2/mod.rs +++ b/crates/wasi/src/p2/mod.rs @@ -244,7 +244,7 @@ mod write_stream; pub use self::ctx::{WasiCtx, WasiCtxBuilder}; pub use self::filesystem::{FsError, FsResult}; -pub use self::network::{SocketError, SocketResult}; +pub use self::network::{Network, SocketError, SocketResult}; pub use self::stdio::{ AsyncStdinStream, AsyncStdoutStream, InputFile, IsATTY, OutputFile, Stderr, Stdin, StdinStream, Stdout, StdoutStream, stderr, stdin, stdout, diff --git a/crates/wasi/src/p2/network.rs b/crates/wasi/src/p2/network.rs index 0a3090308ad1..d608624411b8 100644 --- a/crates/wasi/src/p2/network.rs +++ b/crates/wasi/src/p2/network.rs @@ -1,5 +1,8 @@ -use crate::TrappableError; -use crate::p2::bindings::sockets::network::{ErrorCode, Ipv4Address, Ipv6Address}; +use core::net::SocketAddr; + +use crate::p2::bindings::sockets::network::ErrorCode; +use crate::sockets::SocketAddrCheck; +use crate::{SocketAddrUse, TrappableError}; pub type SocketResult = Result; @@ -23,22 +26,44 @@ impl From for SocketError { } } -pub(crate) fn to_ipv4_addr(addr: Ipv4Address) -> std::net::Ipv4Addr { - let (x0, x1, x2, x3) = addr; - std::net::Ipv4Addr::new(x0, x1, x2, x3) +impl From for SocketError { + fn from(error: crate::sockets::util::ErrorCode) -> Self { + ErrorCode::from(error).into() + } } -pub(crate) fn from_ipv4_addr(addr: std::net::Ipv4Addr) -> Ipv4Address { - let [x0, x1, x2, x3] = addr.octets(); - (x0, x1, x2, x3) +impl From for ErrorCode { + fn from(error: crate::sockets::util::ErrorCode) -> Self { + match error { + crate::sockets::util::ErrorCode::Unknown => Self::Unknown, + crate::sockets::util::ErrorCode::AccessDenied => Self::AccessDenied, + crate::sockets::util::ErrorCode::NotSupported => Self::NotSupported, + crate::sockets::util::ErrorCode::InvalidArgument => Self::InvalidArgument, + crate::sockets::util::ErrorCode::OutOfMemory => Self::OutOfMemory, + crate::sockets::util::ErrorCode::Timeout => Self::Timeout, + crate::sockets::util::ErrorCode::InvalidState => Self::InvalidState, + crate::sockets::util::ErrorCode::AddressNotBindable => Self::AddressNotBindable, + crate::sockets::util::ErrorCode::AddressInUse => Self::AddressInUse, + crate::sockets::util::ErrorCode::RemoteUnreachable => Self::RemoteUnreachable, + crate::sockets::util::ErrorCode::ConnectionRefused => Self::ConnectionRefused, + crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, + crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, + crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + } + } } -pub(crate) fn to_ipv6_addr(addr: Ipv6Address) -> std::net::Ipv6Addr { - let (x0, x1, x2, x3, x4, x5, x6, x7) = addr; - std::net::Ipv6Addr::new(x0, x1, x2, x3, x4, x5, x6, x7) +pub struct Network { + pub socket_addr_check: SocketAddrCheck, + pub allow_ip_name_lookup: bool, } -pub(crate) fn from_ipv6_addr(addr: std::net::Ipv6Addr) -> Ipv6Address { - let [x0, x1, x2, x3, x4, x5, x6, x7] = addr.segments(); - (x0, x1, x2, x3, x4, x5, x6, x7) +impl Network { + pub async fn check_socket_addr( + &self, + addr: SocketAddr, + reason: SocketAddrUse, + ) -> std::io::Result<()> { + self.socket_addr_check.check(addr, reason).await + } } diff --git a/crates/wasi/src/p2/tcp.rs b/crates/wasi/src/p2/tcp.rs index f552b145d63e..188b2148d81f 100644 --- a/crates/wasi/src/p2/tcp.rs +++ b/crates/wasi/src/p2/tcp.rs @@ -1,11 +1,16 @@ -use crate::net::{DEFAULT_TCP_BACKLOG, SocketAddressFamily}; use crate::p2::bindings::sockets::tcp::ErrorCode; -use crate::p2::host::network; use crate::p2::{ DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError, SocketResult, StreamError, }; use crate::runtime::{AbortOnDropJoinHandle, with_ambient_tokio_runtime}; +use crate::sockets::util::{ + get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, + is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count, + set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size, + set_send_buffer_size, set_unicast_hop_limit, tcp_bind, +}; +use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily}; use anyhow::Result; use cap_net_ext::AddressFamily; use futures::Future; @@ -96,13 +101,13 @@ pub struct TcpSocket { // on all platforms. So we keep track of which options have been explicitly // set and manually apply those values to newly accepted clients. #[cfg(target_os = "macos")] - receive_buffer_size: Option, + receive_buffer_size: Option, #[cfg(target_os = "macos")] - send_buffer_size: Option, + send_buffer_size: Option, #[cfg(target_os = "macos")] hop_limit: Option, #[cfg(target_os = "macos")] - keep_alive_idle_time: Option, + keep_alive_idle_time: Option, } impl TcpSocket { @@ -166,49 +171,21 @@ impl TcpSocket { } impl TcpSocket { - pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> { + pub fn start_bind(&mut self, local_address: SocketAddr) -> Result<(), ErrorCode> { let tokio_socket = match &self.tcp_state { TcpState::Default(socket) => socket, TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()), - _ => return Err(Errno::ISCONN.into()), + _ => return Err(ErrorCode::InvalidState), }; - network::util::validate_unicast(&local_address)?; - network::util::validate_address_family(&local_address, &self.family)?; - + if !is_valid_unicast_address(local_address.ip()) + || !is_valid_address_family(local_address.ip(), self.family) { - // Automatically bypass the TIME_WAIT state when the user is trying - // to bind to a specific port: - let reuse_addr = local_address.port() > 0; - - // Unconditionally (re)set SO_REUSEADDR, even when the value is false. - // This ensures we're not accidentally affected by any socket option - // state left behind by a previous failed call to this method (start_bind). - network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?; - - // Perform the OS bind call. - tokio_socket.bind(local_address).map_err(|error| { - match Errno::from_io_error(&error) { - // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: - // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket - // - // The most common reasons for this error should have already - // been handled by our own validation slightly higher up in this - // function. This error mapping is here just in case there is - // an edge case we didn't catch. - Some(Errno::AFNOSUPPORT) => io::Error::new( - io::ErrorKind::InvalidInput, - "The specified address is not a valid address for the address family of the specified socket", - ), - - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS - // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. - #[cfg(windows)] - Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"), + return Err(ErrorCode::InvalidArgument); + }; - _ => error, - } - })?; + { + tcp_bind(&tokio_socket, local_address)?; self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { TcpState::Default(socket) => TcpState::BindStarted(socket), @@ -244,9 +221,12 @@ impl TcpSocket { _ => return Err(ErrorCode::InvalidState.into()), }; - network::util::validate_unicast(&remote_address)?; - network::util::validate_remote_address(&remote_address)?; - network::util::validate_address_family(&remote_address, &self.family)?; + if !is_valid_unicast_address(remote_address.ip()) + || !is_valid_remote_address(remote_address) + || !is_valid_address_family(remote_address.ip(), self.family) + { + return Err(ErrorCode::InvalidArgument.into()); + }; let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) = std::mem::replace(&mut self.tcp_state, TcpState::Closed) @@ -420,20 +400,20 @@ impl TcpSocket { // and only if a specific value was explicitly set on the listener. if let Some(size) = self.receive_buffer_size { - _ = network::util::set_socket_recv_buffer_size(&client, size); // Ignore potential error. + _ = set_receive_buffer_size(&client, size); // Ignore potential error. } if let Some(size) = self.send_buffer_size { - _ = network::util::set_socket_send_buffer_size(&client, size); // Ignore potential error. + _ = set_send_buffer_size(&client, size); // Ignore potential error. } // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) { - _ = network::util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error. + _ = rustix::net::sockopt::set_ipv6_unicast_hops(&client, Some(ttl)); // Ignore potential error. } if let Some(value) = self.keep_alive_idle_time { - _ = network::util::set_tcp_keepidle(&client, value); // Ignore potential error. + _ = set_keep_alive_idle_time(&client, value); // Ignore potential error. } } @@ -530,15 +510,15 @@ impl TcpSocket { Ok(sockopt::tcp_keepidle(view)?) } - pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> { + pub fn set_keep_alive_idle_time(&mut self, value: u64) -> SocketResult<()> { { let view = &*self.as_std_view()?; - network::util::set_tcp_keepidle(view, duration)?; + set_keep_alive_idle_time(view, value)?; } #[cfg(target_os = "macos")] { - self.keep_alive_idle_time = Some(duration); + self.keep_alive_idle_time = Some(value); } Ok(()) @@ -551,7 +531,7 @@ impl TcpSocket { pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> { let view = &*self.as_std_view()?; - Ok(network::util::set_tcp_keepintvl(view, duration)?) + Ok(set_keep_alive_interval(view, duration)?) } pub fn keep_alive_count(&self) -> SocketResult { @@ -561,17 +541,13 @@ impl TcpSocket { pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> { let view = &*self.as_std_view()?; - Ok(network::util::set_tcp_keepcnt(view, value)?) + Ok(set_keep_alive_count(view, value)?) } pub fn hop_limit(&self) -> SocketResult { let view = &*self.as_std_view()?; - let ttl = match self.family { - SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?, - SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?, - }; - + let ttl = get_unicast_hop_limit(view, self.family)?; Ok(ttl) } @@ -579,10 +555,7 @@ impl TcpSocket { { let view = &*self.as_std_view()?; - match self.family { - SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?, - SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?, - } + set_unicast_hop_limit(view, self.family, value)?; } #[cfg(target_os = "macos")] @@ -593,17 +566,17 @@ impl TcpSocket { Ok(()) } - pub fn receive_buffer_size(&self) -> SocketResult { + pub fn receive_buffer_size(&self) -> SocketResult { let view = &*self.as_std_view()?; - Ok(network::util::get_socket_recv_buffer_size(view)?) + Ok(receive_buffer_size(view)?) } - pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> { + pub fn set_receive_buffer_size(&mut self, value: u64) -> SocketResult<()> { { let view = &*self.as_std_view()?; - network::util::set_socket_recv_buffer_size(view, value)?; + set_receive_buffer_size(view, value)?; } #[cfg(target_os = "macos")] @@ -614,17 +587,17 @@ impl TcpSocket { Ok(()) } - pub fn send_buffer_size(&self) -> SocketResult { + pub fn send_buffer_size(&self) -> SocketResult { let view = &*self.as_std_view()?; - Ok(network::util::get_socket_send_buffer_size(view)?) + Ok(send_buffer_size(view)?) } - pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> { + pub fn set_send_buffer_size(&mut self, value: u64) -> SocketResult<()> { { let view = &*self.as_std_view()?; - network::util::set_socket_send_buffer_size(view, value)?; + set_send_buffer_size(view, value)?; } #[cfg(target_os = "macos")] diff --git a/crates/wasi/src/p2/udp.rs b/crates/wasi/src/p2/udp.rs index 3d98e56ac7a2..a066fb97eb15 100644 --- a/crates/wasi/src/p2/udp.rs +++ b/crates/wasi/src/p2/udp.rs @@ -1,8 +1,8 @@ -use crate::net::{SocketAddrCheck, SocketAddressFamily}; -use crate::p2::host::network::util; use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::udp_socket; +use crate::sockets::{SocketAddrCheck, SocketAddressFamily}; use async_trait::async_trait; -use cap_net_ext::{AddressFamily, Blocking}; +use cap_net_ext::AddressFamily; use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; use std::io; use std::net::SocketAddr; @@ -59,7 +59,7 @@ impl UdpSocket { pub fn new(family: AddressFamily) -> io::Result { // Create a new host socket and set it to non-blocking, which is needed // by our async implementation. - let fd = util::udp_socket(family, Blocking::No)?; + let fd = udp_socket(family)?; let socket_address_family = match family { AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, @@ -69,7 +69,7 @@ impl UdpSocket { } }; - let socket = Self::setup_tokio_udp_socket(fd)?; + let socket = Self::setup_tokio_udp_socket(fd.into())?; Ok(UdpSocket { inner: Arc::new(socket), diff --git a/crates/wasi/src/p3/bindings.rs b/crates/wasi/src/p3/bindings.rs index 253794a03524..29718611c525 100644 --- a/crates/wasi/src/p3/bindings.rs +++ b/crates/wasi/src/p3/bindings.rs @@ -94,6 +94,8 @@ mod generated { with: { "wasi:cli/terminal-input/terminal-input": crate::p3::cli::TerminalInput, "wasi:cli/terminal-output/terminal-output": crate::p3::cli::TerminalOutput, + "wasi:sockets/types/tcp-socket": crate::p3::sockets::tcp::TcpSocket, + "wasi:sockets/types/udp-socket": crate::p3::sockets::udp::UdpSocket, } }); } diff --git a/crates/wasi/src/p3/cli/host.rs b/crates/wasi/src/p3/cli/host.rs index 4bba471fb2eb..c37a569594ee 100644 --- a/crates/wasi/src/p3/cli/host.rs +++ b/crates/wasi/src/p3/cli/host.rs @@ -1,5 +1,6 @@ use crate::I32Exit; use crate::cli::IsTerminal; +use crate::p3::DEFAULT_BUFFER_CAPACITY; use crate::p3::bindings::cli::{ environment, exit, stderr, stdin, stdout, terminal_input, terminal_output, terminal_stderr, terminal_stdin, terminal_stdout, @@ -24,7 +25,7 @@ where V: AsyncRead + Send + Sync + Unpin + 'static, { async fn run(mut self, store: &Accessor) -> wasmtime::Result<()> { - let mut buf = BytesMut::with_capacity(8192); + let mut buf = BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY); while !self.tx.is_closed() { match self.rx.read_buf(&mut buf).await { Ok(0) => return Ok(()), @@ -57,7 +58,7 @@ where V: AsyncWrite + Send + Sync + Unpin + 'static, { async fn run(mut self, store: &Accessor) -> wasmtime::Result<()> { - let mut buf = BytesMut::with_capacity(8192); + let mut buf = BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY); while !self.rx.is_closed() { buf = self.rx.read(store, buf).await; match self.tx.write_all(&buf).await { diff --git a/crates/wasi/src/p3/ctx.rs b/crates/wasi/src/p3/ctx.rs index 7a42fac4670b..6876038f25d3 100644 --- a/crates/wasi/src/p3/ctx.rs +++ b/crates/wasi/src/p3/ctx.rs @@ -1,9 +1,9 @@ use crate::cli::WasiCliCtx; use crate::clocks::{HostMonotonicClock, HostWallClock, WasiClocksCtx}; -use crate::net::SocketAddrUse; use crate::p3::cli::{InputStream, OutputStream}; use crate::p3::filesystem::Dir; use crate::random::WasiRandomCtx; +use crate::sockets::{SocketAddrUse, WasiSocketsCtx}; use crate::{DirPerms, FilePerms, OpenMode}; use anyhow::Result; use cap_rand::RngCore; @@ -432,6 +432,7 @@ impl WasiCtxBuilder { random, clocks, cli, + sockets, .. }, built: _, @@ -440,9 +441,10 @@ impl WasiCtxBuilder { self.built = true; WasiCtx { - random, - clocks, cli, + clocks, + random, + sockets, } } } @@ -496,9 +498,10 @@ impl WasiCtxBuilder { /// ``` #[derive(Default)] pub struct WasiCtx { - pub random: WasiRandomCtx, - pub clocks: WasiClocksCtx, pub cli: WasiCliCtx, Box>, + pub clocks: WasiClocksCtx, + pub random: WasiRandomCtx, + pub sockets: WasiSocketsCtx, } impl WasiCtx { diff --git a/crates/wasi/src/p3/mod.rs b/crates/wasi/src/p3/mod.rs index 36f2cc4e1224..70a702b4b0a9 100644 --- a/crates/wasi/src/p3/mod.rs +++ b/crates/wasi/src/p3/mod.rs @@ -14,16 +14,21 @@ pub mod clocks; mod ctx; pub mod filesystem; pub mod random; +pub mod sockets; mod view; use wasmtime::component::Linker; use crate::p3::bindings::LinkOptions; use crate::p3::cli::WasiCliCtxView; +use crate::sockets::WasiSocketsCtxView; pub use self::ctx::{WasiCtx, WasiCtxBuilder}; pub use self::view::{WasiCtxView, WasiView}; +// Default buffer capacity to use for reads of byte-sized values. +const DEFAULT_BUFFER_CAPACITY: usize = 8192; + /// Add all WASI interfaces from this module into the `linker` provided. /// /// This function will add all interfaces implemented by this module to the @@ -84,12 +89,10 @@ where pub fn add_to_linker_with_options( linker: &mut Linker, options: &LinkOptions, -) -> anyhow::Result<()> +) -> wasmtime::Result<()> where T: WasiView + 'static, { - clocks::add_to_linker_impl(linker, |x| &mut x.ctx().ctx.clocks)?; - random::add_to_linker_impl(linker, |x| &mut x.ctx().ctx.random)?; cli::add_to_linker_impl(linker, &options.into(), |x| { let WasiCtxView { ctx, table } = x.ctx(); WasiCliCtxView { @@ -97,5 +100,14 @@ where table, } })?; + clocks::add_to_linker_impl(linker, |x| &mut x.ctx().ctx.clocks)?; + random::add_to_linker_impl(linker, |x| &mut x.ctx().ctx.random)?; + sockets::add_to_linker_impl(linker, |x| { + let WasiCtxView { ctx, table } = x.ctx(); + WasiSocketsCtxView { + ctx: &mut ctx.sockets, + table, + } + })?; Ok(()) } diff --git a/crates/wasi/src/p3/sockets/conv.rs b/crates/wasi/src/p3/sockets/conv.rs new file mode 100644 index 000000000000..ee4920dfa0d8 --- /dev/null +++ b/crates/wasi/src/p3/sockets/conv.rs @@ -0,0 +1,227 @@ +use core::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +use std::net::ToSocketAddrs; + +use rustix::io::Errno; +use tracing::debug; + +use crate::p3::bindings::sockets::types; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; + +impl From for types::IpAddress { + fn from(addr: IpAddr) -> Self { + match addr { + IpAddr::V4(v4) => Self::Ipv4(from_ipv4_addr(v4)), + IpAddr::V6(v6) => Self::Ipv6(from_ipv6_addr(v6)), + } + } +} + +impl From for IpAddr { + fn from(addr: types::IpAddress) -> Self { + match addr { + types::IpAddress::Ipv4(v4) => Self::V4(to_ipv4_addr(v4)), + types::IpAddress::Ipv6(v6) => Self::V6(to_ipv6_addr(v6)), + } + } +} + +impl From for SocketAddr { + fn from(addr: types::IpSocketAddress) -> Self { + match addr { + types::IpSocketAddress::Ipv4(ipv4) => Self::V4(ipv4.into()), + types::IpSocketAddress::Ipv6(ipv6) => Self::V6(ipv6.into()), + } + } +} + +impl From for types::IpSocketAddress { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(v4) => Self::Ipv4(v4.into()), + SocketAddr::V6(v6) => Self::Ipv6(v6.into()), + } + } +} + +impl From for SocketAddrV4 { + fn from(addr: types::Ipv4SocketAddress) -> Self { + Self::new(to_ipv4_addr(addr.address), addr.port) + } +} + +impl From for types::Ipv4SocketAddress { + fn from(addr: SocketAddrV4) -> Self { + Self { + address: from_ipv4_addr(*addr.ip()), + port: addr.port(), + } + } +} + +impl From for SocketAddrV6 { + fn from(addr: types::Ipv6SocketAddress) -> Self { + Self::new( + to_ipv6_addr(addr.address), + addr.port, + addr.flow_info, + addr.scope_id, + ) + } +} + +impl From for types::Ipv6SocketAddress { + fn from(addr: SocketAddrV6) -> Self { + Self { + address: from_ipv6_addr(*addr.ip()), + port: addr.port(), + flow_info: addr.flowinfo(), + scope_id: addr.scope_id(), + } + } +} + +impl ToSocketAddrs for types::IpSocketAddress { + type Iter = ::Iter; + + fn to_socket_addrs(&self) -> std::io::Result { + SocketAddr::from(*self).to_socket_addrs() + } +} + +impl ToSocketAddrs for types::Ipv4SocketAddress { + type Iter = ::Iter; + + fn to_socket_addrs(&self) -> std::io::Result { + SocketAddrV4::from(*self).to_socket_addrs() + } +} + +impl ToSocketAddrs for types::Ipv6SocketAddress { + type Iter = ::Iter; + + fn to_socket_addrs(&self) -> std::io::Result { + SocketAddrV6::from(*self).to_socket_addrs() + } +} + +impl From for cap_net_ext::AddressFamily { + fn from(family: types::IpAddressFamily) -> Self { + match family { + types::IpAddressFamily::Ipv4 => Self::Ipv4, + types::IpAddressFamily::Ipv6 => Self::Ipv6, + } + } +} + +impl From for types::IpAddressFamily { + fn from(family: cap_net_ext::AddressFamily) -> Self { + match family { + cap_net_ext::AddressFamily::Ipv4 => Self::Ipv4, + cap_net_ext::AddressFamily::Ipv6 => Self::Ipv6, + } + } +} + +impl From for types::ErrorCode { + fn from(value: std::io::Error) -> Self { + (&value).into() + } +} + +impl From<&std::io::Error> for types::ErrorCode { + fn from(value: &std::io::Error) -> Self { + // Attempt the more detailed native error code first: + if let Some(errno) = Errno::from_io_error(value) { + return errno.into(); + } + + match value.kind() { + std::io::ErrorKind::AddrInUse => Self::AddressInUse, + std::io::ErrorKind::AddrNotAvailable => Self::AddressNotBindable, + std::io::ErrorKind::ConnectionAborted => Self::ConnectionAborted, + std::io::ErrorKind::ConnectionRefused => Self::ConnectionRefused, + std::io::ErrorKind::ConnectionReset => Self::ConnectionReset, + std::io::ErrorKind::InvalidInput => Self::InvalidArgument, + std::io::ErrorKind::NotConnected => Self::InvalidState, + std::io::ErrorKind::OutOfMemory => Self::OutOfMemory, + std::io::ErrorKind::PermissionDenied => Self::AccessDenied, + std::io::ErrorKind::TimedOut => Self::Timeout, + std::io::ErrorKind::Unsupported => Self::NotSupported, + _ => { + debug!("unknown I/O error: {value}"); + Self::Unknown + } + } + } +} + +impl From for types::ErrorCode { + fn from(value: Errno) -> Self { + (&value).into() + } +} + +impl From<&Errno> for types::ErrorCode { + fn from(value: &Errno) -> Self { + match *value { + #[cfg(not(windows))] + Errno::PERM => Self::AccessDenied, + Errno::ACCESS => Self::AccessDenied, + Errno::ADDRINUSE => Self::AddressInUse, + Errno::ADDRNOTAVAIL => Self::AddressNotBindable, + Errno::TIMEDOUT => Self::Timeout, + Errno::CONNREFUSED => Self::ConnectionRefused, + Errno::CONNRESET => Self::ConnectionReset, + Errno::CONNABORTED => Self::ConnectionAborted, + Errno::INVAL => Self::InvalidArgument, + Errno::HOSTUNREACH => Self::RemoteUnreachable, + Errno::HOSTDOWN => Self::RemoteUnreachable, + Errno::NETDOWN => Self::RemoteUnreachable, + Errno::NETUNREACH => Self::RemoteUnreachable, + #[cfg(target_os = "linux")] + Errno::NONET => Self::RemoteUnreachable, + Errno::ISCONN => Self::InvalidState, + Errno::NOTCONN => Self::InvalidState, + Errno::DESTADDRREQ => Self::InvalidState, + Errno::MSGSIZE => Self::DatagramTooLarge, + #[cfg(not(windows))] + Errno::NOMEM => Self::OutOfMemory, + Errno::NOBUFS => Self::OutOfMemory, + Errno::OPNOTSUPP => Self::NotSupported, + Errno::NOPROTOOPT => Self::NotSupported, + Errno::PFNOSUPPORT => Self::NotSupported, + Errno::PROTONOSUPPORT => Self::NotSupported, + Errno::PROTOTYPE => Self::NotSupported, + Errno::SOCKTNOSUPPORT => Self::NotSupported, + Errno::AFNOSUPPORT => Self::NotSupported, + + // FYI, EINPROGRESS should have already been handled by connect. + _ => { + debug!("unknown I/O error: {value}"); + Self::Unknown + } + } + } +} + +impl From for types::ErrorCode { + fn from(code: crate::sockets::util::ErrorCode) -> Self { + match code { + crate::sockets::util::ErrorCode::Unknown => Self::Unknown, + crate::sockets::util::ErrorCode::AccessDenied => Self::AccessDenied, + crate::sockets::util::ErrorCode::NotSupported => Self::NotSupported, + crate::sockets::util::ErrorCode::InvalidArgument => Self::InvalidArgument, + crate::sockets::util::ErrorCode::OutOfMemory => Self::OutOfMemory, + crate::sockets::util::ErrorCode::Timeout => Self::Timeout, + crate::sockets::util::ErrorCode::InvalidState => Self::InvalidState, + crate::sockets::util::ErrorCode::AddressNotBindable => Self::AddressNotBindable, + crate::sockets::util::ErrorCode::AddressInUse => Self::AddressInUse, + crate::sockets::util::ErrorCode::RemoteUnreachable => Self::RemoteUnreachable, + crate::sockets::util::ErrorCode::ConnectionRefused => Self::ConnectionRefused, + crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, + crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, + crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + } + } +} diff --git a/crates/wasi/src/p3/sockets/host/ip_name_lookup.rs b/crates/wasi/src/p3/sockets/host/ip_name_lookup.rs new file mode 100644 index 000000000000..d4d3dbbc9a32 --- /dev/null +++ b/crates/wasi/src/p3/sockets/host/ip_name_lookup.rs @@ -0,0 +1,39 @@ +use tokio::net::lookup_host; +use wasmtime::component::Accessor; + +use crate::p3::bindings::sockets::ip_name_lookup::{ErrorCode, Host, HostWithStore}; +use crate::p3::bindings::sockets::types; +use crate::p3::sockets::WasiSockets; +use crate::sockets::WasiSocketsCtxView; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, parse_host}; + +impl HostWithStore for WasiSockets { + async fn resolve_addresses( + store: &Accessor, + name: String, + ) -> wasmtime::Result, ErrorCode>> { + let Ok(host) = parse_host(&name) else { + return Ok(Err(ErrorCode::InvalidArgument)); + }; + if !store.with(|mut view| view.get().ctx.allowed_network_uses.ip_name_lookup) { + return Ok(Err(ErrorCode::PermanentResolverFailure)); + } + match host { + url::Host::Ipv4(addr) => Ok(Ok(vec![types::IpAddress::Ipv4(from_ipv4_addr(addr))])), + url::Host::Ipv6(addr) => Ok(Ok(vec![types::IpAddress::Ipv6(from_ipv6_addr(addr))])), + url::Host::Domain(domain) => { + // This is only resolving names, not ports, so force the port to be 0. + if let Ok(addrs) = lookup_host((domain.as_str(), 0)).await { + Ok(Ok(addrs + .map(|addr| addr.ip().to_canonical().into()) + .collect())) + } else { + // If/when we use `getaddrinfo` directly, map the error properly. + Ok(Err(ErrorCode::NameUnresolvable)) + } + } + } + } +} + +impl Host for WasiSocketsCtxView<'_> {} diff --git a/crates/wasi/src/p3/sockets/host/mod.rs b/crates/wasi/src/p3/sockets/host/mod.rs new file mode 100644 index 000000000000..aa4d333fbfda --- /dev/null +++ b/crates/wasi/src/p3/sockets/host/mod.rs @@ -0,0 +1,2 @@ +mod ip_name_lookup; +mod types; diff --git a/crates/wasi/src/p3/sockets/host/types/mod.rs b/crates/wasi/src/p3/sockets/host/types/mod.rs new file mode 100644 index 000000000000..b64fd44dbdf3 --- /dev/null +++ b/crates/wasi/src/p3/sockets/host/types/mod.rs @@ -0,0 +1,24 @@ +use core::net::SocketAddr; + +use wasmtime::component::Accessor; + +use crate::p3::bindings::sockets::types::Host; +use crate::p3::sockets::WasiSockets; +use crate::sockets::{SocketAddrCheck, SocketAddrUse, WasiSocketsCtxView}; + +mod tcp; +mod udp; + +impl Host for WasiSocketsCtxView<'_> {} + +fn get_socket_addr_check(store: &Accessor) -> SocketAddrCheck { + store.with(|mut view| view.get().ctx.socket_addr_check.clone()) +} + +async fn is_addr_allowed( + store: &Accessor, + addr: SocketAddr, + reason: SocketAddrUse, +) -> bool { + get_socket_addr_check(store)(addr, reason).await +} diff --git a/crates/wasi/src/p3/sockets/host/types/tcp.rs b/crates/wasi/src/p3/sockets/host/types/tcp.rs new file mode 100644 index 000000000000..530fa3fbe144 --- /dev/null +++ b/crates/wasi/src/p3/sockets/host/types/tcp.rs @@ -0,0 +1,590 @@ +use core::future::poll_fn; +use core::mem; +use core::net::SocketAddr; +use core::pin::pin; +use core::task::Poll; + +use std::io::Cursor; +use std::net::Shutdown; +use std::sync::Arc; + +use anyhow::{Context as _, ensure}; +use bytes::BytesMut; +use io_lifetimes::AsSocketlike as _; +use rustix::io::Errno; +use tokio::net::{TcpListener, TcpStream}; +use wasmtime::component::{ + Accessor, AccessorTask, FutureWriter, HostFuture, HostStream, Resource, ResourceTable, + StreamWriter, +}; + +use crate::p3::DEFAULT_BUFFER_CAPACITY; +use crate::p3::bindings::sockets::types::{ + Duration, ErrorCode, HostTcpSocket, HostTcpSocketWithStore, IpAddressFamily, IpSocketAddress, + TcpSocket, +}; +use crate::p3::sockets::WasiSockets; +use crate::p3::sockets::tcp::{NonInheritedOptions, TcpState}; +use crate::sockets::util::{ + is_valid_address_family, is_valid_remote_address, is_valid_unicast_address, +}; +use crate::sockets::{SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView}; + +use super::is_addr_allowed; + +fn is_tcp_allowed(store: &Accessor) -> bool { + store.with(|mut view| view.get().ctx.allowed_network_uses.tcp) +} + +fn get_socket<'a>( + table: &'a ResourceTable, + socket: &'a Resource, +) -> wasmtime::Result<&'a TcpSocket> { + table + .get(socket) + .context("failed to get socket resource from table") +} + +fn get_socket_mut<'a>( + table: &'a mut ResourceTable, + socket: &'a Resource, +) -> wasmtime::Result<&'a mut TcpSocket> { + table + .get_mut(socket) + .context("failed to get socket resource from table") +} + +struct ListenTask { + listener: Arc, + family: SocketAddressFamily, + tx: StreamWriter>>, + options: NonInheritedOptions, +} + +impl AccessorTask> for ListenTask { + async fn run(mut self, store: &Accessor) -> wasmtime::Result<()> { + while !self.tx.is_closed() { + let Some(res) = ({ + let mut accept = pin!(self.listener.accept()); + let mut tx = pin!(self.tx.watch_reader(store)); + poll_fn(|cx| match tx.as_mut().poll(cx) { + Poll::Ready(()) => return Poll::Ready(None), + Poll::Pending => accept.as_mut().poll(cx).map(Some), + }) + .await + }) else { + return Ok(()); + }; + let state = match res { + Ok((stream, _addr)) => { + self.options.apply(self.family, &stream); + TcpState::Connected(Arc::new(stream)) + } + Err(err) => { + match Errno::from_io_error(&err) { + // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS + // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, + // > or the service provider is still processing a callback function. + // + // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, + // because in POSIX this error is only returned by a non-blocking + // `connect` and wasi-sockets has a different solution for that. + #[cfg(windows)] + Some(Errno::INPROGRESS) => TcpState::Error(ErrorCode::Unknown), + + // Normalize Linux' non-standard behavior. + // + // From https://man7.org/linux/man-pages/man2/accept.2.html: + // > Linux accept() passes already-pending network errors on the + // > new socket as an error code from accept(). This behavior + // > differs from other BSD socket implementations. (...) + #[cfg(target_os = "linux")] + Some( + Errno::CONNRESET + | Errno::NETRESET + | Errno::HOSTUNREACH + | Errno::HOSTDOWN + | Errno::NETDOWN + | Errno::NETUNREACH + | Errno::PROTO + | Errno::NOPROTOOPT + | Errno::NONET + | Errno::OPNOTSUPP, + ) => TcpState::Error(ErrorCode::ConnectionAborted), + _ => TcpState::Error(err.into()), + } + } + }; + let socket = store.with(|mut view| { + view.get() + .table + .push(TcpSocket::from_state(state, self.family)) + .context("failed to push socket resource to table") + })?; + if let Some(socket) = self.tx.write(store, Some(socket)).await { + debug_assert!(self.tx.is_closed()); + store.with(|mut view| { + view.get() + .table + .delete(socket) + .context("failed to delete socket resource from table") + })?; + return Ok(()); + } + } + Ok(()) + } +} + +struct ResultWriteTask { + result: Result<(), ErrorCode>, + result_tx: FutureWriter>, +} + +impl AccessorTask> for ResultWriteTask { + async fn run(self, store: &Accessor) -> wasmtime::Result<()> { + self.result_tx.write(store, self.result).await; + Ok(()) + } +} + +struct ReceiveTask { + stream: Arc, + data_tx: StreamWriter>, + result_tx: FutureWriter>, +} + +impl AccessorTask> for ReceiveTask { + async fn run(mut self, store: &Accessor) -> wasmtime::Result<()> { + let mut buf = BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY); + let res = loop { + match self.stream.try_read_buf(&mut buf) { + Ok(0) => { + break Ok(()); + } + Ok(..) => { + buf = self + .data_tx + .write_all(store, Cursor::new(buf)) + .await + .into_inner(); + if self.data_tx.is_closed() { + break Ok(()); + } + buf.clear(); + } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + let Some(res) = ({ + let mut readable = pin!(self.stream.readable()); + let mut tx = pin!(self.data_tx.watch_reader(store)); + poll_fn(|cx| match tx.as_mut().poll(cx) { + Poll::Ready(()) => return Poll::Ready(None), + Poll::Pending => readable.as_mut().poll(cx).map(Some), + }) + .await + }) else { + break Ok(()); + }; + if let Err(err) = res { + break Err(err.into()); + } + } + Err(err) => { + break Err(err.into()); + } + } + }; + _ = self + .stream + .as_socketlike_view::() + .shutdown(Shutdown::Read); + + // Write the result async from a separate task to ensure that all resources used by this + // task are freed + store.spawn(ResultWriteTask { + result: res, + result_tx: self.result_tx, + }); + Ok(()) + } +} + +impl HostTcpSocketWithStore for WasiSockets { + async fn bind( + store: &Accessor, + socket: Resource, + local_address: IpSocketAddress, + ) -> wasmtime::Result> { + let local_address = SocketAddr::from(local_address); + if !is_tcp_allowed(store) + || !is_addr_allowed(store, local_address, SocketAddrUse::TcpBind).await + { + return Ok(Err(ErrorCode::AccessDenied)); + } + store.with(|mut view| { + let socket = get_socket_mut(view.get().table, &socket)?; + Ok(socket.bind(local_address)) + }) + } + + async fn connect( + store: &Accessor, + socket: Resource, + remote_address: IpSocketAddress, + ) -> wasmtime::Result> { + let remote_address = SocketAddr::from(remote_address); + if !is_tcp_allowed(store) + || !is_addr_allowed(store, remote_address, SocketAddrUse::TcpConnect).await + { + return Ok(Err(ErrorCode::AccessDenied)); + } + match store.with(|mut view| { + let ip = remote_address.ip(); + let socket = get_socket_mut(view.get().table, &socket)?; + if !is_valid_unicast_address(ip) + || !is_valid_remote_address(remote_address) + || !is_valid_address_family(ip, socket.family) + { + return anyhow::Ok(Err(ErrorCode::InvalidArgument)); + } + match mem::replace(&mut socket.tcp_state, TcpState::Connecting) { + TcpState::Default(sock) | TcpState::Bound(sock) => Ok(Ok(sock)), + tcp_state => { + socket.tcp_state = tcp_state; + Ok(Err(ErrorCode::InvalidState)) + } + } + })? { + Ok(sock) => { + // FIXME: handle possible cancellation of the outer `connect` + // https://github.com/bytecodealliance/wasmtime/pull/11291#discussion_r2223917986 + let res = sock.connect(remote_address).await; + store.with(|mut view| { + let socket = get_socket_mut(view.get().table, &socket)?; + ensure!( + matches!(socket.tcp_state, TcpState::Connecting), + "corrupted socket state" + ); + match res { + Ok(stream) => { + socket.tcp_state = TcpState::Connected(Arc::new(stream)); + Ok(Ok(())) + } + Err(err) => { + socket.tcp_state = TcpState::Closed; + Ok(Err(err.into())) + } + } + }) + } + Err(err) => Ok(Err(err)), + } + } + + async fn listen( + store: &Accessor, + socket: Resource, + ) -> wasmtime::Result>, ErrorCode>> { + store.with(|mut view| { + let (tx, rx) = view + .instance() + .stream::<_, _, Option<_>>(&mut view) + .context("failed to create stream")?; + if !view.get().ctx.allowed_network_uses.tcp { + return Ok(Err(ErrorCode::AccessDenied)); + } + let TcpSocket { + tcp_state, + listen_backlog_size, + family, + options, + } = get_socket_mut(view.get().table, &socket)?; + let sock = match mem::replace(tcp_state, TcpState::Closed) { + TcpState::Default(sock) | TcpState::Bound(sock) => sock, + prev => { + *tcp_state = prev; + return Ok(Err(ErrorCode::InvalidState)); + } + }; + let listener = match sock.listen(*listen_backlog_size) { + Ok(listener) => listener, + Err(err) => { + match Errno::from_io_error(&err) { + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE + // According to the docs, `listen` can return EMFILE on Windows. + // This is odd, because we're not trying to create a new socket + // or file descriptor of any kind. So we rewrite it to less + // surprising error code. + // + // At the time of writing, this behavior has never been experimentally + // observed by any of the wasmtime authors, so we're relying fully + // on Microsoft's documentation here. + #[cfg(windows)] + Some(Errno::MFILE) => return Ok(Err(ErrorCode::OutOfMemory)), + + _ => return Ok(Err(err.into())), + } + } + }; + let listener = Arc::new(listener); + *tcp_state = TcpState::Listening(Arc::clone(&listener)); + let task = ListenTask { + listener, + family: *family, + tx, + options: options.clone(), + }; + view.spawn(task); + Ok(Ok(rx.into())) + }) + } + + async fn send( + store: &Accessor, + socket: Resource, + data: HostStream, + ) -> wasmtime::Result> { + let (stream, mut data) = match store.with(|mut view| -> wasmtime::Result<_> { + let data = data.into_reader::>(&mut view); + let sock = get_socket(view.get().table, &socket)?; + if let TcpState::Connected(stream) | TcpState::Receiving(stream) = &sock.tcp_state { + Ok(Ok((Arc::clone(&stream), data))) + } else { + Ok(Err(ErrorCode::InvalidState)) + } + })? { + Ok((stream, data)) => (stream, data), + Err(err) => return Ok(Err(err)), + }; + let mut buf = Vec::with_capacity(8096); + let mut result = Ok(()); + while !data.is_closed() { + buf = data.read(store, buf).await; + let mut slice = buf.as_slice(); + while !slice.is_empty() { + match stream.try_write(&slice) { + Ok(n) => slice = &slice[n..], + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + if let Err(err) = stream.writable().await { + result = Err(err.into()); + break; + } + } + Err(err) => { + result = Err(err.into()); + break; + } + } + } + buf.clear(); + } + _ = stream + .as_socketlike_view::() + .shutdown(Shutdown::Write); + Ok(result) + } + + async fn receive( + store: &Accessor, + socket: Resource, + ) -> wasmtime::Result<(HostStream, HostFuture>)> { + store.with(|mut view| { + let instance = view.instance(); + let (data_tx, data_rx) = instance + .stream::<_, _, BytesMut>(&mut view) + .context("failed to create stream")?; + let TcpSocket { tcp_state, .. } = get_socket_mut(view.get().table, &socket)?; + match mem::replace(tcp_state, TcpState::Closed) { + TcpState::Connected(stream) => { + *tcp_state = TcpState::Receiving(Arc::clone(&stream)); + let (result_tx, result_rx) = instance + .future(|| unreachable!(), &mut view) + .context("failed to create future")?; + view.spawn(ReceiveTask { + stream, + data_tx, + result_tx, + }); + Ok((data_rx.into(), result_rx.into())) + } + prev => { + *tcp_state = prev; + let (_, result_rx) = instance + .future(|| Err(ErrorCode::InvalidState), &mut view) + .context("failed to create future")?; + Ok((data_rx.into(), result_rx.into())) + } + } + }) + } +} + +impl HostTcpSocket for WasiSocketsCtxView<'_> { + fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { + let socket = TcpSocket::new(address_family.into()).context("failed to create socket")?; + self.table + .push(socket) + .context("failed to push socket resource to table") + } + + fn local_address( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.local_address()) + } + + fn remote_address( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.remote_address()) + } + + fn is_listening(&mut self, socket: Resource) -> wasmtime::Result { + let sock = get_socket(self.table, &socket)?; + Ok(sock.is_listening()) + } + + fn address_family(&mut self, socket: Resource) -> wasmtime::Result { + let sock = get_socket(self.table, &socket)?; + Ok(sock.address_family()) + } + + fn set_listen_backlog_size( + &mut self, + socket: Resource, + value: u64, + ) -> wasmtime::Result> { + let sock = get_socket_mut(self.table, &socket)?; + Ok(sock.set_listen_backlog_size(value)) + } + + fn keep_alive_enabled( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.keep_alive_enabled()) + } + + fn set_keep_alive_enabled( + &mut self, + socket: Resource, + value: bool, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_keep_alive_enabled(value)) + } + + fn keep_alive_idle_time( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.keep_alive_idle_time()) + } + + fn set_keep_alive_idle_time( + &mut self, + socket: Resource, + value: Duration, + ) -> wasmtime::Result> { + let sock = get_socket_mut(self.table, &socket)?; + Ok(sock.set_keep_alive_idle_time(value)) + } + + fn keep_alive_interval( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.keep_alive_interval()) + } + + fn set_keep_alive_interval( + &mut self, + socket: Resource, + value: Duration, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_keep_alive_interval(value)) + } + + fn keep_alive_count( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.keep_alive_count()) + } + + fn set_keep_alive_count( + &mut self, + socket: Resource, + value: u32, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_keep_alive_count(value)) + } + + fn hop_limit( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.hop_limit()) + } + + fn set_hop_limit( + &mut self, + socket: Resource, + value: u8, + ) -> wasmtime::Result> { + let sock = get_socket_mut(self.table, &socket)?; + Ok(sock.set_hop_limit(value)) + } + + fn receive_buffer_size( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.receive_buffer_size()) + } + + fn set_receive_buffer_size( + &mut self, + socket: Resource, + value: u64, + ) -> wasmtime::Result> { + let sock = get_socket_mut(self.table, &socket)?; + Ok(sock.set_receive_buffer_size(value)) + } + + fn send_buffer_size( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.send_buffer_size()) + } + + fn set_send_buffer_size( + &mut self, + socket: Resource, + value: u64, + ) -> wasmtime::Result> { + let sock = get_socket_mut(self.table, &socket)?; + Ok(sock.set_send_buffer_size(value)) + } + + fn drop(&mut self, sock: Resource) -> wasmtime::Result<()> { + self.table + .delete(sock) + .context("failed to delete socket resource from table")?; + Ok(()) + } +} diff --git a/crates/wasi/src/p3/sockets/host/types/udp.rs b/crates/wasi/src/p3/sockets/host/types/udp.rs new file mode 100644 index 000000000000..2518590c8b10 --- /dev/null +++ b/crates/wasi/src/p3/sockets/host/types/udp.rs @@ -0,0 +1,208 @@ +use core::net::SocketAddr; + +use anyhow::Context as _; +use wasmtime::component::{Accessor, Resource, ResourceTable}; + +use crate::p3::bindings::sockets::types::{ + ErrorCode, HostUdpSocket, HostUdpSocketWithStore, IpAddressFamily, IpSocketAddress, +}; +use crate::p3::sockets::WasiSockets; +use crate::p3::sockets::udp::UdpSocket; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, WasiSocketsCtxView}; + +use super::is_addr_allowed; + +fn is_udp_allowed(store: &Accessor) -> bool { + store.with(|mut view| view.get().ctx.allowed_network_uses.udp) +} + +fn get_socket<'a>( + table: &'a ResourceTable, + socket: &'a Resource, +) -> wasmtime::Result<&'a UdpSocket> { + table + .get(socket) + .context("failed to get socket resource from table") +} + +fn get_socket_mut<'a>( + table: &'a mut ResourceTable, + socket: &'a Resource, +) -> wasmtime::Result<&'a mut UdpSocket> { + table + .get_mut(socket) + .context("failed to get socket resource from table") +} + +impl HostUdpSocketWithStore for WasiSockets { + async fn bind( + store: &Accessor, + socket: Resource, + local_address: IpSocketAddress, + ) -> wasmtime::Result> { + let local_address = SocketAddr::from(local_address); + if !is_udp_allowed(store) + || !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await + { + return Ok(Err(ErrorCode::AccessDenied)); + } + store.with(|mut view| { + let socket = get_socket_mut(view.get().table, &socket)?; + Ok(socket.bind(local_address)) + }) + } + + async fn connect( + store: &Accessor, + socket: Resource, + remote_address: IpSocketAddress, + ) -> wasmtime::Result> { + let remote_address = SocketAddr::from(remote_address); + if !is_udp_allowed(store) + || !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await + { + return Ok(Err(ErrorCode::AccessDenied)); + } + store.with(|mut view| { + let socket = get_socket_mut(view.get().table, &socket)?; + Ok(socket.connect(remote_address)) + }) + } + + async fn send( + store: &Accessor, + socket: Resource, + data: Vec, + remote_address: Option, + ) -> wasmtime::Result> { + if data.len() > MAX_UDP_DATAGRAM_SIZE { + return Ok(Err(ErrorCode::DatagramTooLarge)); + } + if !is_udp_allowed(store) { + return Ok(Err(ErrorCode::AccessDenied)); + } + if let Some(addr) = remote_address { + let addr = SocketAddr::from(addr); + if !is_addr_allowed(store, addr, SocketAddrUse::UdpOutgoingDatagram).await { + return Ok(Err(ErrorCode::AccessDenied)); + } + let fut = store.with(|mut view| { + get_socket(view.get().table, &socket).map(|sock| sock.send_to(data, addr)) + })?; + Ok(fut.await) + } else { + let fut = store.with(|mut view| { + get_socket(view.get().table, &socket).map(|sock| sock.send(data)) + })?; + Ok(fut.await) + } + } + + async fn receive( + store: &Accessor, + socket: Resource, + ) -> wasmtime::Result, IpSocketAddress), ErrorCode>> { + if !is_udp_allowed(store) { + return Ok(Err(ErrorCode::AccessDenied)); + } + let fut = store + .with(|mut view| get_socket(view.get().table, &socket).map(|sock| sock.receive()))?; + Ok(fut.await) + } +} + +impl HostUdpSocket for WasiSocketsCtxView<'_> { + fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { + let socket = UdpSocket::new(address_family.into()).context("failed to create socket")?; + self.table + .push(socket) + .context("failed to push socket resource to table") + } + + fn disconnect( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let socket = get_socket_mut(self.table, &socket)?; + Ok(socket.disconnect()) + } + + fn local_address( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.local_address()) + } + + fn remote_address( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.remote_address()) + } + + fn address_family(&mut self, socket: Resource) -> wasmtime::Result { + let sock = get_socket(self.table, &socket)?; + Ok(sock.address_family()) + } + + fn unicast_hop_limit( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.unicast_hop_limit()) + } + + fn set_unicast_hop_limit( + &mut self, + socket: Resource, + value: u8, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_unicast_hop_limit(value)) + } + + fn receive_buffer_size( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.receive_buffer_size()) + } + + fn set_receive_buffer_size( + &mut self, + socket: Resource, + value: u64, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_receive_buffer_size(value)) + } + + fn send_buffer_size( + &mut self, + socket: Resource, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.send_buffer_size()) + } + + fn set_send_buffer_size( + &mut self, + socket: Resource, + value: u64, + ) -> wasmtime::Result> { + let sock = get_socket(self.table, &socket)?; + Ok(sock.set_send_buffer_size(value)) + } + + fn drop(&mut self, sock: Resource) -> wasmtime::Result<()> { + self.table + .delete(sock) + .context("failed to delete socket resource from table")?; + Ok(()) + } +} diff --git a/crates/wasi/src/p3/sockets/mod.rs b/crates/wasi/src/p3/sockets/mod.rs new file mode 100644 index 000000000000..069bf77999b8 --- /dev/null +++ b/crates/wasi/src/p3/sockets/mod.rs @@ -0,0 +1,82 @@ +use crate::p3::bindings::sockets; +use crate::sockets::{WasiSocketsCtxView, WasiSocketsView}; +use wasmtime::component::{HasData, Linker}; + +mod conv; +mod host; +pub mod tcp; +pub mod udp; + +/// Add all WASI interfaces from this module into the `linker` provided. +/// +/// This function will add all interfaces implemented by this module to the +/// [`Linker`], which corresponds to the `wasi:sockets/imports` world supported by +/// this module. +/// +/// This is low-level API for advanced use cases, +/// [`wasmtime_wasi::p3::add_to_linker`](crate::p3::add_to_linker) can be used instead +/// to add *all* wasip3 interfaces (including the ones from this module) to the `linker`. +/// +/// # Example +/// +/// ``` +/// use wasmtime::{Engine, Result, Store, Config}; +/// use wasmtime::component::{Linker, ResourceTable}; +/// use wasmtime_wasi::sockets::{WasiSocketsCtx, WasiSocketsCtxView, WasiSocketsView}; +/// +/// fn main() -> Result<()> { +/// let mut config = Config::new(); +/// config.async_support(true); +/// config.wasm_component_model_async(true); +/// let engine = Engine::new(&config)?; +/// +/// let mut linker = Linker::::new(&engine); +/// wasmtime_wasi::p3::sockets::add_to_linker(&mut linker)?; +/// // ... add any further functionality to `linker` if desired ... +/// +/// let mut store = Store::new( +/// &engine, +/// MyState::default(), +/// ); +/// +/// // ... use `linker` to instantiate within `store` ... +/// +/// Ok(()) +/// } +/// +/// #[derive(Default)] +/// struct MyState { +/// sockets: WasiSocketsCtx, +/// table: ResourceTable, +/// } +/// +/// impl WasiSocketsView for MyState { +/// fn sockets(&mut self) -> WasiSocketsCtxView<'_> { +/// WasiSocketsCtxView { +/// ctx: &mut self.sockets, +/// table: &mut self.table, +/// } +/// } +/// } +/// ``` +pub fn add_to_linker(linker: &mut Linker) -> wasmtime::Result<()> +where + T: WasiSocketsView + 'static, +{ + add_to_linker_impl(linker, T::sockets) +} + +pub(crate) fn add_to_linker_impl( + linker: &mut Linker, + host_getter: fn(&mut T) -> WasiSocketsCtxView<'_>, +) -> wasmtime::Result<()> { + sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>(linker, host_getter)?; + sockets::types::add_to_linker::<_, WasiSockets>(linker, host_getter)?; + Ok(()) +} + +struct WasiSockets; + +impl HasData for WasiSockets { + type Data<'a> = WasiSocketsCtxView<'a>; +} diff --git a/crates/wasi/src/p3/sockets/tcp.rs b/crates/wasi/src/p3/sockets/tcp.rs new file mode 100644 index 000000000000..e327b1ef75e9 --- /dev/null +++ b/crates/wasi/src/p3/sockets/tcp.rs @@ -0,0 +1,409 @@ +use core::fmt::Debug; +use core::mem; +use core::net::SocketAddr; + +use std::sync::Arc; + +use cap_net_ext::AddressFamily; +use io_lifetimes::AsSocketlike as _; +use io_lifetimes::views::SocketlikeView; +use rustix::net::sockopt; + +use crate::p3::bindings::sockets::types::{Duration, ErrorCode, IpAddressFamily, IpSocketAddress}; +use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::{ + get_unicast_hop_limit, is_valid_address_family, is_valid_unicast_address, receive_buffer_size, + send_buffer_size, set_keep_alive_count, set_keep_alive_idle_time, set_keep_alive_interval, + set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, tcp_bind, +}; +use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily}; + +/// The state of a TCP socket. +/// +/// This represents the various states a socket can be in during the +/// activities of binding, listening, accepting, and connecting. +pub enum TcpState { + /// The initial state for a newly-created socket. + Default(tokio::net::TcpSocket), + + /// Binding finished. The socket has an address but is not yet listening for connections. + Bound(tokio::net::TcpSocket), + + /// The socket is now listening and waiting for an incoming connection. + Listening(Arc), + + /// An outgoing connection is started. + Connecting, + + /// A connection has been established. + Connected(Arc), + + /// A connection has been established and `receive` has been called. + Receiving(Arc), + + Error(ErrorCode), + + Closed, +} + +impl Debug for TcpState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Default(_) => f.debug_tuple("Default").finish(), + Self::Bound(_) => f.debug_tuple("Bound").finish(), + Self::Listening { .. } => f.debug_tuple("Listening").finish(), + Self::Connecting => f.debug_tuple("Connecting").finish(), + Self::Connected { .. } => f.debug_tuple("Connected").finish(), + Self::Receiving { .. } => f.debug_tuple("Receiving").finish(), + Self::Error(..) => f.debug_tuple("Error").finish(), + Self::Closed => write!(f, "Closed"), + } + } +} + +/// A host TCP socket, plus associated bookkeeping. +pub struct TcpSocket { + /// The current state in the bind/listen/accept/connect progression. + pub tcp_state: TcpState, + + /// The desired listen queue size. + pub listen_backlog_size: u32, + + pub family: SocketAddressFamily, + + pub options: NonInheritedOptions, +} + +impl TcpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> std::io::Result { + with_ambient_tokio_runtime(|| { + let (socket, family) = match family { + AddressFamily::Ipv4 => { + let socket = tokio::net::TcpSocket::new_v4()?; + (socket, SocketAddressFamily::Ipv4) + } + AddressFamily::Ipv6 => { + let socket = tokio::net::TcpSocket::new_v6()?; + sockopt::set_ipv6_v6only(&socket, true)?; + (socket, SocketAddressFamily::Ipv6) + } + }; + + Ok(Self::from_state(TcpState::Default(socket), family)) + }) + } + + /// Create a `TcpSocket` from an existing socket. + pub fn from_state(state: TcpState, family: SocketAddressFamily) -> Self { + Self { + tcp_state: state, + listen_backlog_size: DEFAULT_TCP_BACKLOG, + family, + options: Default::default(), + } + } + + pub fn as_std_view(&self) -> Result, ErrorCode> { + match &self.tcp_state { + TcpState::Default(socket) | TcpState::Bound(socket) => Ok(socket.as_socketlike_view()), + TcpState::Connected(stream) | TcpState::Receiving(stream) => { + Ok(stream.as_socketlike_view()) + } + TcpState::Listening(listener) => Ok(listener.as_socketlike_view()), + TcpState::Connecting | TcpState::Closed => Err(ErrorCode::InvalidState), + TcpState::Error(err) => Err(*err), + } + } + + pub fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> { + let ip = addr.ip(); + if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) { + return Err(ErrorCode::InvalidArgument); + } + match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Default(sock) => { + if let Err(err) = tcp_bind(&sock, addr) { + self.tcp_state = TcpState::Default(sock); + Err(err.into()) + } else { + self.tcp_state = TcpState::Bound(sock); + Ok(()) + } + } + tcp_state => { + self.tcp_state = tcp_state; + Err(ErrorCode::InvalidState) + } + } + } + + pub fn local_address(&self) -> Result { + match &self.tcp_state { + TcpState::Bound(socket) => { + let addr = socket.local_addr()?; + Ok(addr.into()) + } + TcpState::Connected(stream) | TcpState::Receiving(stream) => { + let addr = stream.local_addr()?; + Ok(addr.into()) + } + TcpState::Listening(listener) => { + let addr = listener.local_addr()?; + Ok(addr.into()) + } + TcpState::Error(err) => Err(*err), + _ => Err(ErrorCode::InvalidState), + } + } + + pub fn remote_address(&self) -> Result { + match &self.tcp_state { + TcpState::Connected(stream) | TcpState::Receiving(stream) => { + let addr = stream.peer_addr()?; + Ok(addr.into()) + } + TcpState::Error(err) => Err(*err), + _ => Err(ErrorCode::InvalidState), + } + } + + pub fn is_listening(&self) -> bool { + matches!(self.tcp_state, TcpState::Listening { .. }) + } + + pub fn address_family(&self) -> IpAddressFamily { + match self.family { + SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, + SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, + } + } + + pub fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> { + const MIN_BACKLOG: u32 = 1; + const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. + + if value == 0 { + return Err(ErrorCode::InvalidArgument); + } + // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. + let value = value + .try_into() + .unwrap_or(MAX_BACKLOG) + .clamp(MIN_BACKLOG, MAX_BACKLOG); + match &self.tcp_state { + TcpState::Default(..) | TcpState::Bound(..) => { + // Socket not listening yet. Stash value for first invocation to `listen`. + self.listen_backlog_size = value; + Ok(()) + } + TcpState::Listening(listener) => { + // Try to update the backlog by calling `listen` again. + // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. + if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() { + return Err(ErrorCode::NotSupported); + } + self.listen_backlog_size = value; + Ok(()) + } + TcpState::Error(err) => Err(*err), + _ => Err(ErrorCode::InvalidState), + } + } + + pub fn keep_alive_enabled(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::socket_keepalive(fd)?; + Ok(v) + } + + pub fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + sockopt::set_socket_keepalive(fd, value)?; + Ok(()) + } + + pub fn keep_alive_idle_time(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepidle(fd)?; + Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) + } + + pub fn set_keep_alive_idle_time(&mut self, value: Duration) -> Result<(), ErrorCode> { + let value = { + let fd = self.as_std_view()?; + set_keep_alive_idle_time(&*fd, value)? + }; + self.options.set_keep_alive_idle_time(value); + Ok(()) + } + + pub fn keep_alive_interval(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepintvl(fd)?; + Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) + } + + pub fn set_keep_alive_interval(&self, value: Duration) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + set_keep_alive_interval(fd, core::time::Duration::from_nanos(value))?; + Ok(()) + } + + pub fn keep_alive_count(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepcnt(fd)?; + Ok(v) + } + + pub fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + set_keep_alive_count(fd, value)?; + Ok(()) + } + + pub fn hop_limit(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = get_unicast_hop_limit(fd, self.family)?; + Ok(n) + } + + pub fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> { + { + let fd = &*self.as_std_view()?; + set_unicast_hop_limit(fd, self.family, value)?; + } + self.options.set_hop_limit(value); + Ok(()) + } + + pub fn receive_buffer_size(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = receive_buffer_size(fd)?; + Ok(n) + } + + pub fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { + let res = { + let fd = &*self.as_std_view()?; + set_receive_buffer_size(fd, value)? + }; + self.options.set_receive_buffer_size(res); + Ok(()) + } + + pub fn send_buffer_size(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = send_buffer_size(fd)?; + Ok(n) + } + + pub fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { + let res = { + let fd = &*self.as_std_view()?; + set_send_buffer_size(fd, value)? + }; + self.options.set_send_buffer_size(res); + Ok(()) + } +} + +#[cfg(not(target_os = "macos"))] +pub use inherits_option::*; +#[cfg(not(target_os = "macos"))] +mod inherits_option { + use crate::sockets::SocketAddressFamily; + use tokio::net::TcpStream; + + #[derive(Default, Clone)] + pub struct NonInheritedOptions; + + impl NonInheritedOptions { + pub fn set_keep_alive_idle_time(&mut self, _value: u64) {} + + pub fn set_hop_limit(&mut self, _value: u8) {} + + pub fn set_receive_buffer_size(&mut self, _value: usize) {} + + pub fn set_send_buffer_size(&mut self, _value: usize) {} + + pub fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {} + } +} + +#[cfg(target_os = "macos")] +pub use does_not_inherit_options::*; +#[cfg(target_os = "macos")] +mod does_not_inherit_options { + use crate::sockets::SocketAddressFamily; + use rustix::net::sockopt; + use std::sync::Arc; + use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed}; + use std::time::Duration; + use tokio::net::TcpStream; + + // The socket options below are not automatically inherited from the listener + // on all platforms. So we keep track of which options have been explicitly + // set and manually apply those values to newly accepted clients. + #[derive(Default, Clone)] + pub struct NonInheritedOptions(Arc); + + #[derive(Default)] + struct Inner { + receive_buffer_size: AtomicUsize, + send_buffer_size: AtomicUsize, + hop_limit: AtomicU8, + keep_alive_idle_time: AtomicU64, // nanoseconds + } + + impl NonInheritedOptions { + pub fn set_keep_alive_idle_time(&mut self, value: u64) { + self.0.keep_alive_idle_time.store(value, Relaxed); + } + + pub fn set_hop_limit(&mut self, value: u8) { + self.0.hop_limit.store(value, Relaxed); + } + + pub fn set_receive_buffer_size(&mut self, value: usize) { + self.0.receive_buffer_size.store(value, Relaxed); + } + + pub fn set_send_buffer_size(&mut self, value: usize) { + self.0.send_buffer_size.store(value, Relaxed); + } + + pub fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) { + // Manually inherit socket options from listener. We only have to + // do this on platforms that don't already do this automatically + // and only if a specific value was explicitly set on the listener. + + let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed); + if receive_buffer_size > 0 { + // Ignore potential error. + _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size); + } + + let send_buffer_size = self.0.send_buffer_size.load(Relaxed); + if send_buffer_size > 0 { + // Ignore potential error. + _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size); + } + + // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. + if family == SocketAddressFamily::Ipv6 { + let hop_limit = self.0.hop_limit.load(Relaxed); + if hop_limit > 0 { + // Ignore potential error. + _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit)); + } + } + + let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed); + if keep_alive_idle_time > 0 { + // Ignore potential error. + _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time)); + } + } + } +} diff --git a/crates/wasi/src/p3/sockets/udp.rs b/crates/wasi/src/p3/sockets/udp.rs new file mode 100644 index 000000000000..6e2d635cfdd0 --- /dev/null +++ b/crates/wasi/src/p3/sockets/udp.rs @@ -0,0 +1,285 @@ +use core::future::Future; +use core::net::SocketAddr; + +use std::sync::Arc; + +use cap_net_ext::AddressFamily; +use io_lifetimes::AsSocketlike as _; +use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _}; +use rustix::io::Errno; +use rustix::net::connect; +use tracing::debug; + +use crate::p3::bindings::sockets::types::{ErrorCode, IpAddressFamily, IpSocketAddress}; +use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::{ + get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, receive_buffer_size, + send_buffer_size, set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, + udp_bind, udp_disconnect, udp_socket, +}; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddressFamily}; + +/// The state of a UDP socket. +/// +/// This represents the various states a socket can be in during the +/// activities of binding, and connecting. +enum UdpState { + /// The initial state for a newly-created socket. + Default, + + /// Binding finished via `finish_bind`. The socket has an address but + /// is not yet listening for connections. + Bound, + + /// The socket is "connected" to a peer address. + Connected(SocketAddr), +} + +/// A host UDP socket, plus associated bookkeeping. +/// +/// The inner state is wrapped in an Arc because the same underlying socket is +/// used for implementing the stream types. +pub struct UdpSocket { + socket: Arc, + + /// The current state in the bind/connect progression. + udp_state: UdpState, + + /// Socket address family. + family: SocketAddressFamily, +} + +impl UdpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> std::io::Result { + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: + // - On Windows: call WSAStartup if not done before. + // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, + // or afterwards using ioctl or fcntl. Exact method depends on the platform. + + let fd = udp_socket(family)?; + + let socket_address_family = match family { + AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, + AddressFamily::Ipv6 => { + rustix::net::sockopt::set_ipv6_v6only(&fd, true)?; + SocketAddressFamily::Ipv6 + } + }; + + let socket = with_ambient_tokio_runtime(|| { + tokio::net::UdpSocket::try_from(unsafe { + std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike()) + }) + })?; + + Ok(Self { + socket: Arc::new(socket), + udp_state: UdpState::Default, + family: socket_address_family, + }) + } + + pub fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> { + if !matches!(self.udp_state, UdpState::Default) { + return Err(ErrorCode::InvalidState); + } + if !is_valid_address_family(addr.ip(), self.family) { + return Err(ErrorCode::InvalidArgument); + } + udp_bind(&self.socket, addr)?; + self.udp_state = UdpState::Bound; + Ok(()) + } + + pub fn disconnect(&mut self) -> Result<(), ErrorCode> { + if !matches!(self.udp_state, UdpState::Connected(..)) { + return Err(ErrorCode::InvalidState); + } + udp_disconnect(&self.socket)?; + self.udp_state = UdpState::Bound; + Ok(()) + } + + pub fn connect(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> { + if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) { + return Err(ErrorCode::InvalidArgument); + } + + // We disconnect & (re)connect in two distinct steps for two reasons: + // - To leave our socket instance in a consistent state in case the + // connect fails. + // - When reconnecting to a different address, Linux sometimes fails + // if there isn't a disconnect in between. + + // Step #1: Disconnect + if let UdpState::Connected(..) = self.udp_state { + udp_disconnect(&self.socket)?; + self.udp_state = UdpState::Bound; + } + // Step #2: (Re)connect + connect(&self.socket, &addr).map_err(|error| match error { + Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation. + Errno::INPROGRESS => { + debug!("UDP connect returned EINPROGRESS, which should never happen"); + ErrorCode::Unknown + } + err => err.into(), + })?; + self.udp_state = UdpState::Connected(addr); + Ok(()) + } + + pub fn send(&self, buf: Vec) -> impl Future> + use<> { + let socket = if let UdpState::Connected(..) = self.udp_state { + Ok(Arc::clone(&self.socket)) + } else { + Err(ErrorCode::InvalidArgument) + }; + async move { + let socket = socket?; + send(&socket, &buf).await + } + } + + pub fn send_to( + &self, + buf: Vec, + addr: SocketAddr, + ) -> impl Future> + use<> { + enum Mode { + Send(Arc), + SendTo(Arc, SocketAddr), + } + let socket = match &self.udp_state { + UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)), + UdpState::Connected(caddr) if addr == *caddr => { + Ok(Mode::Send(Arc::clone(&self.socket))) + } + UdpState::Connected(..) => Err(ErrorCode::InvalidArgument), + }; + async move { + match socket? { + Mode::Send(socket) => send(&socket, &buf).await, + Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await, + } + } + } + + pub fn receive( + &self, + ) -> impl Future, IpSocketAddress), ErrorCode>> + use<> { + enum Mode { + Recv(Arc, IpSocketAddress), + RecvFrom(Arc), + } + let socket = match self.udp_state { + UdpState::Default => Err(ErrorCode::InvalidState), + UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))), + UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr.into())), + }; + async move { + let socket = socket?; + let mut buf = vec![0; MAX_UDP_DATAGRAM_SIZE]; + let (n, addr) = match socket { + Mode::Recv(socket, addr) => { + let n = socket.recv(&mut buf).await?; + (n, addr) + } + Mode::RecvFrom(socket) => { + let (n, addr) = socket.recv_from(&mut buf).await?; + (n, addr.into()) + } + }; + buf.truncate(n); + Ok((buf, addr)) + } + } + + pub fn local_address(&self) -> Result { + if matches!(self.udp_state, UdpState::Default) { + return Err(ErrorCode::InvalidState); + } + let addr = self + .socket + .as_socketlike_view::() + .local_addr()?; + Ok(addr.into()) + } + + pub fn remote_address(&self) -> Result { + if !matches!(self.udp_state, UdpState::Connected(..)) { + return Err(ErrorCode::InvalidState); + } + let addr = self + .socket + .as_socketlike_view::() + .peer_addr()?; + Ok(addr.into()) + } + + pub fn address_family(&self) -> IpAddressFamily { + match self.family { + SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, + SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, + } + } + + pub fn unicast_hop_limit(&self) -> Result { + let n = get_unicast_hop_limit(&self.socket, self.family)?; + Ok(n) + } + + pub fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> { + set_unicast_hop_limit(&self.socket, self.family, value)?; + Ok(()) + } + + pub fn receive_buffer_size(&self) -> Result { + let n = receive_buffer_size(&self.socket)?; + Ok(n) + } + + pub fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> { + set_receive_buffer_size(&self.socket, value)?; + Ok(()) + } + + pub fn send_buffer_size(&self) -> Result { + let n = send_buffer_size(&self.socket)?; + Ok(n) + } + + pub fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> { + set_send_buffer_size(&self.socket, value)?; + Ok(()) + } +} + +async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> { + let n = socket.send(buf).await?; + // From Rust stdlib docs: + // > Note that the operating system may refuse buffers larger than 65507. + // > However, partial writes are not possible until buffer sizes above `i32::MAX`. + // + // For example, on Windows, at most `i32::MAX` bytes will be written + if n != buf.len() { + Err(ErrorCode::Unknown) + } else { + Ok(()) + } +} + +async fn send_to( + socket: &tokio::net::UdpSocket, + buf: &[u8], + addr: SocketAddr, +) -> Result<(), ErrorCode> { + let n = socket.send_to(buf, addr).await?; + // See [`send`] documentation + if n != buf.len() { + Err(ErrorCode::Unknown) + } else { + Ok(()) + } +} diff --git a/crates/wasi/src/p3/view.rs b/crates/wasi/src/p3/view.rs index 7674f22051bd..71222d8f794b 100644 --- a/crates/wasi/src/p3/view.rs +++ b/crates/wasi/src/p3/view.rs @@ -2,11 +2,6 @@ use wasmtime::component::ResourceTable; use crate::p3::ctx::WasiCtx; -pub struct WasiCtxView<'a> { - pub ctx: &'a mut WasiCtx, - pub table: &'a mut ResourceTable, -} - /// A trait which provides access to the [`WasiCtx`] inside the embedder's `T` /// of [`Store`][`Store`]. /// @@ -19,7 +14,7 @@ pub struct WasiCtxView<'a> { /// # Example /// /// ``` -/// use wasmtime_wasi::p3::{WasiCtx, WasiCtxBuilder, WasiView, WasiCtxView}; +/// use wasmtime_wasi::p3::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView}; /// use wasmtime::component::ResourceTable; /// /// struct MyState { @@ -56,3 +51,8 @@ impl WasiView for Box { T::ctx(self) } } + +pub struct WasiCtxView<'a> { + pub ctx: &'a mut WasiCtx, + pub table: &'a mut ResourceTable, +} diff --git a/crates/wasi/src/p3/wit/deps/sockets/ip-name-lookup.wit b/crates/wasi/src/p3/wit/deps/sockets/ip-name-lookup.wit index 7cc8b03e35f2..73b4b201f2d3 100644 --- a/crates/wasi/src/p3/wit/deps/sockets/ip-name-lookup.wit +++ b/crates/wasi/src/p3/wit/deps/sockets/ip-name-lookup.wit @@ -58,5 +58,5 @@ interface ip-name-lookup { /// - /// - @since(version = 0.3.0) - resolve-addresses: func(name: string) -> result, error-code>; + resolve-addresses: async func(name: string) -> result, error-code>; } diff --git a/crates/wasi/src/p3/wit/deps/sockets/types.wit b/crates/wasi/src/p3/wit/deps/sockets/types.wit index 156cc502675c..86315314fc3f 100644 --- a/crates/wasi/src/p3/wit/deps/sockets/types.wit +++ b/crates/wasi/src/p3/wit/deps/sockets/types.wit @@ -219,7 +219,7 @@ interface types { /// - /// - @since(version = 0.3.0) - connect: func(remote-address: ip-socket-address) -> result<_, error-code>; + connect: async func(remote-address: ip-socket-address) -> result<_, error-code>; /// Start listening return a stream of new inbound connections. /// @@ -309,7 +309,7 @@ interface types { /// - /// - @since(version = 0.3.0) - send: func(data: stream) -> result<_, error-code>; + send: async func(data: stream) -> result<_, error-code>; /// Read data from peer. /// @@ -624,7 +624,7 @@ interface types { /// - /// - @since(version = 0.3.0) - send: func(data: list, remote-address: option) -> result<_, error-code>; + send: async func(data: list, remote-address: option) -> result<_, error-code>; /// Receive a message on the socket. /// @@ -650,7 +650,7 @@ interface types { /// - /// - @since(version = 0.3.0) - receive: func() -> result, ip-socket-address>, error-code>; + receive: async func() -> result, ip-socket-address>, error-code>; /// Get the current bound address. /// diff --git a/crates/wasi/src/random.rs b/crates/wasi/src/random.rs index 474f819155cd..7359c17c3ec2 100644 --- a/crates/wasi/src/random.rs +++ b/crates/wasi/src/random.rs @@ -1,27 +1,5 @@ use cap_rand::{Rng as _, RngCore, SeedableRng as _}; -impl WasiRandomView for &mut T { - fn random(&mut self) -> &mut WasiRandomCtx { - T::random(self) - } -} - -impl WasiRandomView for Box { - fn random(&mut self) -> &mut WasiRandomCtx { - T::random(self) - } -} - -impl WasiRandomView for WasiRandomCtx { - fn random(&mut self) -> &mut WasiRandomCtx { - self - } -} - -pub trait WasiRandomView: Send { - fn random(&mut self) -> &mut WasiRandomCtx; -} - pub struct WasiRandomCtx { pub random: Box, pub insecure_random: Box, @@ -49,6 +27,28 @@ impl Default for WasiRandomCtx { } } +pub trait WasiRandomView: Send { + fn random(&mut self) -> &mut WasiRandomCtx; +} + +impl WasiRandomView for &mut T { + fn random(&mut self) -> &mut WasiRandomCtx { + T::random(self) + } +} + +impl WasiRandomView for Box { + fn random(&mut self) -> &mut WasiRandomCtx { + T::random(self) + } +} + +impl WasiRandomView for WasiRandomCtx { + fn random(&mut self) -> &mut WasiRandomCtx { + self + } +} + /// Implement `insecure-random` using a deterministic cycle of bytes. pub struct Deterministic { cycle: std::iter::Cycle>, diff --git a/crates/wasi/src/sockets/mod.rs b/crates/wasi/src/sockets/mod.rs new file mode 100644 index 000000000000..48a48b031fb6 --- /dev/null +++ b/crates/wasi/src/sockets/mod.rs @@ -0,0 +1,159 @@ +use core::future::Future; +use core::ops::Deref; + +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +pub(crate) mod util; + +use wasmtime::component::ResourceTable; + +/// Value taken from rust std library. +pub const DEFAULT_TCP_BACKLOG: u32 = 128; + +/// Theoretical maximum byte size of a UDP datagram, the real limit is lower, +/// but we do not account for e.g. the transport layer here for simplicity. +/// In practice, datagrams are typically less than 1500 bytes. +pub const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize; + +#[derive(Clone, Default)] +pub struct WasiSocketsCtx { + pub socket_addr_check: SocketAddrCheck, + pub allowed_network_uses: AllowedNetworkUses, +} + +pub struct WasiSocketsCtxView<'a> { + pub ctx: &'a mut WasiSocketsCtx, + pub table: &'a mut ResourceTable, +} + +pub trait WasiSocketsView: Send { + fn sockets(&mut self) -> WasiSocketsCtxView<'_>; +} + +impl WasiSocketsView for &mut T { + fn sockets(&mut self) -> WasiSocketsCtxView<'_> { + T::sockets(self) + } +} + +impl WasiSocketsView for Box { + fn sockets(&mut self) -> WasiSocketsCtxView<'_> { + T::sockets(self) + } +} + +#[derive(Copy, Clone)] +pub struct AllowedNetworkUses { + pub ip_name_lookup: bool, + pub udp: bool, + pub tcp: bool, +} + +impl Default for AllowedNetworkUses { + fn default() -> Self { + Self { + ip_name_lookup: false, + udp: true, + tcp: true, + } + } +} + +impl AllowedNetworkUses { + pub(crate) fn check_allowed_udp(&self) -> std::io::Result<()> { + if !self.udp { + return Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "UDP is not allowed", + )); + } + + Ok(()) + } + + pub(crate) fn check_allowed_tcp(&self) -> std::io::Result<()> { + if !self.tcp { + return Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "TCP is not allowed", + )); + } + + Ok(()) + } +} + +/// A check that will be called for each socket address that is used of whether the address is permitted. +#[derive(Clone)] +pub struct SocketAddrCheck( + pub(crate) Arc< + dyn Fn(SocketAddr, SocketAddrUse) -> Pin + Send + Sync>> + + Send + + Sync, + >, +); + +impl SocketAddrCheck { + /// A check that will be called for each socket address that is used. + /// + /// Returning `true` will permit socket connections to the `SocketAddr`, + /// while returning `false` will reject the connection. + pub fn new( + f: impl Fn(SocketAddr, SocketAddrUse) -> Pin + Send + Sync>> + + Send + + Sync + + 'static, + ) -> Self { + Self(Arc::new(f)) + } + + pub async fn check(&self, addr: SocketAddr, reason: SocketAddrUse) -> std::io::Result<()> { + if (self.0)(addr, reason).await { + Ok(()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "An address was not permitted by the socket address check.", + )) + } + } +} + +impl Deref for SocketAddrCheck { + type Target = dyn Fn(SocketAddr, SocketAddrUse) -> Pin + Send + Sync>> + + Send + + Sync; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Default for SocketAddrCheck { + fn default() -> Self { + Self(Arc::new(|_, _| Box::pin(async { false }))) + } +} + +/// The reason what a socket address is being used for. +#[derive(Clone, Copy, Debug)] +pub enum SocketAddrUse { + /// Binding TCP socket + TcpBind, + /// Connecting TCP socket + TcpConnect, + /// Binding UDP socket + UdpBind, + /// Connecting UDP socket + UdpConnect, + /// Sending datagram on non-connected UDP socket + UdpOutgoingDatagram, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum SocketAddressFamily { + Ipv4, + Ipv6, +} diff --git a/crates/wasi/src/sockets/util.rs b/crates/wasi/src/sockets/util.rs new file mode 100644 index 000000000000..dafe6e25845e --- /dev/null +++ b/crates/wasi/src/sockets/util.rs @@ -0,0 +1,431 @@ +use core::fmt; +use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use core::str::FromStr as _; +use core::time::Duration; + +use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; +use rustix::fd::AsFd; +use rustix::io::Errno; +use rustix::net::{bind, connect_unspec, sockopt}; +use tracing::debug; + +use crate::sockets::SocketAddressFamily; + +#[derive(Debug)] +pub enum ErrorCode { + Unknown, + AccessDenied, + NotSupported, + InvalidArgument, + OutOfMemory, + Timeout, + InvalidState, + AddressNotBindable, + AddressInUse, + RemoteUnreachable, + ConnectionRefused, + ConnectionReset, + ConnectionAborted, + DatagramTooLarge, +} + +impl fmt::Display for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for ErrorCode {} + +fn is_deprecated_ipv4_compatible(addr: Ipv6Addr) -> bool { + matches!(addr.segments(), [0, 0, 0, 0, 0, 0, _, _]) + && addr != Ipv6Addr::UNSPECIFIED + && addr != Ipv6Addr::LOCALHOST +} + +pub fn is_valid_address_family(addr: IpAddr, socket_family: SocketAddressFamily) -> bool { + match (socket_family, addr) { + (SocketAddressFamily::Ipv4, IpAddr::V4(..)) => true, + (SocketAddressFamily::Ipv6, IpAddr::V6(ipv6)) => { + // Reject IPv4-*compatible* IPv6 addresses. They have been deprecated + // since 2006, OS handling of them is inconsistent and our own + // validations don't take them into account either. + // Note that these are not the same as IPv4-*mapped* IPv6 addresses. + !is_deprecated_ipv4_compatible(ipv6) && ipv6.to_ipv4_mapped().is_none() + } + _ => false, + } +} + +pub fn is_valid_remote_address(addr: SocketAddr) -> bool { + !addr.ip().to_canonical().is_unspecified() && addr.port() != 0 +} + +pub fn is_valid_unicast_address(addr: IpAddr) -> bool { + match addr.to_canonical() { + IpAddr::V4(ipv4) => !ipv4.is_multicast() && !ipv4.is_broadcast(), + IpAddr::V6(ipv6) => !ipv6.is_multicast(), + } +} + +pub fn to_ipv4_addr(addr: (u8, u8, u8, u8)) -> Ipv4Addr { + let (x0, x1, x2, x3) = addr; + Ipv4Addr::new(x0, x1, x2, x3) +} + +pub fn from_ipv4_addr(addr: Ipv4Addr) -> (u8, u8, u8, u8) { + let [x0, x1, x2, x3] = addr.octets(); + (x0, x1, x2, x3) +} + +pub fn to_ipv6_addr(addr: (u16, u16, u16, u16, u16, u16, u16, u16)) -> Ipv6Addr { + let (x0, x1, x2, x3, x4, x5, x6, x7) = addr; + Ipv6Addr::new(x0, x1, x2, x3, x4, x5, x6, x7) +} + +pub fn from_ipv6_addr(addr: Ipv6Addr) -> (u16, u16, u16, u16, u16, u16, u16, u16) { + let [x0, x1, x2, x3, x4, x5, x6, x7] = addr.segments(); + (x0, x1, x2, x3, x4, x5, x6, x7) +} + +/* + * Syscalls wrappers with (opinionated) portability fixes. + */ + +pub fn normalize_get_buffer_size(value: usize) -> usize { + if cfg!(target_os = "linux") { + // Linux doubles the value passed to setsockopt to allow space for bookkeeping overhead. + // getsockopt returns this internally doubled value. + // We'll half the value to at least get it back into the same ballpark that the application requested it in. + // + // This normalized behavior is tested for in: test-programs/src/bin/preview2_tcp_sockopts.rs + value / 2 + } else { + value + } +} + +pub fn normalize_set_buffer_size(value: usize) -> usize { + value.clamp(1, i32::MAX as usize) +} + +impl From for ErrorCode { + fn from(value: std::io::Error) -> Self { + (&value).into() + } +} + +impl From<&std::io::Error> for ErrorCode { + fn from(value: &std::io::Error) -> Self { + // Attempt the more detailed native error code first: + if let Some(errno) = Errno::from_io_error(value) { + return errno.into(); + } + + match value.kind() { + std::io::ErrorKind::AddrInUse => Self::AddressInUse, + std::io::ErrorKind::AddrNotAvailable => Self::AddressNotBindable, + std::io::ErrorKind::ConnectionAborted => Self::ConnectionAborted, + std::io::ErrorKind::ConnectionRefused => Self::ConnectionRefused, + std::io::ErrorKind::ConnectionReset => Self::ConnectionReset, + std::io::ErrorKind::InvalidInput => Self::InvalidArgument, + std::io::ErrorKind::NotConnected => Self::InvalidState, + std::io::ErrorKind::OutOfMemory => Self::OutOfMemory, + std::io::ErrorKind::PermissionDenied => Self::AccessDenied, + std::io::ErrorKind::TimedOut => Self::Timeout, + std::io::ErrorKind::Unsupported => Self::NotSupported, + _ => { + debug!("unknown I/O error: {value}"); + Self::Unknown + } + } + } +} + +impl From for ErrorCode { + fn from(value: Errno) -> Self { + (&value).into() + } +} + +impl From<&Errno> for ErrorCode { + fn from(value: &Errno) -> Self { + match *value { + #[cfg(not(windows))] + Errno::PERM => Self::AccessDenied, + Errno::ACCESS => Self::AccessDenied, + Errno::ADDRINUSE => Self::AddressInUse, + Errno::ADDRNOTAVAIL => Self::AddressNotBindable, + Errno::TIMEDOUT => Self::Timeout, + Errno::CONNREFUSED => Self::ConnectionRefused, + Errno::CONNRESET => Self::ConnectionReset, + Errno::CONNABORTED => Self::ConnectionAborted, + Errno::INVAL => Self::InvalidArgument, + Errno::HOSTUNREACH => Self::RemoteUnreachable, + Errno::HOSTDOWN => Self::RemoteUnreachable, + Errno::NETDOWN => Self::RemoteUnreachable, + Errno::NETUNREACH => Self::RemoteUnreachable, + #[cfg(target_os = "linux")] + Errno::NONET => Self::RemoteUnreachable, + Errno::ISCONN => Self::InvalidState, + Errno::NOTCONN => Self::InvalidState, + Errno::DESTADDRREQ => Self::InvalidState, + Errno::MSGSIZE => Self::DatagramTooLarge, + #[cfg(not(windows))] + Errno::NOMEM => Self::OutOfMemory, + Errno::NOBUFS => Self::OutOfMemory, + Errno::OPNOTSUPP => Self::NotSupported, + Errno::NOPROTOOPT => Self::NotSupported, + Errno::PFNOSUPPORT => Self::NotSupported, + Errno::PROTONOSUPPORT => Self::NotSupported, + Errno::PROTOTYPE => Self::NotSupported, + Errno::SOCKTNOSUPPORT => Self::NotSupported, + Errno::AFNOSUPPORT => Self::NotSupported, + + // FYI, EINPROGRESS should have already been handled by connect. + _ => { + debug!("unknown I/O error: {value}"); + Self::Unknown + } + } + } +} + +pub fn get_ip_ttl(fd: impl AsFd) -> Result { + let v = sockopt::ip_ttl(fd)?; + let Ok(v) = v.try_into() else { + return Err(ErrorCode::NotSupported); + }; + Ok(v) +} + +pub fn get_ipv6_unicast_hops(fd: impl AsFd) -> Result { + let v = sockopt::ipv6_unicast_hops(fd)?; + Ok(v) +} + +pub fn get_unicast_hop_limit(fd: impl AsFd, family: SocketAddressFamily) -> Result { + match family { + SocketAddressFamily::Ipv4 => get_ip_ttl(fd), + SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(fd), + } +} + +pub fn set_unicast_hop_limit( + fd: impl AsFd, + family: SocketAddressFamily, + value: u8, +) -> Result<(), ErrorCode> { + if value == 0 { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + // + // A well-behaved IP application should never send out new packets with TTL 0. + // We validate the value ourselves because OS'es are not consistent in this. + // On Linux the validation is even inconsistent between their IPv4 and IPv6 implementation. + return Err(ErrorCode::InvalidArgument); + } + match family { + SocketAddressFamily::Ipv4 => { + sockopt::set_ip_ttl(fd, value.into())?; + } + SocketAddressFamily::Ipv6 => { + sockopt::set_ipv6_unicast_hops(fd, Some(value))?; + } + } + Ok(()) +} + +pub fn receive_buffer_size(fd: impl AsFd) -> Result { + let v = sockopt::socket_recv_buffer_size(fd)?; + Ok(normalize_get_buffer_size(v).try_into().unwrap_or(u64::MAX)) +} + +pub fn set_receive_buffer_size(fd: impl AsFd, value: u64) -> Result { + if value == 0 { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + return Err(ErrorCode::InvalidArgument); + } + let value = value.try_into().unwrap_or(usize::MAX); + let value = normalize_set_buffer_size(value); + match sockopt::set_socket_recv_buffer_size(fd, value) { + // Most platforms (Linux, Windows, Fuchsia, Solaris, Illumos, Haiku, ESP-IDF, ..and more?) treat the value + // passed to SO_SNDBUF/SO_RCVBUF as a performance tuning hint and silently clamp the input if it exceeds + // their capability. + // As far as I can see, only the *BSD family views this option as a hard requirement and fails when the + // value is out of range. We normalize this behavior in favor of the more commonly understood + // "performance hint" semantics. In other words; even ENOBUFS is "Ok". + // A future improvement could be to query the corresponding sysctl on *BSD platforms and clamp the input + // `size` ourselves, to completely close the gap with other platforms. + // + // This normalized behavior is tested for in: test-programs/src/bin/preview2_tcp_sockopts.rs + Err(Errno::NOBUFS) => {} + Err(err) => return Err(err.into()), + _ => {} + }; + Ok(value) +} + +pub fn send_buffer_size(fd: impl AsFd) -> Result { + let v = sockopt::socket_send_buffer_size(fd)?; + Ok(normalize_get_buffer_size(v).try_into().unwrap_or(u64::MAX)) +} + +pub fn set_send_buffer_size(fd: impl AsFd, value: u64) -> Result { + if value == 0 { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + return Err(ErrorCode::InvalidArgument); + } + let value = value.try_into().unwrap_or(usize::MAX); + let value = normalize_set_buffer_size(value); + match sockopt::set_socket_send_buffer_size(fd, value) { + Err(Errno::NOBUFS) => {} + Err(err) => return Err(err.into()), + _ => {} + }; + Ok(value) +} + +pub fn set_keep_alive_idle_time(fd: impl AsFd, value: u64) -> Result { + const NANOS_PER_SEC: u64 = 1_000_000_000; + + // Ensure that the value passed to the actual syscall never gets rounded down to 0. + const MIN: u64 = NANOS_PER_SEC; + + // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. + const MAX: u64 = (i16::MAX as u64) * NANOS_PER_SEC; + + if value <= 0 { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + return Err(ErrorCode::InvalidArgument); + } + let value = value.clamp(MIN, MAX); + sockopt::set_tcp_keepidle(fd, Duration::from_nanos(value))?; + Ok(value) +} + +pub fn set_keep_alive_interval(fd: impl AsFd, value: Duration) -> Result<(), ErrorCode> { + // Ensure that any fractional value passed to the actual syscall never gets rounded down to 0. + const MIN: Duration = Duration::from_secs(1); + + // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. + const MAX: Duration = Duration::from_secs(i16::MAX as u64); + + if value <= Duration::ZERO { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + return Err(ErrorCode::InvalidArgument); + } + sockopt::set_tcp_keepintvl(fd, value.clamp(MIN, MAX))?; + Ok(()) +} + +pub fn set_keep_alive_count(fd: impl AsFd, value: u32) -> Result<(), ErrorCode> { + const MIN_CNT: u32 = 1; + // Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms. + const MAX_CNT: u32 = i8::MAX as u32; + + if value == 0 { + // WIT: "If the provided value is 0, an `invalid-argument` error is returned." + return Err(ErrorCode::InvalidArgument); + } + sockopt::set_tcp_keepcnt(fd, value.clamp(MIN_CNT, MAX_CNT))?; + Ok(()) +} + +pub fn tcp_bind( + socket: &tokio::net::TcpSocket, + local_address: SocketAddr, +) -> Result<(), ErrorCode> { + // Automatically bypass the TIME_WAIT state when binding to a specific port + // Unconditionally (re)set SO_REUSEADDR, even when the value is false. + // This ensures we're not accidentally affected by any socket option + // state left behind by a previous failed call to this method. + #[cfg(not(windows))] + if let Err(err) = sockopt::set_socket_reuseaddr(&socket, local_address.port() > 0) { + return Err(err.into()); + } + + // Perform the OS bind call. + socket + .bind(local_address) + .map_err(|err| match Errno::from_io_error(&err) { + // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: + // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket + // + // The most common reasons for this error should have already + // been handled by our own validation slightly higher up in this + // function. This error mapping is here just in case there is + // an edge case we didn't catch. + Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument, + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS + // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. + #[cfg(windows)] + Some(Errno::NOBUFS) => ErrorCode::AddressInUse, + _ => err.into(), + }) +} + +pub fn udp_socket(family: AddressFamily) -> std::io::Result { + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: + // - On Windows: call WSAStartup if not done before. + // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, + // or afterwards using ioctl or fcntl. Exact method depends on the platform. + + let socket = cap_std::net::UdpSocket::new(family, Blocking::No)?; + Ok(socket) +} + +pub fn udp_bind(sockfd: impl AsFd, addr: SocketAddr) -> Result<(), ErrorCode> { + bind(sockfd, &addr).map_err(|err| match err { + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS + // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. + #[cfg(windows)] + Errno::NOBUFS => ErrorCode::AddressInUse, + // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: + // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket + // + // The most common reasons for this error should have already + // been handled by our own validation slightly higher up in this + // function. This error mapping is here just in case there is + // an edge case we didn't catch. + Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, + _ => err.into(), + }) +} + +pub fn udp_disconnect(sockfd: impl AsFd) -> Result<(), ErrorCode> { + match connect_unspec(sockfd) { + // BSD platforms return an error even if the UDP socket was disconnected successfully. + // + // MacOS was kind enough to document this: https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/connect.2.html + // > Datagram sockets may dissolve the association by connecting to an + // > invalid address, such as a null address or an address with the address + // > family set to AF_UNSPEC (the error EAFNOSUPPORT will be harmlessly + // > returned). + // + // ... except that this appears to be incomplete, because experiments + // have shown that MacOS actually returns EINVAL, depending on the + // address family of the socket. + #[cfg(target_os = "macos")] + Err(Errno::INVAL | Errno::AFNOSUPPORT) => Ok(()), + Err(err) => Err(err.into()), + Ok(()) => Ok(()), + } +} + +pub fn parse_host(name: &str) -> Result { + // `url::Host::parse` serves us two functions: + // 1. validate the input is a valid domain name or IP, + // 2. convert unicode domains to punycode. + match url::Host::parse(&name) { + Ok(host) => Ok(host), + + // `url::Host::parse` doesn't understand bare IPv6 addresses without [brackets] + Err(_) => { + if let Ok(addr) = Ipv6Addr::from_str(name) { + Ok(url::Host::Ipv6(addr)) + } else { + Err(ErrorCode::InvalidArgument) + } + } + } +} diff --git a/crates/wasi/tests/all/p3/mod.rs b/crates/wasi/tests/all/p3/mod.rs index 9f9d247b9c52..3adf0cd3f4a7 100644 --- a/crates/wasi/tests/all/p3/mod.rs +++ b/crates/wasi/tests/all/p3/mod.rs @@ -97,6 +97,11 @@ async fn run(path: &str) -> anyhow::Result<()> { foreach_p3!(assert_test_exists); +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_cli() -> anyhow::Result<()> { + run(P3_CLI_COMPONENT).await +} + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_clocks_sleep() -> anyhow::Result<()> { run(P3_CLOCKS_SLEEP_COMPONENT).await @@ -108,6 +113,61 @@ async fn p3_random_imports() -> anyhow::Result<()> { } #[test_log::test(tokio::test(flavor = "multi_thread"))] -async fn p3_cli() -> anyhow::Result<()> { - run(P3_CLI_COMPONENT).await +async fn p3_sockets_ip_name_lookup() -> anyhow::Result<()> { + run(P3_SOCKETS_IP_NAME_LOOKUP_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_bind() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_BIND_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_connect() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_CONNECT_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_sample_application() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_SAMPLE_APPLICATION_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_sockopts() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_SOCKOPTS_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_states() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_STATES_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_tcp_streams() -> anyhow::Result<()> { + run(P3_SOCKETS_TCP_STREAMS_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_udp_bind() -> anyhow::Result<()> { + run(P3_SOCKETS_UDP_BIND_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_udp_connect() -> anyhow::Result<()> { + run(P3_SOCKETS_UDP_CONNECT_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_udp_sample_application() -> anyhow::Result<()> { + run(P3_SOCKETS_UDP_SAMPLE_APPLICATION_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_udp_sockopts() -> anyhow::Result<()> { + run(P3_SOCKETS_UDP_SOCKOPTS_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_sockets_udp_states() -> anyhow::Result<()> { + run(P3_SOCKETS_UDP_STATES_COMPONENT).await }