From d4a1b6af62c26f90682b227f82c35f3298332260 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Fri, 1 Aug 2025 14:37:20 -0700 Subject: [PATCH 1/6] Use the same `UdpSocket` in WASIp{2,3} This commit refactors the implementation of `wasi:sockets` for WASIp2 and WASIp3 to use the same underlying host data structure for the `UdpSocket` resource in WIT. Previously each version of WASI had its own socket which resulted in duplicated code. There's some minor differences between WASIp2 and WASIp3 but it's easy enough to paper over with the same socket type. This is intended to help with the maintainability of this going forward to only have one type to operate on rather than two (which also ensures that bugfixes for one should affect the other). One other change made in this commit is that sprinkled checks for whether or not UDP is allowed are all removed and canonicalized during UDP socket creation. This means that UDP socket creation is the only location that checks for whether UDP is allowed. Once a UDP socket is created it can be used freely regardless of whether the UDP setting is enabled or disabled. This is not intended to have a large practical effect but it does mean the behavior of hosts that deny UDP but manually give access to a UDP socket resource to a component may behave subtly differently. --- crates/wasi/src/p2/bindings.rs | 4 +- crates/wasi/src/p2/host/udp.rs | 180 +++++-------------- crates/wasi/src/p2/host/udp_create_socket.rs | 4 +- crates/wasi/src/p2/network.rs | 1 + crates/wasi/src/p2/udp.rs | 88 --------- crates/wasi/src/p3/bindings.rs | 2 +- crates/wasi/src/p3/sockets/conv.rs | 27 ++- crates/wasi/src/p3/sockets/host/types/udp.rs | 39 ++-- crates/wasi/src/p3/sockets/mod.rs | 1 - crates/wasi/src/sockets/mod.rs | 3 + crates/wasi/src/{p3 => }/sockets/udp.rs | 98 +++++++--- crates/wasi/src/sockets/util.rs | 1 + 12 files changed, 159 insertions(+), 289 deletions(-) rename crates/wasi/src/{p3 => }/sockets/udp.rs (77%) diff --git a/crates/wasi/src/p2/bindings.rs b/crates/wasi/src/p2/bindings.rs index 454bda9056b2..7fb5e7b136da 100644 --- a/crates/wasi/src/p2/bindings.rs +++ b/crates/wasi/src/p2/bindings.rs @@ -173,7 +173,7 @@ pub mod sync { "wasi:sockets/tcp/tcp-socket": super::super::sockets::tcp::TcpSocket, "wasi:sockets/udp/incoming-datagram-stream": super::super::sockets::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": super::super::sockets::udp::OutgoingDatagramStream, - "wasi:sockets/udp/udp-socket": super::super::sockets::udp::UdpSocket, + "wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket, // Error host trait from wasmtime-wasi-io is synchronous, so we can alias it "wasi:io/error": wasmtime_wasi_io::bindings::wasi::io::error, @@ -394,7 +394,7 @@ mod async_io { // this crate "wasi:sockets/network/network": crate::p2::network::Network, "wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket, - "wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket, + "wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket, "wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream, "wasi:sockets/ip-name-lookup/resolve-address-stream": crate::p2::ip_name_lookup::ResolveAddressStream, diff --git a/crates/wasi/src/p2/host/udp.rs b/crates/wasi/src/p2/host/udp.rs index 579499d4b070..508f744c52f7 100644 --- a/crates/wasi/src/p2/host/udp.rs +++ b/crates/wasi/src/p2/host/udp.rs @@ -1,19 +1,13 @@ use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}; use crate::p2::bindings::sockets::udp; -use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState}; +use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState}; use crate::p2::{Pollable, SocketError, SocketResult}; -use crate::sockets::util::{ - get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address, - receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, - set_unicast_hop_limit, udp_bind, udp_disconnect, -}; +use crate::sockets::util::{is_valid_address_family, is_valid_remote_address}; use crate::sockets::{ - MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView, + MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView, }; use anyhow::anyhow; use async_trait::async_trait; -use io_lifetimes::AsSocketlike; -use rustix::io::Errno; use std::net::SocketAddr; use tokio::io::Interest; use wasmtime::component::Resource; @@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { network: Resource, local_address: IpSocketAddress, ) -> SocketResult<()> { - self.ctx.allowed_network_uses.check_allowed_udp()?; - - match self.table.get(&this)?.udp_state { - UdpState::Default => {} - UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), - UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()), - } - - // Set the socket addr check on the socket so later functions have access to it through the socket handle + let local_address = SocketAddr::from(local_address); let check = self.table.get(&network)?.socket_addr_check.clone(); - self.table - .get_mut(&this)? - .socket_addr_check - .replace(check.clone()); - - let socket = self.table.get(&this)?; - let local_address: SocketAddr = local_address.into(); - - if !is_valid_address_family(local_address.ip(), socket.family) { - return Err(ErrorCode::InvalidArgument.into()); - } - - { - check.check(local_address, SocketAddrUse::UdpBind).await?; - - // Perform the OS bind call. - udp_bind(socket.udp_socket(), local_address)?; - } + check.check(local_address, SocketAddrUse::UdpBind).await?; let socket = self.table.get_mut(&this)?; - socket.udp_state = UdpState::BindStarted; + socket.bind(local_address)?; + socket.set_socket_addr_check(Some(check.clone())); Ok(()) } fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { - let socket = self.table.get_mut(&this)?; - - match socket.udp_state { - UdpState::BindStarted => { - socket.udp_state = UdpState::Bound; - Ok(()) - } - _ => Err(ErrorCode::NotInProgress.into()), - } + self.table.get_mut(&this)?.finish_bind()?; + Ok(()) } async fn stream( @@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { let socket = self.table.get_mut(&this)?; let remote_address = remote_address.map(SocketAddr::from); - match socket.udp_state { - UdpState::Bound | UdpState::Connected => {} - _ => return Err(ErrorCode::InvalidState.into()), + if !socket.is_bound() { + return Err(ErrorCode::InvalidState.into()); } // We disconnect & (re)connect in two distinct steps for two reasons: @@ -107,48 +69,30 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { // if there isn't a disconnect in between. // Step #1: Disconnect - if let UdpState::Connected = socket.udp_state { - udp_disconnect(socket.udp_socket())?; - socket.udp_state = UdpState::Bound; + if socket.is_connected() { + socket.disconnect()?; } // Step #2: (Re)connect if let Some(connect_addr) = remote_address { - let Some(check) = socket.socket_addr_check.as_ref() else { + let connect_addr = SocketAddr::from(connect_addr); + let Some(check) = socket.socket_addr_check() else { return Err(ErrorCode::InvalidState.into()); }; - if !is_valid_remote_address(connect_addr) - || !is_valid_address_family(connect_addr.ip(), socket.family) - { - return Err(ErrorCode::InvalidArgument.into()); - } check.check(connect_addr, SocketAddrUse::UdpConnect).await?; - - rustix::net::connect(socket.udp_socket(), &connect_addr).map_err( - |error| match error { - Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `bind` implementation. - Errno::INPROGRESS => { - tracing::debug!( - "UDP connect returned EINPROGRESS, which should never happen" - ); - ErrorCode::Unknown - } - _ => ErrorCode::from(error), - }, - )?; - socket.udp_state = UdpState::Connected; + socket.connect(connect_addr)?; } let incoming_stream = IncomingDatagramStream { - inner: socket.inner.clone(), + inner: socket.socket().clone(), remote_address, }; let outgoing_stream = OutgoingDatagramStream { - inner: socket.inner.clone(), + inner: socket.socket().clone(), remote_address, - family: socket.family, + family: socket.address_family(), send_state: SendState::Idle, - socket_addr_check: socket.socket_addr_check.clone(), + socket_addr_check: socket.socket_addr_check().cloned(), }; Ok(( @@ -159,33 +103,12 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { fn local_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - match socket.udp_state { - UdpState::Default => return Err(ErrorCode::InvalidState.into()), - UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), - _ => {} - } - - let addr = socket - .udp_socket() - .as_socketlike_view::() - .local_addr()?; - Ok(addr.into()) + Ok(socket.local_address()?.into()) } fn remote_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - match socket.udp_state { - UdpState::Connected => {} - _ => return Err(ErrorCode::InvalidState.into()), - } - - let addr = socket - .udp_socket() - .as_socketlike_view::() - .peer_addr()?; - Ok(addr.into()) + Ok(socket.remote_address()?.into()) } fn address_family( @@ -193,22 +116,12 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { this: Resource, ) -> Result { let socket = self.table.get(&this)?; - - match socket.family { - SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), - SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), - } + Ok(socket.address_family().into()) } fn unicast_hop_limit(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let ttl = match socket.family { - SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?, - SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?, - }; - - Ok(ttl) + Ok(socket.unicast_hop_limit()?) } fn set_unicast_hop_limit( @@ -217,17 +130,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { value: u8, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?; - + socket.set_unicast_hop_limit(value)?; Ok(()) } fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let value = receive_buffer_size(socket.udp_socket())?; - Ok(value) + Ok(socket.receive_buffer_size()?) } fn set_receive_buffer_size( @@ -236,33 +145,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { value: u64, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_receive_buffer_size(socket.udp_socket(), value)?; + socket.set_receive_buffer_size(value)?; Ok(()) } fn send_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let value = send_buffer_size(socket.udp_socket())?; - Ok(value) + Ok(socket.send_buffer_size()?) } - fn set_send_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> SocketResult<()> { + fn set_send_buffer_size(&mut self, this: Resource, value: u64) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_send_buffer_size(socket.udp_socket(), value)?; + socket.set_send_buffer_size(value)?; Ok(()) } - fn subscribe( - &mut self, - this: Resource, - ) -> anyhow::Result> { + fn subscribe(&mut self, this: Resource) -> anyhow::Result> { wasmtime_wasi_io::poll::subscribe(self.table, this) } @@ -276,6 +174,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { } } +#[async_trait] +impl Pollable for UdpSocket { + async fn ready(&mut self) { + // None of the socket-level operations block natively + } +} + impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { fn receive( &mut self, @@ -504,6 +409,15 @@ impl Pollable for OutgoingDatagramStream { } } +impl From for IpAddressFamily { + fn from(family: SocketAddressFamily) -> IpAddressFamily { + match family { + SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, + SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, + } + } +} + pub mod sync { use wasmtime::component::Resource; diff --git a/crates/wasi/src/p2/host/udp_create_socket.rs b/crates/wasi/src/p2/host/udp_create_socket.rs index 3f4a05c50e00..ac05dc51d838 100644 --- a/crates/wasi/src/p2/host/udp_create_socket.rs +++ b/crates/wasi/src/p2/host/udp_create_socket.rs @@ -1,6 +1,6 @@ use crate::p2::SocketResult; use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket}; -use crate::p2::udp::UdpSocket; +use crate::sockets::UdpSocket; use crate::sockets::WasiSocketsCtxView; use wasmtime::component::Resource; @@ -9,7 +9,7 @@ impl udp_create_socket::Host for WasiSocketsCtxView<'_> { &mut self, address_family: IpAddressFamily, ) -> SocketResult> { - let socket = UdpSocket::new(address_family.into())?; + let socket = UdpSocket::new(self.ctx, address_family.into())?; let socket = self.table.push(socket)?; Ok(socket) } diff --git a/crates/wasi/src/p2/network.rs b/crates/wasi/src/p2/network.rs index 68b3f7adb09a..8534d0320485 100644 --- a/crates/wasi/src/p2/network.rs +++ b/crates/wasi/src/p2/network.rs @@ -48,6 +48,7 @@ impl From for ErrorCode { crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress, } } } diff --git a/crates/wasi/src/p2/udp.rs b/crates/wasi/src/p2/udp.rs index a066fb97eb15..0165a9550410 100644 --- a/crates/wasi/src/p2/udp.rs +++ b/crates/wasi/src/p2/udp.rs @@ -1,94 +1,6 @@ -use crate::runtime::with_ambient_tokio_runtime; -use crate::sockets::util::udp_socket; use crate::sockets::{SocketAddrCheck, SocketAddressFamily}; -use async_trait::async_trait; -use cap_net_ext::AddressFamily; -use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; -use std::io; use std::net::SocketAddr; use std::sync::Arc; -use wasmtime_wasi_io::poll::Pollable; - -/// The state of a UDP socket. -/// -/// This represents the various states a socket can be in during the -/// activities of binding, and connecting. -pub(crate) enum UdpState { - /// The initial state for a newly-created socket. - Default, - - /// Binding started via `start_bind`. - BindStarted, - - /// Binding finished via `finish_bind`. The socket has an address but - /// is not yet listening for connections. - Bound, - - /// The socket is "connected" to a peer address. - Connected, -} - -/// A host UDP socket, plus associated bookkeeping. -/// -/// The inner state is wrapped in an Arc because the same underlying socket is -/// used for implementing the stream types. -pub struct UdpSocket { - /// The part of a `UdpSocket` which is reference-counted so that we - /// can pass it to async tasks. - pub(crate) inner: Arc, - - /// The current state in the bind/connect progression. - pub(crate) udp_state: UdpState, - - /// Socket address family. - pub(crate) family: SocketAddressFamily, - - /// The check of allowed addresses - pub(crate) socket_addr_check: Option, -} - -#[async_trait] -impl Pollable for UdpSocket { - async fn ready(&mut self) { - // None of the socket-level operations block natively - } -} - -impl UdpSocket { - /// Create a new socket in the given family. - pub fn new(family: AddressFamily) -> io::Result { - // Create a new host socket and set it to non-blocking, which is needed - // by our async implementation. - let fd = udp_socket(family)?; - - let socket_address_family = match family { - AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, - AddressFamily::Ipv6 => { - rustix::net::sockopt::set_ipv6_v6only(&fd, true)?; - SocketAddressFamily::Ipv6 - } - }; - - let socket = Self::setup_tokio_udp_socket(fd.into())?; - - Ok(UdpSocket { - inner: Arc::new(socket), - udp_state: UdpState::Default, - family: socket_address_family, - socket_addr_check: None, - }) - } - - fn setup_tokio_udp_socket(fd: rustix::fd::OwnedFd) -> io::Result { - let std_socket = - unsafe { std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike()) }; - with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket)) - } - - pub fn udp_socket(&self) -> &tokio::net::UdpSocket { - &self.inner - } -} pub struct IncomingDatagramStream { pub(crate) inner: Arc, diff --git a/crates/wasi/src/p3/bindings.rs b/crates/wasi/src/p3/bindings.rs index ae0c1eb6902a..89c9bd5888b8 100644 --- a/crates/wasi/src/p3/bindings.rs +++ b/crates/wasi/src/p3/bindings.rs @@ -95,7 +95,7 @@ mod generated { "wasi:cli/terminal-input/terminal-input": crate::p3::cli::TerminalInput, "wasi:cli/terminal-output/terminal-output": crate::p3::cli::TerminalOutput, "wasi:sockets/types/tcp-socket": crate::p3::sockets::tcp::TcpSocket, - "wasi:sockets/types/udp-socket": crate::p3::sockets::udp::UdpSocket, + "wasi:sockets/types/udp-socket": crate::sockets::UdpSocket, }, trappable_error_type: { "wasi:sockets/types/error-code" => crate::p3::sockets::SocketError, diff --git a/crates/wasi/src/p3/sockets/conv.rs b/crates/wasi/src/p3/sockets/conv.rs index ee4920dfa0d8..03e1aff33a15 100644 --- a/crates/wasi/src/p3/sockets/conv.rs +++ b/crates/wasi/src/p3/sockets/conv.rs @@ -1,13 +1,12 @@ +use crate::p3::bindings::sockets::types; +use crate::p3::sockets::SocketError; +use crate::sockets::SocketAddressFamily; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; use core::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; - -use std::net::ToSocketAddrs; - use rustix::io::Errno; +use std::net::ToSocketAddrs; use tracing::debug; -use crate::p3::bindings::sockets::types; -use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; - impl From for types::IpAddress { fn from(addr: IpAddr) -> Self { match addr { @@ -123,6 +122,15 @@ impl From for types::IpAddressFamily { } } +impl From for types::IpAddressFamily { + fn from(family: SocketAddressFamily) -> Self { + match family { + SocketAddressFamily::Ipv4 => Self::Ipv4, + SocketAddressFamily::Ipv6 => Self::Ipv6, + } + } +} + impl From for types::ErrorCode { fn from(value: std::io::Error) -> Self { (&value).into() @@ -222,6 +230,13 @@ impl From for types::ErrorCode { crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + crate::sockets::util::ErrorCode::NotInProgress => Self::InvalidState, } } } + +impl From for SocketError { + fn from(code: crate::sockets::util::ErrorCode) -> Self { + SocketError::from(types::ErrorCode::from(code)) + } +} diff --git a/crates/wasi/src/p3/sockets/host/types/udp.rs b/crates/wasi/src/p3/sockets/host/types/udp.rs index e6a48b92ac22..e9f4ecf0c39a 100644 --- a/crates/wasi/src/p3/sockets/host/types/udp.rs +++ b/crates/wasi/src/p3/sockets/host/types/udp.rs @@ -1,20 +1,14 @@ +use super::is_addr_allowed; use crate::TrappableError; use crate::p3::bindings::sockets::types::{ ErrorCode, HostUdpSocket, HostUdpSocketWithStore, IpAddressFamily, IpSocketAddress, }; -use crate::p3::sockets::udp::UdpSocket; use crate::p3::sockets::{SocketResult, WasiSockets}; -use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, WasiSocketsCtxView}; -use anyhow::Context as _; -use core::net::SocketAddr; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, UdpSocket, WasiSocketsCtxView}; +use anyhow::Context; +use std::net::SocketAddr; use wasmtime::component::{Accessor, Resource, ResourceTable}; -use super::is_addr_allowed; - -fn is_udp_allowed(store: &Accessor) -> bool { - store.with(|mut view| view.get().ctx.allowed_network_uses.udp) -} - fn get_socket<'a>( table: &'a ResourceTable, socket: &'a Resource, @@ -42,14 +36,13 @@ impl HostUdpSocketWithStore for WasiSockets { local_address: IpSocketAddress, ) -> SocketResult<()> { let local_address = SocketAddr::from(local_address); - if !is_udp_allowed(store) - || !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await - { + if !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await { return Err(ErrorCode::AccessDenied.into()); } store.with(|mut view| { let socket = get_socket_mut(view.get().table, &socket)?; socket.bind(local_address)?; + socket.finish_bind()?; Ok(()) }) } @@ -60,9 +53,7 @@ impl HostUdpSocketWithStore for WasiSockets { remote_address: IpSocketAddress, ) -> SocketResult<()> { let remote_address = SocketAddr::from(remote_address); - if !is_udp_allowed(store) - || !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await - { + if !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await { return Err(ErrorCode::AccessDenied.into()); } store.with(|mut view| { @@ -81,9 +72,6 @@ impl HostUdpSocketWithStore for WasiSockets { if data.len() > MAX_UDP_DATAGRAM_SIZE { return Err(ErrorCode::DatagramTooLarge.into()); } - if !is_udp_allowed(store) { - return Err(ErrorCode::AccessDenied.into()); - } if let Some(addr) = remote_address { let addr = SocketAddr::from(addr); if !is_addr_allowed(store, addr, SocketAddrUse::UdpOutgoingDatagram).await { @@ -107,19 +95,16 @@ impl HostUdpSocketWithStore for WasiSockets { store: &Accessor, socket: Resource, ) -> SocketResult<(Vec, IpSocketAddress)> { - if !is_udp_allowed(store) { - return Err(ErrorCode::AccessDenied.into()); - } let fut = store .with(|mut view| get_socket(view.get().table, &socket).map(|sock| sock.receive()))?; let (result, addr) = fut.await?; - Ok((result, addr)) + Ok((result, addr.into())) } } impl HostUdpSocket for WasiSocketsCtxView<'_> { fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { - let socket = UdpSocket::new(address_family.into()).context("failed to create socket")?; + let socket = UdpSocket::new(self.ctx, address_family.into())?; self.table .push(socket) .context("failed to push socket resource to table") @@ -133,17 +118,17 @@ impl HostUdpSocket for WasiSocketsCtxView<'_> { fn local_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.local_address()?) + Ok(sock.local_address()?.into()) } fn remote_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.remote_address()?) + Ok(sock.remote_address()?.into()) } fn address_family(&mut self, socket: Resource) -> wasmtime::Result { let sock = get_socket(self.table, &socket)?; - Ok(sock.address_family()) + Ok(sock.address_family().into()) } fn unicast_hop_limit(&mut self, socket: Resource) -> SocketResult { diff --git a/crates/wasi/src/p3/sockets/mod.rs b/crates/wasi/src/p3/sockets/mod.rs index ade030de49bf..07b4db6bd3d1 100644 --- a/crates/wasi/src/p3/sockets/mod.rs +++ b/crates/wasi/src/p3/sockets/mod.rs @@ -6,7 +6,6 @@ use wasmtime::component::Linker; mod conv; mod host; pub mod tcp; -pub mod udp; pub type SocketResult = Result; pub type SocketError = TrappableError; diff --git a/crates/wasi/src/sockets/mod.rs b/crates/wasi/src/sockets/mod.rs index a7c8fbff2aa5..ce78e78f08c6 100644 --- a/crates/wasi/src/sockets/mod.rs +++ b/crates/wasi/src/sockets/mod.rs @@ -5,8 +5,11 @@ use std::pin::Pin; use std::sync::Arc; use wasmtime::component::{HasData, ResourceTable}; +mod udp; pub(crate) mod util; +pub use udp::UdpSocket; + pub(crate) struct WasiSockets; impl HasData for WasiSockets { diff --git a/crates/wasi/src/p3/sockets/udp.rs b/crates/wasi/src/sockets/udp.rs similarity index 77% rename from crates/wasi/src/p3/sockets/udp.rs rename to crates/wasi/src/sockets/udp.rs index e55054334d71..95c0af1bcb14 100644 --- a/crates/wasi/src/p3/sockets/udp.rs +++ b/crates/wasi/src/sockets/udp.rs @@ -1,24 +1,21 @@ -use core::future::Future; -use core::net::SocketAddr; - -use std::sync::Arc; - +use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::{ + ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, + receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, + set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket, +}; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx}; use cap_net_ext::AddressFamily; +use core::future::Future; use io_lifetimes::AsSocketlike as _; use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _}; use rustix::io::Errno; use rustix::net::connect; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; use tracing::debug; -use crate::p3::bindings::sockets::types::{ErrorCode, IpAddressFamily, IpSocketAddress}; -use crate::runtime::with_ambient_tokio_runtime; -use crate::sockets::util::{ - get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, receive_buffer_size, - send_buffer_size, set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, - udp_bind, udp_disconnect, udp_socket, -}; -use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddressFamily}; - /// The state of a UDP socket. /// /// This represents the various states a socket can be in during the @@ -27,6 +24,9 @@ enum UdpState { /// The initial state for a newly-created socket. Default, + /// TODO + BindStarted, + /// Binding finished via `finish_bind`. The socket has an address but /// is not yet listening for connections. Bound, @@ -47,11 +47,17 @@ pub struct UdpSocket { /// Socket address family. family: SocketAddressFamily, + + /// If set, use this custom check for addrs, otherwise use what's in + /// `WasiSocketsCtx`. + socket_addr_check: Option, } impl UdpSocket { /// Create a new socket in the given family. - pub(crate) fn new(family: AddressFamily) -> std::io::Result { + pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> io::Result { + cx.allowed_network_uses.check_allowed_udp()?; + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: // - On Windows: call WSAStartup if not done before. // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, @@ -77,6 +83,7 @@ impl UdpSocket { socket: Arc::new(socket), udp_state: UdpState::Default, family: socket_address_family, + socket_addr_check: None, }) } @@ -88,12 +95,30 @@ impl UdpSocket { return Err(ErrorCode::InvalidArgument); } udp_bind(&self.socket, addr)?; - self.udp_state = UdpState::Bound; + self.udp_state = UdpState::BindStarted; Ok(()) } + pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> { + match self.udp_state { + UdpState::BindStarted => { + self.udp_state = UdpState::Bound; + Ok(()) + } + _ => Err(ErrorCode::NotInProgress), + } + } + + pub(crate) fn is_connected(&self) -> bool { + matches!(self.udp_state, UdpState::Connected(..)) + } + + pub(crate) fn is_bound(&self) -> bool { + matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound) + } + pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> { - if !matches!(self.udp_state, UdpState::Connected(..)) { + if !self.is_connected() { return Err(ErrorCode::InvalidState); } udp_disconnect(&self.socket)?; @@ -106,6 +131,11 @@ impl UdpSocket { return Err(ErrorCode::InvalidArgument); } + match self.udp_state { + UdpState::Bound | UdpState::Connected(_) => {} + _ => return Err(ErrorCode::InvalidState.into()), + } + // We disconnect & (re)connect in two distinct steps for two reasons: // - To leave our socket instance in a consistent state in case the // connect fails. @@ -152,6 +182,7 @@ impl UdpSocket { SendTo(Arc, SocketAddr), } let socket = match &self.udp_state { + UdpState::BindStarted => Err(ErrorCode::InvalidState), UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)), UdpState::Connected(caddr) if addr == *caddr => { Ok(Mode::Send(Arc::clone(&self.socket))) @@ -168,13 +199,13 @@ impl UdpSocket { pub(crate) fn receive( &self, - ) -> impl Future, IpSocketAddress), ErrorCode>> + use<> { + ) -> impl Future, SocketAddr), ErrorCode>> + use<> { enum Mode { - Recv(Arc, IpSocketAddress), + Recv(Arc, SocketAddr), RecvFrom(Arc), } let socket = match self.udp_state { - UdpState::Default => Err(ErrorCode::InvalidState), + UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState), UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))), UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr.into())), }; @@ -188,7 +219,7 @@ impl UdpSocket { } Mode::RecvFrom(socket) => { let (n, addr) = socket.recv_from(&mut buf).await?; - (n, addr.into()) + (n, addr) } }; buf.truncate(n); @@ -196,8 +227,8 @@ impl UdpSocket { } } - pub(crate) fn local_address(&self) -> Result { - if matches!(self.udp_state, UdpState::Default) { + pub(crate) fn local_address(&self) -> Result { + if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) { return Err(ErrorCode::InvalidState); } let addr = self @@ -207,7 +238,7 @@ impl UdpSocket { Ok(addr.into()) } - pub(crate) fn remote_address(&self) -> Result { + pub(crate) fn remote_address(&self) -> Result { if !matches!(self.udp_state, UdpState::Connected(..)) { return Err(ErrorCode::InvalidState); } @@ -218,11 +249,8 @@ impl UdpSocket { Ok(addr.into()) } - pub(crate) fn address_family(&self) -> IpAddressFamily { - match self.family { - SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, - SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, - } + pub(crate) fn address_family(&self) -> SocketAddressFamily { + self.family } pub(crate) fn unicast_hop_limit(&self) -> Result { @@ -254,6 +282,18 @@ impl UdpSocket { set_send_buffer_size(&self.socket, value)?; Ok(()) } + + pub(crate) fn socket(&self) -> &Arc { + &self.socket + } + + pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> { + self.socket_addr_check.as_ref() + } + + pub(crate) fn set_socket_addr_check(&mut self, check: Option) { + self.socket_addr_check = check; + } } async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> { diff --git a/crates/wasi/src/sockets/util.rs b/crates/wasi/src/sockets/util.rs index dafe6e25845e..cb1043f48f3e 100644 --- a/crates/wasi/src/sockets/util.rs +++ b/crates/wasi/src/sockets/util.rs @@ -27,6 +27,7 @@ pub enum ErrorCode { ConnectionReset, ConnectionAborted, DatagramTooLarge, + NotInProgress, } impl fmt::Display for ErrorCode { From ff093a8c1e55c19c622a59d04292f2f075897eb5 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 5 Aug 2025 12:19:38 -0700 Subject: [PATCH 2/6] Review comments --- crates/wasi/src/p2/host/udp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/wasi/src/p2/host/udp.rs b/crates/wasi/src/p2/host/udp.rs index 508f744c52f7..c7cd34108dd9 100644 --- a/crates/wasi/src/p2/host/udp.rs +++ b/crates/wasi/src/p2/host/udp.rs @@ -28,7 +28,7 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { let socket = self.table.get_mut(&this)?; socket.bind(local_address)?; - socket.set_socket_addr_check(Some(check.clone())); + socket.set_socket_addr_check(Some(check)); Ok(()) } From 296a41974f4242104e9f614e48ff463203ff7865 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 5 Aug 2025 12:22:24 -0700 Subject: [PATCH 3/6] Fix p3-less warnings --- crates/wasi/src/sockets/udp.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/crates/wasi/src/sockets/udp.rs b/crates/wasi/src/sockets/udp.rs index 95c0af1bcb14..2e2c1e1b8f23 100644 --- a/crates/wasi/src/sockets/udp.rs +++ b/crates/wasi/src/sockets/udp.rs @@ -4,9 +4,8 @@ use crate::sockets::util::{ receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket, }; -use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx}; +use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx}; use cap_net_ext::AddressFamily; -use core::future::Future; use io_lifetimes::AsSocketlike as _; use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _}; use rustix::io::Errno; @@ -32,6 +31,10 @@ enum UdpState { Bound, /// The socket is "connected" to a peer address. + #[cfg_attr( + not(feature = "p3"), + expect(dead_code, reason = "p2 has its own way of managing sending/receiving") + )] Connected(SocketAddr), } @@ -160,6 +163,7 @@ impl UdpSocket { Ok(()) } + #[cfg(feature = "p3")] pub(crate) fn send(&self, buf: Vec) -> impl Future> + use<> { let socket = if let UdpState::Connected(..) = self.udp_state { Ok(Arc::clone(&self.socket)) @@ -172,6 +176,7 @@ impl UdpSocket { } } + #[cfg(feature = "p3")] pub(crate) fn send_to( &self, buf: Vec, @@ -197,6 +202,7 @@ impl UdpSocket { } } + #[cfg(feature = "p3")] pub(crate) fn receive( &self, ) -> impl Future, SocketAddr), ErrorCode>> + use<> { @@ -211,7 +217,7 @@ impl UdpSocket { }; async move { let socket = socket?; - let mut buf = vec![0; MAX_UDP_DATAGRAM_SIZE]; + let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE]; let (n, addr) = match socket { Mode::Recv(socket, addr) => { let n = socket.recv(&mut buf).await?; @@ -296,6 +302,7 @@ impl UdpSocket { } } +#[cfg(feature = "p3")] async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> { let n = socket.send(buf).await?; // From Rust stdlib docs: @@ -310,6 +317,7 @@ async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCod } } +#[cfg(feature = "p3")] async fn send_to( socket: &tokio::net::UdpSocket, buf: &[u8], From 75040a1e17b2226d640f72dea4a2c9ec37224545 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 5 Aug 2025 13:06:07 -0700 Subject: [PATCH 4/6] Update UDP denial test --- crates/test-programs/src/bin/cli_no_udp.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/test-programs/src/bin/cli_no_udp.rs b/crates/test-programs/src/bin/cli_no_udp.rs index 1dd0240612db..221822319d14 100644 --- a/crates/test-programs/src/bin/cli_no_udp.rs +++ b/crates/test-programs/src/bin/cli_no_udp.rs @@ -8,9 +8,8 @@ fn main() { let net = Network::default(); let family = IpAddressFamily::Ipv4; let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); - let sock = UdpSocket::new(family).unwrap(); - - let bind = sock.blocking_bind(&net, remote1); - eprintln!("Result of binding: {bind:?}"); - assert!(matches!(bind, Err(ErrorCode::AccessDenied))); + assert!(matches!( + UdpSocket::new(family), + Err(ErrorCode::AccessDenied) + )); } From 363bfc5b15d72557b8581d6b5ef4edea55562977 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 5 Aug 2025 13:06:13 -0700 Subject: [PATCH 5/6] Fix some clippy issues --- crates/wasi/src/p2/host/udp.rs | 1 - crates/wasi/src/sockets/udp.rs | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/wasi/src/p2/host/udp.rs b/crates/wasi/src/p2/host/udp.rs index c7cd34108dd9..23a3b53c1ea4 100644 --- a/crates/wasi/src/p2/host/udp.rs +++ b/crates/wasi/src/p2/host/udp.rs @@ -75,7 +75,6 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { // Step #2: (Re)connect if let Some(connect_addr) = remote_address { - let connect_addr = SocketAddr::from(connect_addr); let Some(check) = socket.socket_addr_check() else { return Err(ErrorCode::InvalidState.into()); }; diff --git a/crates/wasi/src/sockets/udp.rs b/crates/wasi/src/sockets/udp.rs index 2e2c1e1b8f23..82b6ea270755 100644 --- a/crates/wasi/src/sockets/udp.rs +++ b/crates/wasi/src/sockets/udp.rs @@ -136,7 +136,7 @@ impl UdpSocket { match self.udp_state { UdpState::Bound | UdpState::Connected(_) => {} - _ => return Err(ErrorCode::InvalidState.into()), + _ => return Err(ErrorCode::InvalidState), } // We disconnect & (re)connect in two distinct steps for two reasons: @@ -241,7 +241,7 @@ impl UdpSocket { .socket .as_socketlike_view::() .local_addr()?; - Ok(addr.into()) + Ok(addr) } pub(crate) fn remote_address(&self) -> Result { @@ -252,7 +252,7 @@ impl UdpSocket { .socket .as_socketlike_view::() .peer_addr()?; - Ok(addr.into()) + Ok(addr) } pub(crate) fn address_family(&self) -> SocketAddressFamily { From 4aaa7137db686dc7b8bcc78885fa1399c16dba84 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 5 Aug 2025 17:03:39 -0700 Subject: [PATCH 6/6] Fix no-udp test warnings --- crates/test-programs/src/bin/cli_no_udp.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/crates/test-programs/src/bin/cli_no_udp.rs b/crates/test-programs/src/bin/cli_no_udp.rs index 221822319d14..a225336b6832 100644 --- a/crates/test-programs/src/bin/cli_no_udp.rs +++ b/crates/test-programs/src/bin/cli_no_udp.rs @@ -1,15 +1,11 @@ //! This test assumes that it will be run without udp support enabled -use test_programs::wasi::sockets::{ - network::IpAddress, - udp::{ErrorCode, IpAddressFamily, IpSocketAddress, Network, UdpSocket}, -}; + +#![deny(warnings)] +use test_programs::wasi::sockets::udp::{ErrorCode, IpAddressFamily, UdpSocket}; fn main() { - let net = Network::default(); - let family = IpAddressFamily::Ipv4; - let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); assert!(matches!( - UdpSocket::new(family), + UdpSocket::new(IpAddressFamily::Ipv4), Err(ErrorCode::AccessDenied) )); }