Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions crates/test-programs/src/bin/cli_no_udp.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
//! This test assumes that it will be run without udp support enabled
use test_programs::wasi::sockets::{
network::IpAddress,
udp::{ErrorCode, IpAddressFamily, IpSocketAddress, Network, UdpSocket},
};

fn main() {
let net = Network::default();
let family = IpAddressFamily::Ipv4;
let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321);
let sock = UdpSocket::new(family).unwrap();
#![deny(warnings)]
use test_programs::wasi::sockets::udp::{ErrorCode, IpAddressFamily, UdpSocket};

let bind = sock.blocking_bind(&net, remote1);
eprintln!("Result of binding: {bind:?}");
assert!(matches!(bind, Err(ErrorCode::AccessDenied)));
fn main() {
assert!(matches!(
UdpSocket::new(IpAddressFamily::Ipv4),
Err(ErrorCode::AccessDenied)
));
}
4 changes: 2 additions & 2 deletions crates/wasi/src/p2/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ pub mod sync {
"wasi:sockets/tcp/tcp-socket": super::super::sockets::tcp::TcpSocket,
"wasi:sockets/udp/incoming-datagram-stream": super::super::sockets::udp::IncomingDatagramStream,
"wasi:sockets/udp/outgoing-datagram-stream": super::super::sockets::udp::OutgoingDatagramStream,
"wasi:sockets/udp/udp-socket": super::super::sockets::udp::UdpSocket,
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,

// Error host trait from wasmtime-wasi-io is synchronous, so we can alias it
"wasi:io/error": wasmtime_wasi_io::bindings::wasi::io::error,
Expand Down Expand Up @@ -394,7 +394,7 @@ mod async_io {
// this crate
"wasi:sockets/network/network": crate::p2::network::Network,
"wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket,
"wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket,
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,
"wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream,
"wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream,
"wasi:sockets/ip-name-lookup/resolve-address-stream": crate::p2::ip_name_lookup::ResolveAddressStream,
Expand Down
179 changes: 46 additions & 133 deletions crates/wasi/src/p2/host/udp.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
use crate::p2::bindings::sockets::udp;
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState};
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState};
use crate::p2::{Pollable, SocketError, SocketResult};
use crate::sockets::util::{
get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address,
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
set_unicast_hop_limit, udp_bind, udp_disconnect,
};
use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
use crate::sockets::{
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView,
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
};
use anyhow::anyhow;
use async_trait::async_trait;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use std::net::SocketAddr;
use tokio::io::Interest;
use wasmtime::component::Resource;
Expand All @@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
network: Resource<Network>,
local_address: IpSocketAddress,
) -> SocketResult<()> {
self.ctx.allowed_network_uses.check_allowed_udp()?;

match self.table.get(&this)?.udp_state {
UdpState::Default => {}
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()),
}

// Set the socket addr check on the socket so later functions have access to it through the socket handle
let local_address = SocketAddr::from(local_address);
let check = self.table.get(&network)?.socket_addr_check.clone();
self.table
.get_mut(&this)?
.socket_addr_check
.replace(check.clone());

let socket = self.table.get(&this)?;
let local_address: SocketAddr = local_address.into();

if !is_valid_address_family(local_address.ip(), socket.family) {
return Err(ErrorCode::InvalidArgument.into());
}

{
check.check(local_address, SocketAddrUse::UdpBind).await?;

// Perform the OS bind call.
udp_bind(socket.udp_socket(), local_address)?;
}
check.check(local_address, SocketAddrUse::UdpBind).await?;

let socket = self.table.get_mut(&this)?;
socket.udp_state = UdpState::BindStarted;
socket.bind(local_address)?;
socket.set_socket_addr_check(Some(check));

Ok(())
}

fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
let socket = self.table.get_mut(&this)?;

match socket.udp_state {
UdpState::BindStarted => {
socket.udp_state = UdpState::Bound;
Ok(())
}
_ => Err(ErrorCode::NotInProgress.into()),
}
self.table.get_mut(&this)?.finish_bind()?;
Ok(())
}

async fn stream(
Expand All @@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
let socket = self.table.get_mut(&this)?;
let remote_address = remote_address.map(SocketAddr::from);

match socket.udp_state {
UdpState::Bound | UdpState::Connected => {}
_ => return Err(ErrorCode::InvalidState.into()),
if !socket.is_bound() {
return Err(ErrorCode::InvalidState.into());
}

// We disconnect & (re)connect in two distinct steps for two reasons:
Expand All @@ -107,48 +69,29 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
// if there isn't a disconnect in between.

// Step #1: Disconnect
if let UdpState::Connected = socket.udp_state {
udp_disconnect(socket.udp_socket())?;
socket.udp_state = UdpState::Bound;
if socket.is_connected() {
socket.disconnect()?;
}

// Step #2: (Re)connect
if let Some(connect_addr) = remote_address {
let Some(check) = socket.socket_addr_check.as_ref() else {
let Some(check) = socket.socket_addr_check() else {
return Err(ErrorCode::InvalidState.into());
};
if !is_valid_remote_address(connect_addr)
|| !is_valid_address_family(connect_addr.ip(), socket.family)
{
return Err(ErrorCode::InvalidArgument.into());
}
check.check(connect_addr, SocketAddrUse::UdpConnect).await?;

rustix::net::connect(socket.udp_socket(), &connect_addr).map_err(
|error| match error {
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `bind` implementation.
Errno::INPROGRESS => {
tracing::debug!(
"UDP connect returned EINPROGRESS, which should never happen"
);
ErrorCode::Unknown
}
_ => ErrorCode::from(error),
},
)?;
socket.udp_state = UdpState::Connected;
socket.connect(connect_addr)?;
}

let incoming_stream = IncomingDatagramStream {
inner: socket.inner.clone(),
inner: socket.socket().clone(),
remote_address,
};
let outgoing_stream = OutgoingDatagramStream {
inner: socket.inner.clone(),
inner: socket.socket().clone(),
remote_address,
family: socket.family,
family: socket.address_family(),
send_state: SendState::Idle,
socket_addr_check: socket.socket_addr_check.clone(),
socket_addr_check: socket.socket_addr_check().cloned(),
};

Ok((
Expand All @@ -159,56 +102,25 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {

fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
let socket = self.table.get(&this)?;

match socket.udp_state {
UdpState::Default => return Err(ErrorCode::InvalidState.into()),
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => {}
}

let addr = socket
.udp_socket()
.as_socketlike_view::<std::net::UdpSocket>()
.local_addr()?;
Ok(addr.into())
Ok(socket.local_address()?.into())
}

fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
let socket = self.table.get(&this)?;

match socket.udp_state {
UdpState::Connected => {}
_ => return Err(ErrorCode::InvalidState.into()),
}

let addr = socket
.udp_socket()
.as_socketlike_view::<std::net::UdpSocket>()
.peer_addr()?;
Ok(addr.into())
Ok(socket.remote_address()?.into())
}

fn address_family(
&mut self,
this: Resource<udp::UdpSocket>,
) -> Result<IpAddressFamily, anyhow::Error> {
let socket = self.table.get(&this)?;

match socket.family {
SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
}
Ok(socket.address_family().into())
}

fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
let socket = self.table.get(&this)?;

let ttl = match socket.family {
SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?,
SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?,
};

Ok(ttl)
Ok(socket.unicast_hop_limit()?)
}

fn set_unicast_hop_limit(
Expand All @@ -217,17 +129,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
value: u8,
) -> SocketResult<()> {
let socket = self.table.get(&this)?;

set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?;

socket.set_unicast_hop_limit(value)?;
Ok(())
}

fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
let socket = self.table.get(&this)?;

let value = receive_buffer_size(socket.udp_socket())?;
Ok(value)
Ok(socket.receive_buffer_size()?)
}

fn set_receive_buffer_size(
Expand All @@ -236,33 +144,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
value: u64,
) -> SocketResult<()> {
let socket = self.table.get(&this)?;

set_receive_buffer_size(socket.udp_socket(), value)?;
socket.set_receive_buffer_size(value)?;
Ok(())
}

fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
let socket = self.table.get(&this)?;

let value = send_buffer_size(socket.udp_socket())?;
Ok(value)
Ok(socket.send_buffer_size()?)
}

fn set_send_buffer_size(
&mut self,
this: Resource<udp::UdpSocket>,
value: u64,
) -> SocketResult<()> {
fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
let socket = self.table.get(&this)?;

set_send_buffer_size(socket.udp_socket(), value)?;
socket.set_send_buffer_size(value)?;
Ok(())
}

fn subscribe(
&mut self,
this: Resource<udp::UdpSocket>,
) -> anyhow::Result<Resource<DynPollable>> {
fn subscribe(&mut self, this: Resource<UdpSocket>) -> anyhow::Result<Resource<DynPollable>> {
wasmtime_wasi_io::poll::subscribe(self.table, this)
}

Expand All @@ -276,6 +173,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
}
}

#[async_trait]
impl Pollable for UdpSocket {
async fn ready(&mut self) {
// None of the socket-level operations block natively
}
}

impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
fn receive(
&mut self,
Expand Down Expand Up @@ -504,6 +408,15 @@ impl Pollable for OutgoingDatagramStream {
}
}

impl From<SocketAddressFamily> for IpAddressFamily {
fn from(family: SocketAddressFamily) -> IpAddressFamily {
match family {
SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
}
}
}

pub mod sync {
use wasmtime::component::Resource;

Expand Down
4 changes: 2 additions & 2 deletions crates/wasi/src/p2/host/udp_create_socket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::p2::SocketResult;
use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket};
use crate::p2::udp::UdpSocket;
use crate::sockets::UdpSocket;
use crate::sockets::WasiSocketsCtxView;
use wasmtime::component::Resource;

Expand All @@ -9,7 +9,7 @@ impl udp_create_socket::Host for WasiSocketsCtxView<'_> {
&mut self,
address_family: IpAddressFamily,
) -> SocketResult<Resource<UdpSocket>> {
let socket = UdpSocket::new(address_family.into())?;
let socket = UdpSocket::new(self.ctx, address_family.into())?;
let socket = self.table.push(socket)?;
Ok(socket)
}
Expand Down
1 change: 1 addition & 0 deletions crates/wasi/src/p2/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl From<crate::sockets::util::ErrorCode> for ErrorCode {
crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset,
crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted,
crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge,
crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress,
}
}
}
Expand Down
Loading