diff --git a/Cargo.toml b/Cargo.toml index 5012df875..2869e5408 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ rustc-std-workspace-alloc = { version = "1.0.0", optional = true } # not aliased # libc backend can be selected via adding `--cfg=rustix_use_libc` to # `RUSTFLAGS` or enabling the `use-libc` cargo feature. [target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", any(target_endian = "little", any(target_arch = "s390x", target_arch = "powerpc")), any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc"), all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "s390x"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies] -linux-raw-sys = { version = "0.11.0", default-features = false, features = ["auxvec", "general", "errno", "ioctl", "no_std", "elf"] } +linux-raw-sys = { version = "0.12.1", default-features = false, features = ["auxvec", "general", "errno", "ioctl", "no_std", "elf"] } libc_errno = { package = "errno", version = "0.3.10", default-features = false, optional = true } libc = { version = "0.2.177", default-features = false, optional = true } @@ -50,7 +50,7 @@ libc = { version = "0.2.177", default-features = false } # Some syscalls do not have libc wrappers, such as in `io_uring`. For these, # the libc backend uses the linux-raw-sys ABI and `libc::syscall`. [target.'cfg(all(any(target_os = "linux", target_os = "android"), any(rustix_use_libc, miri, not(all(target_os = "linux", any(target_endian = "little", any(target_arch = "s390x", target_arch = "powerpc")), any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc"), all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "s390x"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64")))))))'.dependencies] -linux-raw-sys = { version = "0.11.0", default-features = false, features = ["general", "ioctl", "no_std"] } +linux-raw-sys = { version = "0.12.1", default-features = false, features = ["general", "ioctl", "no_std"] } # For the libc backend on Windows, use the Winsock API in windows-sys. [target.'cfg(windows)'.dependencies.windows-sys] @@ -132,7 +132,7 @@ io_uring = ["event", "fs", "net", "thread", "linux-raw-sys/io_uring"] mount = [] # Enable `rustix::net::*`. -net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/xdp"] +net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/vm_sockets", "linux-raw-sys/xdp"] # Enable `rustix::thread::*`. thread = ["linux-raw-sys/prctl"] diff --git a/src/backend/libc/net/read_sockaddr.rs b/src/backend/libc/net/read_sockaddr.rs index 5f8c48ec7..0fc4e3c77 100644 --- a/src/backend/libc/net/read_sockaddr.rs +++ b/src/backend/libc/net/read_sockaddr.rs @@ -11,6 +11,8 @@ use crate::io::Errno; use crate::net::addr::SocketAddrLen; #[cfg(linux_kernel)] use crate::net::netlink::SocketAddrNetlink; +#[cfg(any(linux_kernel, apple))] +use crate::net::vsock::SocketAddrVSock; #[cfg(target_os = "linux")] use crate::net::xdp::{SocketAddrXdp, SocketAddrXdpFlags}; use crate::net::{AddressFamily, Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6}; @@ -262,3 +264,15 @@ pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result() }; Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups)) } + +#[cfg(any(linux_kernel, apple))] +#[inline] +pub(crate) fn read_sockaddr_vsock(addr: &SocketAddrAny) -> Result { + if addr.address_family() != AddressFamily::VSOCK { + return Err(Errno::AFNOSUPPORT); + } + + assert!(addr.addr_len() as usize >= size_of::()); + let decode = unsafe { &*addr.as_ptr().cast::() }; + Ok(SocketAddrVSock::new(decode.svm_cid, decode.svm_port)) +} diff --git a/src/backend/linux_raw/c.rs b/src/backend/linux_raw/c.rs index 762cdd479..36e938290 100644 --- a/src/backend/linux_raw/c.rs +++ b/src/backend/linux_raw/c.rs @@ -105,6 +105,7 @@ pub(crate) use linux_raw_sys::{ TCP_QUICKACK, TCP_THIN_LINEAR_TIMEOUTS, TCP_USER_TIMEOUT, }, netlink::*, + vm_sockets::sockaddr_vm, xdp::{ sockaddr_xdp, xdp_desc, xdp_mmap_offsets, xdp_mmap_offsets_v1, xdp_options, xdp_ring_offset, xdp_ring_offset_v1, xdp_statistics, xdp_statistics_v1, xdp_umem_reg, diff --git a/src/backend/linux_raw/net/read_sockaddr.rs b/src/backend/linux_raw/net/read_sockaddr.rs index f18cd4833..ec93544e9 100644 --- a/src/backend/linux_raw/net/read_sockaddr.rs +++ b/src/backend/linux_raw/net/read_sockaddr.rs @@ -6,6 +6,7 @@ use crate::backend::c; use crate::io::Errno; use crate::net::addr::SocketAddrLen; use crate::net::netlink::SocketAddrNetlink; +use crate::net::vsock::SocketAddrVSock; #[cfg(target_os = "linux")] use crate::net::xdp::{SocketAddrXdp, SocketAddrXdpFlags}; use crate::net::{ @@ -153,3 +154,19 @@ pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result() }; Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups)) } + +#[inline] +pub(crate) fn read_sockaddr_vsock(addr: &SocketAddrAny) -> Result { + if addr.address_family() != AddressFamily::VSOCK { + return Err(Errno::AFNOSUPPORT); + } + + assert!( + addr.addr_len() as usize * 8 >= size_of::(), + "addr len ({}) >= sockaddr_vm len ({})", + addr.addr_len() as usize * 8, + size_of::() + ); + let decode = unsafe { &*addr.as_ptr().cast::() }; + Ok(SocketAddrVSock::new(decode.svm_cid, decode.svm_port)) +} diff --git a/src/net/socket_addr_any.rs b/src/net/socket_addr_any.rs index 7a9530444..c5c8a9c83 100644 --- a/src/net/socket_addr_any.rs +++ b/src/net/socket_addr_any.rs @@ -225,6 +225,12 @@ impl fmt::Debug for SocketAddrAny { return addr.fmt(f); } } + #[cfg(any(linux_kernel, apple))] + AddressFamily::VSOCK => { + if let Ok(addr) = crate::net::vsock::SocketAddrVSock::try_from(self.clone()) { + return addr.fmt(f); + } + } _ => {} } diff --git a/src/net/types.rs b/src/net/types.rs index 370f2468f..3c49928be 100644 --- a/src/net/types.rs +++ b/src/net/types.rs @@ -2081,6 +2081,105 @@ pub mod xdp { pub const XSK_UNALIGNED_BUF_ADDR_MASK: u64 = c::XSK_UNALIGNED_BUF_ADDR_MASK; } +/// `AF_VSOCK` and related types. +#[cfg(any(linux_kernel, apple))] +pub mod vsock { + use crate::backend::net::read_sockaddr::read_sockaddr_vsock; + use crate::net::addr::{call_with_sockaddr, SocketAddrArg, SocketAddrLen, SocketAddrOpaque}; + use crate::net::SocketAddrAny; + + use super::c; + use core::mem; + + /// A VSock socket address. + /// + /// Used to bind to a VSock socket. Not ABI compatible with `sockadr_vm`. + #[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)] + pub struct SocketAddrVSock { + /// The Context IDentifier (CID) to connect to. + cid: u32, + /// The port to connect to. + port: u32, + } + + impl SocketAddrVSock { + /// Construct a new VSock address. + #[inline] + pub const fn new(cid: u32, port: u32) -> Self { + Self { cid, port } + } + + /// Context IDentifier (CID), referring to the VM to connect to. + #[inline] + pub fn cid(&self) -> u32 { + self.cid + } + + /// Set the context identifier. + #[inline] + pub fn set_cid(&mut self, cid: u32) { + self.cid = cid; + } + + /// Port to connect to. + #[inline] + pub fn port(&self) -> u32 { + self.port + } + + /// Set the port to connect to. + #[inline] + pub fn set_port(&mut self, port: u32) { + self.port = port; + } + } + + #[allow(unsafe_code)] + // SAFETY: `with_sockaddr` calls `f` using `call_with_sockaddr`, which + // handles calling `f` with the needed preconditions. + unsafe impl SocketAddrArg for SocketAddrVSock { + unsafe fn with_sockaddr( + &self, + f: impl FnOnce(*const SocketAddrOpaque, SocketAddrLen) -> R, + ) -> R { + let addr = c::sockaddr_vm { + svm_family: c::AF_VSOCK as _, + svm_cid: self.cid, + svm_port: self.port, + ..mem::zeroed() + }; + + call_with_sockaddr(&addr, f) + } + } + + impl From for SocketAddrAny { + #[inline] + fn from(from: SocketAddrVSock) -> Self { + from.as_any() + } + } + + impl TryFrom for SocketAddrVSock { + type Error = crate::io::Errno; + + fn try_from(addr: SocketAddrAny) -> Result { + read_sockaddr_vsock(&addr) + } + } + + /// CID to connect to any host. + pub const VMADDR_CID_ANY: u32 = 0xFFFFFFFF; + /// CID to connect to the hypervisor. + pub const VMADDR_CID_HYPERVISOR: u32 = 0; + /// CID to connect to the local host. + pub const VMADDR_CID_LOCAL: u32 = 1; + /// CID to connect to the host. + pub const VMADDR_CID_HOST: u32 = 2; + /// Connect to any port. + pub const VMADDR_PORT_ANY: u32 = 0xFFFFFFFF; +} + /// UNIX credentials of socket peer, for use with [`get_socket_peercred`] /// [`SendAncillaryMessage::ScmCredentials`] and /// [`RecvAncillaryMessage::ScmCredentials`]. diff --git a/tests/net/main.rs b/tests/net/main.rs index c4bdef4b8..7fe66510c 100644 --- a/tests/net/main.rs +++ b/tests/net/main.rs @@ -23,6 +23,7 @@ mod unix; mod unix_alloc; mod v4; mod v6; +mod vsock; #[cfg(windows)] mod windows { diff --git a/tests/net/vsock.rs b/tests/net/vsock.rs new file mode 100644 index 000000000..debb83312 --- /dev/null +++ b/tests/net/vsock.rs @@ -0,0 +1,318 @@ +//! Test a simple IPv4 socket server and client. +//! +//! The client send a message and the server sends one back. + +#![cfg(any(linux_kernel, apple))] + +use rustix::net::vsock::{SocketAddrVSock, VMADDR_CID_LOCAL}; +use rustix::net::{ + accept, bind, connect, getsockname, listen, recv, send, socket, AddressFamily, RecvFlags, + ReturnFlags, SendFlags, SocketType, +}; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread; + +const BUFFER_SIZE: usize = 20; + +/// Only run vsock tests if it is supported on the current machine. +fn vsock_supported() -> bool { + let sock = match socket(AddressFamily::VSOCK, SocketType::STREAM, None) { + Ok(sock) => sock, + Err(rustix::io::Errno::AFNOSUPPORT) + | Err(rustix::io::Errno::NOTSUP) + | Err(rustix::io::Errno::NODEV) => return false, + Err(_) => return true, + }; + + match bind(&sock, &SocketAddrVSock::new(VMADDR_CID_LOCAL, 0x1230)) { + Ok(_) => true, + Err(rustix::io::Errno::ADDRNOTAVAIL) => false, + Err(_) => true, + } +} + +fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrVSock::new(VMADDR_CID_LOCAL, 0x1234); + bind(&connection_socket, &name).unwrap(); + + let who = getsockname(&connection_socket).unwrap(); + let who = SocketAddrVSock::try_from(who).unwrap(); + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; BUFFER_SIZE]; + let data_socket = accept(&connection_socket).unwrap(); + let (nread, actual) = recv(&data_socket, &mut buffer, RecvFlags::empty()).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer[..nread]), "hello, world"); + assert_eq!(actual, nread); + + send(&data_socket, b"goodnight, moon", SendFlags::empty()).unwrap(); +} + +fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrVSock::new(VMADDR_CID_LOCAL, port); + let mut buffer = vec![0; BUFFER_SIZE]; + + let data_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + connect(&data_socket, &addr).unwrap(); + + send(&data_socket, b"hello, world", SendFlags::empty()).unwrap(); + + let (nread, actual) = recv(&data_socket, &mut buffer, RecvFlags::empty()).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer[..nread]), "goodnight, moon"); + assert_eq!(actual, nread); +} + +#[test] +fn test_vsock() { + crate::init(); + + if !vsock_supported() { + return; + } + + let ready = Arc::new((Mutex::new(0_u32), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + client.join().unwrap(); + server.join().unwrap(); +} + +#[test] +fn test_vsock_msg() { + crate::init(); + + if !vsock_supported() { + return; + } + + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{recvmsg, sendmsg}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrVSock::new(VMADDR_CID_LOCAL, 0x1238); + bind(&connection_socket, &name).unwrap(); + + let who = getsockname(&connection_socket).unwrap(); + let who = SocketAddrVSock::try_from(who).unwrap(); + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; BUFFER_SIZE]; + let data_socket = accept(&connection_socket).unwrap(); + let res = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap(); + assert_eq!( + String::from_utf8_lossy(&buffer[..res.bytes]), + "hello, world" + ); + assert_eq!(res.flags, ReturnFlags::empty()); + + sendmsg( + &data_socket, + &[IoSlice::new(b"goodnight, moon")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrVSock::new(VMADDR_CID_LOCAL, port); + let mut buffer = vec![0; BUFFER_SIZE]; + + let data_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + connect(&data_socket, &addr).unwrap(); + + sendmsg( + &data_socket, + &[IoSlice::new(b"hello, world")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + + let res = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap(); + assert_eq!( + String::from_utf8_lossy(&buffer[..res.bytes]), + "goodnight, moon" + ); + assert_eq!(res.flags, ReturnFlags::empty()); + } + + let ready = Arc::new((Mutex::new(0_u32), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} + +#[cfg(target_os = "linux")] +#[test] +// TODO(notgull): Figure out why this test keeps failing. +#[ignore] +fn test_vsock_sendmmsg() { + crate::init(); + + if !vsock_supported() { + return; + } + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::addr::SocketAddrArg as _; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrVSock::new(VMADDR_CID_LOCAL, 0x1236); + bind(&connection_socket, &name).unwrap(); + + let who = getsockname(&connection_socket).unwrap(); + let who = SocketAddrVSock::try_from(who).unwrap(); + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrVSock::new(VMADDR_CID_LOCAL, port); + let data_socket = socket(AddressFamily::VSOCK, SocketType::STREAM, None).unwrap(); + connect(&data_socket, &addr).unwrap(); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new_with_addr( + &addr.as_any(), + &[IoSlice::new(b"...world")], + &mut Default::default(), + ), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u32), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +}