diff --git a/crates/test-programs/src/bin/cli_no_udp.rs b/crates/test-programs/src/bin/cli_no_udp.rs index 1dd0240612db..a225336b6832 100644 --- a/crates/test-programs/src/bin/cli_no_udp.rs +++ b/crates/test-programs/src/bin/cli_no_udp.rs @@ -1,16 +1,11 @@ //! This test assumes that it will be run without udp support enabled -use test_programs::wasi::sockets::{ - network::IpAddress, - udp::{ErrorCode, IpAddressFamily, IpSocketAddress, Network, UdpSocket}, -}; -fn main() { - let net = Network::default(); - let family = IpAddressFamily::Ipv4; - let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); - let sock = UdpSocket::new(family).unwrap(); +#![deny(warnings)] +use test_programs::wasi::sockets::udp::{ErrorCode, IpAddressFamily, UdpSocket}; - let bind = sock.blocking_bind(&net, remote1); - eprintln!("Result of binding: {bind:?}"); - assert!(matches!(bind, Err(ErrorCode::AccessDenied))); +fn main() { + assert!(matches!( + UdpSocket::new(IpAddressFamily::Ipv4), + Err(ErrorCode::AccessDenied) + )); } diff --git a/crates/wasi/src/p2/bindings.rs b/crates/wasi/src/p2/bindings.rs index 454bda9056b2..7fb5e7b136da 100644 --- a/crates/wasi/src/p2/bindings.rs +++ b/crates/wasi/src/p2/bindings.rs @@ -173,7 +173,7 @@ pub mod sync { "wasi:sockets/tcp/tcp-socket": super::super::sockets::tcp::TcpSocket, "wasi:sockets/udp/incoming-datagram-stream": super::super::sockets::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": super::super::sockets::udp::OutgoingDatagramStream, - "wasi:sockets/udp/udp-socket": super::super::sockets::udp::UdpSocket, + "wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket, // Error host trait from wasmtime-wasi-io is synchronous, so we can alias it "wasi:io/error": wasmtime_wasi_io::bindings::wasi::io::error, @@ -394,7 +394,7 @@ mod async_io { // this crate "wasi:sockets/network/network": crate::p2::network::Network, "wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket, - "wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket, + "wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket, "wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream, "wasi:sockets/ip-name-lookup/resolve-address-stream": crate::p2::ip_name_lookup::ResolveAddressStream, diff --git a/crates/wasi/src/p2/host/udp.rs b/crates/wasi/src/p2/host/udp.rs index 579499d4b070..23a3b53c1ea4 100644 --- a/crates/wasi/src/p2/host/udp.rs +++ b/crates/wasi/src/p2/host/udp.rs @@ -1,19 +1,13 @@ use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}; use crate::p2::bindings::sockets::udp; -use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState}; +use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState}; use crate::p2::{Pollable, SocketError, SocketResult}; -use crate::sockets::util::{ - get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address, - receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, - set_unicast_hop_limit, udp_bind, udp_disconnect, -}; +use crate::sockets::util::{is_valid_address_family, is_valid_remote_address}; use crate::sockets::{ - MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView, + MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView, }; use anyhow::anyhow; use async_trait::async_trait; -use io_lifetimes::AsSocketlike; -use rustix::io::Errno; use std::net::SocketAddr; use tokio::io::Interest; use wasmtime::component::Resource; @@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { network: Resource, local_address: IpSocketAddress, ) -> SocketResult<()> { - self.ctx.allowed_network_uses.check_allowed_udp()?; - - match self.table.get(&this)?.udp_state { - UdpState::Default => {} - UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), - UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()), - } - - // Set the socket addr check on the socket so later functions have access to it through the socket handle + let local_address = SocketAddr::from(local_address); let check = self.table.get(&network)?.socket_addr_check.clone(); - self.table - .get_mut(&this)? - .socket_addr_check - .replace(check.clone()); - - let socket = self.table.get(&this)?; - let local_address: SocketAddr = local_address.into(); - - if !is_valid_address_family(local_address.ip(), socket.family) { - return Err(ErrorCode::InvalidArgument.into()); - } - - { - check.check(local_address, SocketAddrUse::UdpBind).await?; - - // Perform the OS bind call. - udp_bind(socket.udp_socket(), local_address)?; - } + check.check(local_address, SocketAddrUse::UdpBind).await?; let socket = self.table.get_mut(&this)?; - socket.udp_state = UdpState::BindStarted; + socket.bind(local_address)?; + socket.set_socket_addr_check(Some(check)); Ok(()) } fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { - let socket = self.table.get_mut(&this)?; - - match socket.udp_state { - UdpState::BindStarted => { - socket.udp_state = UdpState::Bound; - Ok(()) - } - _ => Err(ErrorCode::NotInProgress.into()), - } + self.table.get_mut(&this)?.finish_bind()?; + Ok(()) } async fn stream( @@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { let socket = self.table.get_mut(&this)?; let remote_address = remote_address.map(SocketAddr::from); - match socket.udp_state { - UdpState::Bound | UdpState::Connected => {} - _ => return Err(ErrorCode::InvalidState.into()), + if !socket.is_bound() { + return Err(ErrorCode::InvalidState.into()); } // We disconnect & (re)connect in two distinct steps for two reasons: @@ -107,48 +69,29 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { // if there isn't a disconnect in between. // Step #1: Disconnect - if let UdpState::Connected = socket.udp_state { - udp_disconnect(socket.udp_socket())?; - socket.udp_state = UdpState::Bound; + if socket.is_connected() { + socket.disconnect()?; } // Step #2: (Re)connect if let Some(connect_addr) = remote_address { - let Some(check) = socket.socket_addr_check.as_ref() else { + let Some(check) = socket.socket_addr_check() else { return Err(ErrorCode::InvalidState.into()); }; - if !is_valid_remote_address(connect_addr) - || !is_valid_address_family(connect_addr.ip(), socket.family) - { - return Err(ErrorCode::InvalidArgument.into()); - } check.check(connect_addr, SocketAddrUse::UdpConnect).await?; - - rustix::net::connect(socket.udp_socket(), &connect_addr).map_err( - |error| match error { - Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `bind` implementation. - Errno::INPROGRESS => { - tracing::debug!( - "UDP connect returned EINPROGRESS, which should never happen" - ); - ErrorCode::Unknown - } - _ => ErrorCode::from(error), - }, - )?; - socket.udp_state = UdpState::Connected; + socket.connect(connect_addr)?; } let incoming_stream = IncomingDatagramStream { - inner: socket.inner.clone(), + inner: socket.socket().clone(), remote_address, }; let outgoing_stream = OutgoingDatagramStream { - inner: socket.inner.clone(), + inner: socket.socket().clone(), remote_address, - family: socket.family, + family: socket.address_family(), send_state: SendState::Idle, - socket_addr_check: socket.socket_addr_check.clone(), + socket_addr_check: socket.socket_addr_check().cloned(), }; Ok(( @@ -159,33 +102,12 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { fn local_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - match socket.udp_state { - UdpState::Default => return Err(ErrorCode::InvalidState.into()), - UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()), - _ => {} - } - - let addr = socket - .udp_socket() - .as_socketlike_view::() - .local_addr()?; - Ok(addr.into()) + Ok(socket.local_address()?.into()) } fn remote_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - match socket.udp_state { - UdpState::Connected => {} - _ => return Err(ErrorCode::InvalidState.into()), - } - - let addr = socket - .udp_socket() - .as_socketlike_view::() - .peer_addr()?; - Ok(addr.into()) + Ok(socket.remote_address()?.into()) } fn address_family( @@ -193,22 +115,12 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { this: Resource, ) -> Result { let socket = self.table.get(&this)?; - - match socket.family { - SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), - SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), - } + Ok(socket.address_family().into()) } fn unicast_hop_limit(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let ttl = match socket.family { - SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?, - SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?, - }; - - Ok(ttl) + Ok(socket.unicast_hop_limit()?) } fn set_unicast_hop_limit( @@ -217,17 +129,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { value: u8, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?; - + socket.set_unicast_hop_limit(value)?; Ok(()) } fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let value = receive_buffer_size(socket.udp_socket())?; - Ok(value) + Ok(socket.receive_buffer_size()?) } fn set_receive_buffer_size( @@ -236,33 +144,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { value: u64, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_receive_buffer_size(socket.udp_socket(), value)?; + socket.set_receive_buffer_size(value)?; Ok(()) } fn send_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - let value = send_buffer_size(socket.udp_socket())?; - Ok(value) + Ok(socket.send_buffer_size()?) } - fn set_send_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> SocketResult<()> { + fn set_send_buffer_size(&mut self, this: Resource, value: u64) -> SocketResult<()> { let socket = self.table.get(&this)?; - - set_send_buffer_size(socket.udp_socket(), value)?; + socket.set_send_buffer_size(value)?; Ok(()) } - fn subscribe( - &mut self, - this: Resource, - ) -> anyhow::Result> { + fn subscribe(&mut self, this: Resource) -> anyhow::Result> { wasmtime_wasi_io::poll::subscribe(self.table, this) } @@ -276,6 +173,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> { } } +#[async_trait] +impl Pollable for UdpSocket { + async fn ready(&mut self) { + // None of the socket-level operations block natively + } +} + impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { fn receive( &mut self, @@ -504,6 +408,15 @@ impl Pollable for OutgoingDatagramStream { } } +impl From for IpAddressFamily { + fn from(family: SocketAddressFamily) -> IpAddressFamily { + match family { + SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, + SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, + } + } +} + pub mod sync { use wasmtime::component::Resource; diff --git a/crates/wasi/src/p2/host/udp_create_socket.rs b/crates/wasi/src/p2/host/udp_create_socket.rs index 3f4a05c50e00..ac05dc51d838 100644 --- a/crates/wasi/src/p2/host/udp_create_socket.rs +++ b/crates/wasi/src/p2/host/udp_create_socket.rs @@ -1,6 +1,6 @@ use crate::p2::SocketResult; use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket}; -use crate::p2::udp::UdpSocket; +use crate::sockets::UdpSocket; use crate::sockets::WasiSocketsCtxView; use wasmtime::component::Resource; @@ -9,7 +9,7 @@ impl udp_create_socket::Host for WasiSocketsCtxView<'_> { &mut self, address_family: IpAddressFamily, ) -> SocketResult> { - let socket = UdpSocket::new(address_family.into())?; + let socket = UdpSocket::new(self.ctx, address_family.into())?; let socket = self.table.push(socket)?; Ok(socket) } diff --git a/crates/wasi/src/p2/network.rs b/crates/wasi/src/p2/network.rs index 68b3f7adb09a..8534d0320485 100644 --- a/crates/wasi/src/p2/network.rs +++ b/crates/wasi/src/p2/network.rs @@ -48,6 +48,7 @@ impl From for ErrorCode { crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress, } } } diff --git a/crates/wasi/src/p2/udp.rs b/crates/wasi/src/p2/udp.rs index a066fb97eb15..0165a9550410 100644 --- a/crates/wasi/src/p2/udp.rs +++ b/crates/wasi/src/p2/udp.rs @@ -1,94 +1,6 @@ -use crate::runtime::with_ambient_tokio_runtime; -use crate::sockets::util::udp_socket; use crate::sockets::{SocketAddrCheck, SocketAddressFamily}; -use async_trait::async_trait; -use cap_net_ext::AddressFamily; -use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; -use std::io; use std::net::SocketAddr; use std::sync::Arc; -use wasmtime_wasi_io::poll::Pollable; - -/// The state of a UDP socket. -/// -/// This represents the various states a socket can be in during the -/// activities of binding, and connecting. -pub(crate) enum UdpState { - /// The initial state for a newly-created socket. - Default, - - /// Binding started via `start_bind`. - BindStarted, - - /// Binding finished via `finish_bind`. The socket has an address but - /// is not yet listening for connections. - Bound, - - /// The socket is "connected" to a peer address. - Connected, -} - -/// A host UDP socket, plus associated bookkeeping. -/// -/// The inner state is wrapped in an Arc because the same underlying socket is -/// used for implementing the stream types. -pub struct UdpSocket { - /// The part of a `UdpSocket` which is reference-counted so that we - /// can pass it to async tasks. - pub(crate) inner: Arc, - - /// The current state in the bind/connect progression. - pub(crate) udp_state: UdpState, - - /// Socket address family. - pub(crate) family: SocketAddressFamily, - - /// The check of allowed addresses - pub(crate) socket_addr_check: Option, -} - -#[async_trait] -impl Pollable for UdpSocket { - async fn ready(&mut self) { - // None of the socket-level operations block natively - } -} - -impl UdpSocket { - /// Create a new socket in the given family. - pub fn new(family: AddressFamily) -> io::Result { - // Create a new host socket and set it to non-blocking, which is needed - // by our async implementation. - let fd = udp_socket(family)?; - - let socket_address_family = match family { - AddressFamily::Ipv4 => SocketAddressFamily::Ipv4, - AddressFamily::Ipv6 => { - rustix::net::sockopt::set_ipv6_v6only(&fd, true)?; - SocketAddressFamily::Ipv6 - } - }; - - let socket = Self::setup_tokio_udp_socket(fd.into())?; - - Ok(UdpSocket { - inner: Arc::new(socket), - udp_state: UdpState::Default, - family: socket_address_family, - socket_addr_check: None, - }) - } - - fn setup_tokio_udp_socket(fd: rustix::fd::OwnedFd) -> io::Result { - let std_socket = - unsafe { std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike()) }; - with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket)) - } - - pub fn udp_socket(&self) -> &tokio::net::UdpSocket { - &self.inner - } -} pub struct IncomingDatagramStream { pub(crate) inner: Arc, diff --git a/crates/wasi/src/p3/bindings.rs b/crates/wasi/src/p3/bindings.rs index ae0c1eb6902a..89c9bd5888b8 100644 --- a/crates/wasi/src/p3/bindings.rs +++ b/crates/wasi/src/p3/bindings.rs @@ -95,7 +95,7 @@ mod generated { "wasi:cli/terminal-input/terminal-input": crate::p3::cli::TerminalInput, "wasi:cli/terminal-output/terminal-output": crate::p3::cli::TerminalOutput, "wasi:sockets/types/tcp-socket": crate::p3::sockets::tcp::TcpSocket, - "wasi:sockets/types/udp-socket": crate::p3::sockets::udp::UdpSocket, + "wasi:sockets/types/udp-socket": crate::sockets::UdpSocket, }, trappable_error_type: { "wasi:sockets/types/error-code" => crate::p3::sockets::SocketError, diff --git a/crates/wasi/src/p3/sockets/conv.rs b/crates/wasi/src/p3/sockets/conv.rs index ee4920dfa0d8..03e1aff33a15 100644 --- a/crates/wasi/src/p3/sockets/conv.rs +++ b/crates/wasi/src/p3/sockets/conv.rs @@ -1,13 +1,12 @@ +use crate::p3::bindings::sockets::types; +use crate::p3::sockets::SocketError; +use crate::sockets::SocketAddressFamily; +use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; use core::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; - -use std::net::ToSocketAddrs; - use rustix::io::Errno; +use std::net::ToSocketAddrs; use tracing::debug; -use crate::p3::bindings::sockets::types; -use crate::sockets::util::{from_ipv4_addr, from_ipv6_addr, to_ipv4_addr, to_ipv6_addr}; - impl From for types::IpAddress { fn from(addr: IpAddr) -> Self { match addr { @@ -123,6 +122,15 @@ impl From for types::IpAddressFamily { } } +impl From for types::IpAddressFamily { + fn from(family: SocketAddressFamily) -> Self { + match family { + SocketAddressFamily::Ipv4 => Self::Ipv4, + SocketAddressFamily::Ipv6 => Self::Ipv6, + } + } +} + impl From for types::ErrorCode { fn from(value: std::io::Error) -> Self { (&value).into() @@ -222,6 +230,13 @@ impl From for types::ErrorCode { crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset, crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, + crate::sockets::util::ErrorCode::NotInProgress => Self::InvalidState, } } } + +impl From for SocketError { + fn from(code: crate::sockets::util::ErrorCode) -> Self { + SocketError::from(types::ErrorCode::from(code)) + } +} diff --git a/crates/wasi/src/p3/sockets/host/types/udp.rs b/crates/wasi/src/p3/sockets/host/types/udp.rs index e6a48b92ac22..e9f4ecf0c39a 100644 --- a/crates/wasi/src/p3/sockets/host/types/udp.rs +++ b/crates/wasi/src/p3/sockets/host/types/udp.rs @@ -1,20 +1,14 @@ +use super::is_addr_allowed; use crate::TrappableError; use crate::p3::bindings::sockets::types::{ ErrorCode, HostUdpSocket, HostUdpSocketWithStore, IpAddressFamily, IpSocketAddress, }; -use crate::p3::sockets::udp::UdpSocket; use crate::p3::sockets::{SocketResult, WasiSockets}; -use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, WasiSocketsCtxView}; -use anyhow::Context as _; -use core::net::SocketAddr; +use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, UdpSocket, WasiSocketsCtxView}; +use anyhow::Context; +use std::net::SocketAddr; use wasmtime::component::{Accessor, Resource, ResourceTable}; -use super::is_addr_allowed; - -fn is_udp_allowed(store: &Accessor) -> bool { - store.with(|mut view| view.get().ctx.allowed_network_uses.udp) -} - fn get_socket<'a>( table: &'a ResourceTable, socket: &'a Resource, @@ -42,14 +36,13 @@ impl HostUdpSocketWithStore for WasiSockets { local_address: IpSocketAddress, ) -> SocketResult<()> { let local_address = SocketAddr::from(local_address); - if !is_udp_allowed(store) - || !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await - { + if !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await { return Err(ErrorCode::AccessDenied.into()); } store.with(|mut view| { let socket = get_socket_mut(view.get().table, &socket)?; socket.bind(local_address)?; + socket.finish_bind()?; Ok(()) }) } @@ -60,9 +53,7 @@ impl HostUdpSocketWithStore for WasiSockets { remote_address: IpSocketAddress, ) -> SocketResult<()> { let remote_address = SocketAddr::from(remote_address); - if !is_udp_allowed(store) - || !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await - { + if !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await { return Err(ErrorCode::AccessDenied.into()); } store.with(|mut view| { @@ -81,9 +72,6 @@ impl HostUdpSocketWithStore for WasiSockets { if data.len() > MAX_UDP_DATAGRAM_SIZE { return Err(ErrorCode::DatagramTooLarge.into()); } - if !is_udp_allowed(store) { - return Err(ErrorCode::AccessDenied.into()); - } if let Some(addr) = remote_address { let addr = SocketAddr::from(addr); if !is_addr_allowed(store, addr, SocketAddrUse::UdpOutgoingDatagram).await { @@ -107,19 +95,16 @@ impl HostUdpSocketWithStore for WasiSockets { store: &Accessor, socket: Resource, ) -> SocketResult<(Vec, IpSocketAddress)> { - if !is_udp_allowed(store) { - return Err(ErrorCode::AccessDenied.into()); - } let fut = store .with(|mut view| get_socket(view.get().table, &socket).map(|sock| sock.receive()))?; let (result, addr) = fut.await?; - Ok((result, addr)) + Ok((result, addr.into())) } } impl HostUdpSocket for WasiSocketsCtxView<'_> { fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { - let socket = UdpSocket::new(address_family.into()).context("failed to create socket")?; + let socket = UdpSocket::new(self.ctx, address_family.into())?; self.table .push(socket) .context("failed to push socket resource to table") @@ -133,17 +118,17 @@ impl HostUdpSocket for WasiSocketsCtxView<'_> { fn local_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.local_address()?) + Ok(sock.local_address()?.into()) } fn remote_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.remote_address()?) + Ok(sock.remote_address()?.into()) } fn address_family(&mut self, socket: Resource) -> wasmtime::Result { let sock = get_socket(self.table, &socket)?; - Ok(sock.address_family()) + Ok(sock.address_family().into()) } fn unicast_hop_limit(&mut self, socket: Resource) -> SocketResult { diff --git a/crates/wasi/src/p3/sockets/mod.rs b/crates/wasi/src/p3/sockets/mod.rs index ade030de49bf..07b4db6bd3d1 100644 --- a/crates/wasi/src/p3/sockets/mod.rs +++ b/crates/wasi/src/p3/sockets/mod.rs @@ -6,7 +6,6 @@ use wasmtime::component::Linker; mod conv; mod host; pub mod tcp; -pub mod udp; pub type SocketResult = Result; pub type SocketError = TrappableError; diff --git a/crates/wasi/src/sockets/mod.rs b/crates/wasi/src/sockets/mod.rs index a7c8fbff2aa5..ce78e78f08c6 100644 --- a/crates/wasi/src/sockets/mod.rs +++ b/crates/wasi/src/sockets/mod.rs @@ -5,8 +5,11 @@ use std::pin::Pin; use std::sync::Arc; use wasmtime::component::{HasData, ResourceTable}; +mod udp; pub(crate) mod util; +pub use udp::UdpSocket; + pub(crate) struct WasiSockets; impl HasData for WasiSockets { diff --git a/crates/wasi/src/p3/sockets/udp.rs b/crates/wasi/src/sockets/udp.rs similarity index 74% rename from crates/wasi/src/p3/sockets/udp.rs rename to crates/wasi/src/sockets/udp.rs index e55054334d71..82b6ea270755 100644 --- a/crates/wasi/src/p3/sockets/udp.rs +++ b/crates/wasi/src/sockets/udp.rs @@ -1,24 +1,20 @@ -use core::future::Future; -use core::net::SocketAddr; - -use std::sync::Arc; - +use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::{ + ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, + receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size, + set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket, +}; +use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx}; use cap_net_ext::AddressFamily; use io_lifetimes::AsSocketlike as _; use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _}; use rustix::io::Errno; use rustix::net::connect; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; use tracing::debug; -use crate::p3::bindings::sockets::types::{ErrorCode, IpAddressFamily, IpSocketAddress}; -use crate::runtime::with_ambient_tokio_runtime; -use crate::sockets::util::{ - get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, receive_buffer_size, - send_buffer_size, set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, - udp_bind, udp_disconnect, udp_socket, -}; -use crate::sockets::{MAX_UDP_DATAGRAM_SIZE, SocketAddressFamily}; - /// The state of a UDP socket. /// /// This represents the various states a socket can be in during the @@ -27,11 +23,18 @@ enum UdpState { /// The initial state for a newly-created socket. Default, + /// TODO + BindStarted, + /// Binding finished via `finish_bind`. The socket has an address but /// is not yet listening for connections. Bound, /// The socket is "connected" to a peer address. + #[cfg_attr( + not(feature = "p3"), + expect(dead_code, reason = "p2 has its own way of managing sending/receiving") + )] Connected(SocketAddr), } @@ -47,11 +50,17 @@ pub struct UdpSocket { /// Socket address family. family: SocketAddressFamily, + + /// If set, use this custom check for addrs, otherwise use what's in + /// `WasiSocketsCtx`. + socket_addr_check: Option, } impl UdpSocket { /// Create a new socket in the given family. - pub(crate) fn new(family: AddressFamily) -> std::io::Result { + pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> io::Result { + cx.allowed_network_uses.check_allowed_udp()?; + // Delegate socket creation to cap_net_ext. They handle a couple of things for us: // - On Windows: call WSAStartup if not done before. // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, @@ -77,6 +86,7 @@ impl UdpSocket { socket: Arc::new(socket), udp_state: UdpState::Default, family: socket_address_family, + socket_addr_check: None, }) } @@ -88,12 +98,30 @@ impl UdpSocket { return Err(ErrorCode::InvalidArgument); } udp_bind(&self.socket, addr)?; - self.udp_state = UdpState::Bound; + self.udp_state = UdpState::BindStarted; Ok(()) } + pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> { + match self.udp_state { + UdpState::BindStarted => { + self.udp_state = UdpState::Bound; + Ok(()) + } + _ => Err(ErrorCode::NotInProgress), + } + } + + pub(crate) fn is_connected(&self) -> bool { + matches!(self.udp_state, UdpState::Connected(..)) + } + + pub(crate) fn is_bound(&self) -> bool { + matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound) + } + pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> { - if !matches!(self.udp_state, UdpState::Connected(..)) { + if !self.is_connected() { return Err(ErrorCode::InvalidState); } udp_disconnect(&self.socket)?; @@ -106,6 +134,11 @@ impl UdpSocket { return Err(ErrorCode::InvalidArgument); } + match self.udp_state { + UdpState::Bound | UdpState::Connected(_) => {} + _ => return Err(ErrorCode::InvalidState), + } + // We disconnect & (re)connect in two distinct steps for two reasons: // - To leave our socket instance in a consistent state in case the // connect fails. @@ -130,6 +163,7 @@ impl UdpSocket { Ok(()) } + #[cfg(feature = "p3")] pub(crate) fn send(&self, buf: Vec) -> impl Future> + use<> { let socket = if let UdpState::Connected(..) = self.udp_state { Ok(Arc::clone(&self.socket)) @@ -142,6 +176,7 @@ impl UdpSocket { } } + #[cfg(feature = "p3")] pub(crate) fn send_to( &self, buf: Vec, @@ -152,6 +187,7 @@ impl UdpSocket { SendTo(Arc, SocketAddr), } let socket = match &self.udp_state { + UdpState::BindStarted => Err(ErrorCode::InvalidState), UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)), UdpState::Connected(caddr) if addr == *caddr => { Ok(Mode::Send(Arc::clone(&self.socket))) @@ -166,21 +202,22 @@ impl UdpSocket { } } + #[cfg(feature = "p3")] pub(crate) fn receive( &self, - ) -> impl Future, IpSocketAddress), ErrorCode>> + use<> { + ) -> impl Future, SocketAddr), ErrorCode>> + use<> { enum Mode { - Recv(Arc, IpSocketAddress), + Recv(Arc, SocketAddr), RecvFrom(Arc), } let socket = match self.udp_state { - UdpState::Default => Err(ErrorCode::InvalidState), + UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState), UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))), UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr.into())), }; async move { let socket = socket?; - let mut buf = vec![0; MAX_UDP_DATAGRAM_SIZE]; + let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE]; let (n, addr) = match socket { Mode::Recv(socket, addr) => { let n = socket.recv(&mut buf).await?; @@ -188,7 +225,7 @@ impl UdpSocket { } Mode::RecvFrom(socket) => { let (n, addr) = socket.recv_from(&mut buf).await?; - (n, addr.into()) + (n, addr) } }; buf.truncate(n); @@ -196,18 +233,18 @@ impl UdpSocket { } } - pub(crate) fn local_address(&self) -> Result { - if matches!(self.udp_state, UdpState::Default) { + pub(crate) fn local_address(&self) -> Result { + if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) { return Err(ErrorCode::InvalidState); } let addr = self .socket .as_socketlike_view::() .local_addr()?; - Ok(addr.into()) + Ok(addr) } - pub(crate) fn remote_address(&self) -> Result { + pub(crate) fn remote_address(&self) -> Result { if !matches!(self.udp_state, UdpState::Connected(..)) { return Err(ErrorCode::InvalidState); } @@ -215,14 +252,11 @@ impl UdpSocket { .socket .as_socketlike_view::() .peer_addr()?; - Ok(addr.into()) + Ok(addr) } - pub(crate) fn address_family(&self) -> IpAddressFamily { - match self.family { - SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, - SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, - } + pub(crate) fn address_family(&self) -> SocketAddressFamily { + self.family } pub(crate) fn unicast_hop_limit(&self) -> Result { @@ -254,8 +288,21 @@ impl UdpSocket { set_send_buffer_size(&self.socket, value)?; Ok(()) } + + pub(crate) fn socket(&self) -> &Arc { + &self.socket + } + + pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> { + self.socket_addr_check.as_ref() + } + + pub(crate) fn set_socket_addr_check(&mut self, check: Option) { + self.socket_addr_check = check; + } } +#[cfg(feature = "p3")] async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> { let n = socket.send(buf).await?; // From Rust stdlib docs: @@ -270,6 +317,7 @@ async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCod } } +#[cfg(feature = "p3")] async fn send_to( socket: &tokio::net::UdpSocket, buf: &[u8], diff --git a/crates/wasi/src/sockets/util.rs b/crates/wasi/src/sockets/util.rs index dafe6e25845e..cb1043f48f3e 100644 --- a/crates/wasi/src/sockets/util.rs +++ b/crates/wasi/src/sockets/util.rs @@ -27,6 +27,7 @@ pub enum ErrorCode { ConnectionReset, ConnectionAborted, DatagramTooLarge, + NotInProgress, } impl fmt::Display for ErrorCode {