From 6be0408f40b3b70ccac6b3701dc34827ddd9d209 Mon Sep 17 00:00:00 2001 From: badeend Date: Sat, 14 Jun 2025 11:23:20 +0200 Subject: [PATCH 01/10] Reorganize the wasi-tls folder and isolate the `rustls`-specific bits from the rest of the implementation. - `client.rs`: The rustls parts. - `io.rs`: WASI I/O conversion utility types. - `host.rs`: The host types + impls. - `bindings.rs`: The generated bindings. --- crates/wasi-tls/Cargo.toml | 1 + crates/wasi-tls/src/bindings.rs | 21 ++ crates/wasi-tls/src/client.rs | 55 +++ crates/wasi-tls/src/host.rs | 149 ++++++++ crates/wasi-tls/src/io.rs | 384 ++++++++++++++++++++ crates/wasi-tls/src/lib.rs | 621 +------------------------------- 6 files changed, 620 insertions(+), 611 deletions(-) create mode 100644 crates/wasi-tls/src/bindings.rs create mode 100644 crates/wasi-tls/src/client.rs create mode 100644 crates/wasi-tls/src/host.rs create mode 100644 crates/wasi-tls/src/io.rs diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index 7bc7a29c26dc..b5f32292aa16 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -18,6 +18,7 @@ tokio = { workspace = true, features = [ "net", "rt-multi-thread", "time", + "io-util" ] } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } diff --git a/crates/wasi-tls/src/bindings.rs b/crates/wasi-tls/src/bindings.rs new file mode 100644 index 000000000000..355034ee512b --- /dev/null +++ b/crates/wasi-tls/src/bindings.rs @@ -0,0 +1,21 @@ +//! Auto-generated bindings. + +#[expect(missing_docs, reason = "bindgen-generated code")] +mod generated { + wasmtime::component::bindgen!({ + path: "wit", + world: "wasi:tls/imports", + with: { + "wasi:io": wasmtime_wasi::p2::bindings::io, + "wasi:tls/types/client-connection": crate::HostClientConnection, + "wasi:tls/types/client-handshake": crate::HostClientHandshake, + "wasi:tls/types/future-client-streams": crate::HostFutureClientStreams, + }, + trappable_imports: true, + async: { + only_imports: [], + } + }); +} + +pub use generated::wasi::tls::*; diff --git a/crates/wasi-tls/src/client.rs b/crates/wasi-tls/src/client.rs new file mode 100644 index 000000000000..f89691700b70 --- /dev/null +++ b/crates/wasi-tls/src/client.rs @@ -0,0 +1,55 @@ +//! A uniform TLS client interface, abstracting away the differences between the +//! `rustls` and `native-tls` implementations. + +use rustls::pki_types::ServerName; +use std::io; +use std::sync::Arc; +use std::sync::LazyLock; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// A client TLS handshake configuration object. +/// +/// At the time of writing, there's nothing to configure (yet). +pub struct Handshake { + transport: IO, + server_name: String, +} +impl Handshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Create a new handshake. + pub fn new(server_name: String, transport: IO) -> Self { + Self { + server_name, + transport, + } + } + + /// Run the handshake to completion. + pub async fn finish(self) -> io::Result> { + let domain = ServerName::try_from(self.server_name) + .map_err(|_| io::Error::other("invalid server name"))?; + + let stream = tokio_rustls::TlsConnector::from(Self::client_config()) + .connect(domain, self.transport) + .await?; + Ok(stream) + } + + fn client_config() -> Arc { + static CONFIG: LazyLock> = LazyLock::new(|| { + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + Arc::new(config) + }); + Arc::clone(&CONFIG) + } +} + +/// A TLS client connection. +pub type Connection = tokio_rustls::client::TlsStream; diff --git a/crates/wasi-tls/src/host.rs b/crates/wasi-tls/src/host.rs new file mode 100644 index 000000000000..784fa3257ad9 --- /dev/null +++ b/crates/wasi-tls/src/host.rs @@ -0,0 +1,149 @@ +use anyhow::Result; +use wasmtime::component::Resource; +use wasmtime_wasi::async_trait; +use wasmtime_wasi::p2::Pollable; +use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable, IoError}; + +use crate::{ + WasiTlsCtx, bindings, + io::{ + AsyncReadStream, AsyncWriteStream, FutureOutput, WasiFuture, WasiStreamReader, + WasiStreamWriter, + }, +}; + +impl<'a> bindings::types::Host for WasiTlsCtx<'a> {} + +/// The underlying transport. Typically, this is a TCP input+output stream. +type Transport = tokio::io::Join; + +/// Represents the ClientHandshake which will be used to configure the handshake +pub struct HostClientHandshake(crate::client::Handshake); + +impl<'a> bindings::types::HostClientHandshake for WasiTlsCtx<'a> { + fn new( + &mut self, + server_name: String, + input: Resource, + output: Resource, + ) -> wasmtime::Result> { + let input = self.table.delete(input)?; + let output = self.table.delete(output)?; + + let reader = WasiStreamReader::new(input); + let writer = WasiStreamWriter::new(output); + let transport = tokio::io::join(reader, writer); + let handshake = crate::client::Handshake::new(server_name, transport); + + Ok(self.table.push(HostClientHandshake(handshake))?) + } + + fn finish( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + let handshake = self.table.delete(this)?; + + let future = HostFutureClientStreams(WasiFuture::spawn(async move { + let tls_stream = handshake.0.finish().await?; + + let (rx, tx) = tokio::io::split(tls_stream); + let write_stream = AsyncWriteStream::new(tx); + let client = HostClientConnection(write_stream.clone()); + + let input = Box::new(AsyncReadStream::new(rx)) as DynInputStream; + let output = Box::new(write_stream) as DynOutputStream; + + Ok((client, input, output)) + })); + + Ok(self.table.push(future)?) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Future streams provides the tls streams after the handshake is completed +pub struct HostFutureClientStreams( + WasiFuture>, +); + +#[async_trait] +impl Pollable for HostFutureClientStreams { + async fn ready(&mut self) { + self.0.ready().await + } +} + +impl<'a> bindings::types::HostFutureClientStreams for WasiTlsCtx<'a> { + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + wasmtime_wasi::p2::subscribe(self.table, this) + } + + fn get( + &mut self, + this: Resource, + ) -> wasmtime::Result< + Option< + Result< + Result< + ( + Resource, + Resource, + Resource, + ), + Resource, + >, + (), + >, + >, + > { + let future = self.table.get_mut(&this)?; + + let result = match future.0.get() { + FutureOutput::Ready(Ok((client, input, output))) => { + let client = self.table.push(client)?; + let input = self.table.push_child(input, &client)?; + let output = self.table.push_child(output, &client)?; + + Some(Ok(Ok((client, input, output)))) + } + FutureOutput::Ready(Err(io_error)) => { + let io_error = self.table.push(io_error)?; + + Some(Ok(Err(io_error))) + } + FutureOutput::Consumed => Some(Err(())), + FutureOutput::Pending => None, + }; + + Ok(result) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Represents the client connection and used to shut down the tls stream +pub struct HostClientConnection( + crate::io::AsyncWriteStream>>, +); + +impl<'a> bindings::types::HostClientConnection for WasiTlsCtx<'a> { + fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.get_mut(&this)?.0.close() + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} diff --git a/crates/wasi-tls/src/io.rs b/crates/wasi-tls/src/io.rs new file mode 100644 index 000000000000..399d23eb6763 --- /dev/null +++ b/crates/wasi-tls/src/io.rs @@ -0,0 +1,384 @@ +//! Utility types for converting Rust & Tokio I/O types into WASI I/O types, +//! and vice versa. + +use anyhow::Result; +use bytes::Bytes; +use std::io; +use std::sync::Arc; +use std::task::{Poll, ready}; +use std::{future::Future, mem, pin::Pin}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::sync::Mutex; +use wasmtime_wasi::async_trait; +use wasmtime_wasi::p2::{ + DynInputStream, DynOutputStream, OutputStream, Pollable, StreamError, StreamResult, +}; +use wasmtime_wasi::runtime::AbortOnDropJoinHandle; + +enum FutureState { + Pending(Pin + Send>>), + Ready(T), + Consumed, +} + +pub(crate) enum FutureOutput { + Pending, + Ready(T), + Consumed, +} + +pub(crate) struct WasiFuture(FutureState); + +impl WasiFuture +where + T: Send + 'static, +{ + pub(crate) fn spawn(fut: F) -> Self + where + F: Future + Send + 'static, + { + Self(FutureState::Pending(Box::pin( + wasmtime_wasi::runtime::spawn(async move { fut.await }), + ))) + } + + pub(crate) fn get(&mut self) -> FutureOutput { + match &self.0 { + FutureState::Pending(_) => return FutureOutput::Pending, + FutureState::Consumed => return FutureOutput::Consumed, + FutureState::Ready(_) => (), + } + + let FutureState::Ready(value) = mem::replace(&mut self.0, FutureState::Consumed) else { + unreachable!() + }; + + FutureOutput::Ready(value) + } +} + +#[async_trait] +impl Pollable for WasiFuture +where + T: Send + 'static, +{ + async fn ready(&mut self) { + match &mut self.0 { + FutureState::Ready(_) | FutureState::Consumed => return, + FutureState::Pending(task) => self.0 = FutureState::Ready(task.as_mut().await), + } + } +} + +pub(crate) struct WasiStreamReader(FutureState); +impl WasiStreamReader { + pub(crate) fn new(stream: DynInputStream) -> Self { + Self(FutureState::Ready(stream)) + } +} +impl AsyncRead for WasiStreamReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + loop { + let stream = match &mut self.0 { + FutureState::Ready(stream) => stream, + FutureState::Pending(fut) => { + let stream = ready!(fut.as_mut().poll(cx)); + self.0 = FutureState::Ready(stream); + if let FutureState::Ready(stream) = &mut self.0 { + stream + } else { + unreachable!() + } + } + FutureState::Consumed => { + return Poll::Ready(Ok(())); + } + }; + match stream.read(buf.remaining()) { + Ok(bytes) if bytes.is_empty() => { + let FutureState::Ready(mut stream) = + std::mem::replace(&mut self.0, FutureState::Consumed) + else { + unreachable!() + }; + + self.0 = FutureState::Pending(Box::pin(async move { + stream.ready().await; + stream + })); + } + Ok(bytes) => { + buf.put_slice(&bytes); + + return Poll::Ready(Ok(())); + } + Err(StreamError::Closed) => { + self.0 = FutureState::Consumed; + return Poll::Ready(Ok(())); + } + Err(e) => { + self.0 = FutureState::Consumed; + return Poll::Ready(Err(std::io::Error::other(e))); + } + } + } + } +} + +pub(crate) struct WasiStreamWriter(FutureState); +impl WasiStreamWriter { + pub(crate) fn new(stream: DynOutputStream) -> Self { + Self(FutureState::Ready(stream)) + } +} +impl AsyncWrite for WasiStreamWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + loop { + match &mut self.as_mut().0 { + FutureState::Consumed => unreachable!(), + FutureState::Pending(future) => { + let value = ready!(future.as_mut().poll(cx)); + self.as_mut().0 = FutureState::Ready(value); + } + FutureState::Ready(output) => { + match output.check_write() { + Ok(0) => { + let FutureState::Ready(mut output) = + mem::replace(&mut self.as_mut().0, FutureState::Consumed) + else { + unreachable!() + }; + self.as_mut().0 = FutureState::Pending(Box::pin(async move { + output.ready().await; + output + })); + } + Ok(count) => { + let count = count.min(buf.len()); + return match output.write(Bytes::copy_from_slice(&buf[..count])) { + Ok(()) => Poll::Ready(Ok(count)), + Err(StreamError::Closed) => Poll::Ready(Ok(0)), + Err(e) => Poll::Ready(Err(std::io::Error::other(e))), + }; + } + Err(StreamError::Closed) => { + // Our current version of tokio-rustls does not handle returning `Ok(0)` well. + // See: https://github.com/rustls/tokio-rustls/issues/92 + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), + }; + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_write(cx, &[]).map(|v| v.map(drop)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_flush(cx) + } +} + +pub(crate) use wasmtime_wasi::p2::pipe::AsyncReadStream; + +pub(crate) struct AsyncWriteStream(Arc>>); + +impl AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + pub(crate) fn new(io: IO) -> Self { + AsyncWriteStream(Arc::new(Mutex::new(WriteState::new(io)))) + } + + pub(crate) fn close(&mut self) -> wasmtime::Result<()> { + self.try_lock()?.close(); + Ok(()) + } + + async fn lock(&self) -> tokio::sync::MutexGuard<'_, WriteState> { + self.0.lock().await + } + + fn try_lock(&self) -> Result>, StreamError> { + self.0 + .try_lock() + .map_err(|_| StreamError::trap("concurrent access to resource not supported")) + } +} +impl Clone for AsyncWriteStream { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +#[async_trait] +impl OutputStream for AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> { + self.try_lock()?.write(bytes) + } + + fn flush(&mut self) -> StreamResult<()> { + self.try_lock()?.flush() + } + + fn check_write(&mut self) -> StreamResult { + self.try_lock()?.check_write() + } + + async fn cancel(&mut self) { + self.lock().await.cancel().await + } +} + +#[async_trait] +impl Pollable for AsyncWriteStream +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + async fn ready(&mut self) { + self.lock().await.ready().await + } +} + +enum WriteState { + Ready(IO), + Writing(AbortOnDropJoinHandle>), + Closing(AbortOnDropJoinHandle>), + Closed, + Error(io::Error), +} +const READY_SIZE: usize = 1024 * 1024 * 1024; + +impl WriteState +where + IO: AsyncWrite + Send + Unpin + 'static, +{ + fn new(stream: IO) -> Self { + Self::Ready(stream) + } + + fn write(&mut self, mut bytes: bytes::Bytes) -> StreamResult<()> { + let WriteState::Ready(_) = self else { + return Err(StreamError::Trap(anyhow::anyhow!( + "unpermitted: must call check_write first" + ))); + }; + + if bytes.is_empty() { + return Ok(()); + } + + let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + + *self = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move { + while !bytes.is_empty() { + let n = stream.write(&bytes).await?; + let _ = bytes.split_to(n); + } + + Ok(stream) + })); + + Ok(()) + } + + fn flush(&mut self) -> StreamResult<()> { + // `flush` is a no-op here, as we're not managing any internal buffer. + match self { + WriteState::Ready(_) + | WriteState::Writing(_) + | WriteState::Closing(_) + | WriteState::Error(_) => Ok(()), + WriteState::Closed => Err(StreamError::Closed), + } + } + + fn check_write(&mut self) -> StreamResult { + match self { + WriteState::Ready(_) => Ok(READY_SIZE), + WriteState::Writing(_) => Ok(0), + WriteState::Closing(_) => Ok(0), + WriteState::Closed => Err(StreamError::Closed), + WriteState::Error(_) => { + let WriteState::Error(e) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + + Err(StreamError::LastOperationFailed(e.into())) + } + } + } + + fn close(&mut self) { + match std::mem::replace(self, WriteState::Closed) { + // No write in progress, immediately shut down: + WriteState::Ready(mut stream) => { + *self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + stream.shutdown().await + })); + } + + // Schedule the shutdown after the current write has finished: + WriteState::Writing(write) => { + *self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + let mut stream = write.await?; + stream.shutdown().await + })); + } + + WriteState::Closing(t) => { + *self = WriteState::Closing(t); + } + WriteState::Closed | WriteState::Error(_) => {} + } + } + + async fn cancel(&mut self) { + match std::mem::replace(self, WriteState::Closed) { + WriteState::Writing(task) => _ = task.cancel().await, + WriteState::Closing(task) => _ = task.cancel().await, + _ => {} + } + } + + async fn ready(&mut self) { + match self { + WriteState::Writing(task) => { + *self = match task.await { + Ok(s) => WriteState::Ready(s), + Err(e) => WriteState::Error(e), + } + } + WriteState::Closing(task) => { + *self = match task.await { + Ok(()) => WriteState::Closed, + Err(e) => WriteState::Error(e), + } + } + _ => {} + } + } +} diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 2e8733c9c3ff..22f2e16ad295 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -71,60 +71,17 @@ #![doc(test(attr(deny(warnings))))] #![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] -use anyhow::Result; -use bytes::Bytes; -use rustls::pki_types::ServerName; -use std::io; -use std::sync::Arc; -use std::task::{Poll, ready}; -use std::{future::Future, mem, pin::Pin, sync::LazyLock}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::Mutex; -use tokio_rustls::client::TlsStream; -use wasmtime::component::{HasData, Resource, ResourceTable}; -use wasmtime_wasi::async_trait; -use wasmtime_wasi::p2::bindings::io::{ - error::Error as HostIoError, - poll::Pollable as HostPollable, - streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream}, -}; -use wasmtime_wasi::p2::pipe::AsyncReadStream; -use wasmtime_wasi::p2::{OutputStream, Pollable, StreamError}; -use wasmtime_wasi::runtime::AbortOnDropJoinHandle; +use wasmtime::component::{HasData, ResourceTable}; -mod gen_ { - wasmtime::component::bindgen!({ - path: "wit", - world: "wasi:tls/imports", - with: { - "wasi:io": wasmtime_wasi::p2::bindings::io, - "wasi:tls/types/client-connection": super::ClientConnection, - "wasi:tls/types/client-handshake": super::ClientHandShake, - "wasi:tls/types/future-client-streams": super::FutureClientStreams, - }, - trappable_imports: true, - async: { - only_imports: [], - } - }); -} -pub use gen_::wasi::tls::types::LinkOptions; -use gen_::wasi::tls::{self as generated}; +pub mod bindings; +mod client; +mod host; +mod io; -fn default_client_config() -> Arc { - static CONFIG: LazyLock> = LazyLock::new(|| { - let roots = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - let config = rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_no_client_auth(); - Arc::new(config) - }); - Arc::clone(&CONFIG) -} +pub use bindings::types::LinkOptions; +pub use host::{HostClientConnection, HostClientHandshake, HostFutureClientStreams}; -/// Wasi TLS context needed fro internal `wasi-tls`` state +/// Wasi TLS context needed for internal `wasi-tls` state pub struct WasiTlsCtx<'a> { table: &'a mut ResourceTable, } @@ -136,15 +93,13 @@ impl<'a> WasiTlsCtx<'a> { } } -impl<'a> generated::types::Host for WasiTlsCtx<'a> {} - /// Add the `wasi-tls` world's types to a [`wasmtime::component::Linker`]. pub fn add_to_linker( l: &mut wasmtime::component::Linker, opts: &mut LinkOptions, f: fn(&mut T) -> WasiTlsCtx<'_>, -) -> Result<()> { - generated::types::add_to_linker::<_, WasiTls>(l, &opts, f)?; +) -> anyhow::Result<()> { + bindings::types::add_to_linker::<_, WasiTls>(l, &opts, f)?; Ok(()) } @@ -153,559 +108,3 @@ struct WasiTls; impl HasData for WasiTls { type Data<'a> = WasiTlsCtx<'a>; } - -enum TlsError { - /// The component should trap. Under normal circumstances, this only occurs - /// when the underlying transport stream returns [`StreamError::Trap`]. - Trap(anyhow::Error), - - /// A failure indicated by the underlying transport stream as - /// [`StreamError::LastOperationFailed`]. - Io(wasmtime_wasi::p2::IoError), - - /// A TLS protocol error occurred. - Tls(rustls::Error), -} - -impl TlsError { - /// Create a [`TlsError::Tls`] error from a simple message. - fn msg(msg: &str) -> Self { - // (Ab)using rustls' error type to synthesize our own TLS errors: - Self::Tls(rustls::Error::General(msg.to_string())) - } -} - -impl From for TlsError { - fn from(error: io::Error) -> Self { - // Report unexpected EOFs as an error to prevent truncation attacks. - // See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read - if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() { - return Self::msg("underlying transport closed abruptly"); - } - - // Errors from underlying transport. - // These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below. - let error = match error.downcast::() { - Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e), - Ok(StreamError::Trap(e)) => return Self::Trap(e), - Ok(StreamError::Closed) => unreachable!( - "our wasi-to-tokio stream transformer should have translated this to a 0-sized read" - ), - Err(e) => e, - }; - - // Errors from `rustls`. - // These have been wrapped inside `io::Error`s by `tokio-rustls`. - let error = match error.downcast::() { - Ok(e) => return Self::Tls(e), - Err(e) => e, - }; - - // All errors should have been handled by the clauses above. - Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error")) - } -} - -/// Represents the ClientHandshake which will be used to configure the handshake -pub struct ClientHandShake { - server_name: String, - streams: WasiStreams, -} - -impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { - fn new( - &mut self, - server_name: String, - input: Resource, - output: Resource, - ) -> wasmtime::Result> { - let input = self.table.delete(input)?; - let output = self.table.delete(output)?; - Ok(self.table.push(ClientHandShake { - server_name, - streams: WasiStreams { - input: StreamState::Ready(input), - output: StreamState::Ready(output), - }, - })?) - } - - fn finish( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result> { - let handshake = self.table.delete(this)?; - let server_name = handshake.server_name; - let streams = handshake.streams; - - Ok(self - .table - .push(FutureStreams(StreamState::Pending(Box::pin(async move { - let domain = ServerName::try_from(server_name) - .map_err(|_| TlsError::msg("invalid server name"))?; - - let stream = tokio_rustls::TlsConnector::from(default_client_config()) - .connect(domain, streams) - .await?; - Ok(stream) - }))))?) - } - - fn drop( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } -} - -/// Future streams provides the tls streams after the handshake is completed -pub struct FutureStreams(StreamState>); - -/// Library specific version of TLS connection after the handshake is completed. -/// This alias allows it to use with wit-bindgen component generator which won't take generic types -pub type FutureClientStreams = FutureStreams>; - -#[async_trait] -impl Pollable for FutureStreams { - async fn ready(&mut self) { - match &mut self.0 { - StreamState::Ready(_) | StreamState::Closed => return, - StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await), - } - } -} - -impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> { - fn subscribe( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result> { - wasmtime_wasi::p2::subscribe(self.table, this) - } - - fn get( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result< - Option< - Result< - Result< - ( - Resource, - Resource, - Resource, - ), - Resource, - >, - (), - >, - >, - > { - let this = &mut self.table.get_mut(&this)?.0; - match this { - StreamState::Pending(_) => return Ok(None), - StreamState::Closed => return Ok(Some(Err(()))), - StreamState::Ready(_) => (), - } - - let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else { - unreachable!() - }; - - let tls_stream = match result { - Ok(s) => s, - Err(TlsError::Trap(e)) => return Err(e), - Err(TlsError::Io(e)) => { - let error = self.table.push(e)?; - return Ok(Some(Ok(Err(error)))); - } - Err(TlsError::Tls(e)) => { - let error = self.table.push(wasmtime_wasi::p2::IoError::new(e))?; - return Ok(Some(Ok(Err(error)))); - } - }; - - let (rx, tx) = tokio::io::split(tls_stream); - let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx)); - let client = ClientConnection { - writer: write_stream.clone(), - }; - - let input = Box::new(AsyncReadStream::new(rx)) as BoxInputStream; - let output = Box::new(write_stream) as BoxOutputStream; - - let client = self.table.push(client)?; - let input = self.table.push_child(input, &client)?; - let output = self.table.push_child(output, &client)?; - - Ok(Some(Ok(Ok((client, input, output))))) - } - - fn drop( - &mut self, - this: wasmtime::component::Resource, - ) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } -} - -/// Represents the client connection and used to shut down the tls stream -pub struct ClientConnection { - writer: AsyncTlsWriteStream, -} - -impl<'a> generated::types::HostClientConnection for WasiTlsCtx<'a> { - fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { - self.table.get_mut(&this)?.writer.close() - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - self.table.delete(this)?; - Ok(()) - } -} - -enum StreamState { - Ready(T), - Pending(Pin + Send>>), - Closed, -} - -/// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write -pub struct WasiStreams { - input: StreamState, - output: StreamState, -} - -impl AsyncWrite for WasiStreams { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - loop { - match &mut self.as_mut().output { - StreamState::Closed => unreachable!(), - StreamState::Pending(future) => { - let value = ready!(future.as_mut().poll(cx)); - self.as_mut().output = StreamState::Ready(value); - } - StreamState::Ready(output) => { - match output.check_write() { - Ok(0) => { - let StreamState::Ready(mut output) = - mem::replace(&mut self.as_mut().output, StreamState::Closed) - else { - unreachable!() - }; - self.as_mut().output = StreamState::Pending(Box::pin(async move { - output.ready().await; - output - })); - } - Ok(count) => { - let count = count.min(buf.len()); - return match output.write(Bytes::copy_from_slice(&buf[..count])) { - Ok(()) => Poll::Ready(Ok(count)), - Err(StreamError::Closed) => Poll::Ready(Ok(0)), - Err(e) => Poll::Ready(Err(std::io::Error::other(e))), - }; - } - Err(StreamError::Closed) => { - // Our current version of tokio-rustls does not handle returning `Ok(0)` well. - // See: https://github.com/rustls/tokio-rustls/issues/92 - return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); - } - Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), - }; - } - } - } - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_write(cx, &[]).map(|v| v.map(drop)) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_flush(cx) - } -} - -impl AsyncRead for WasiStreams { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - loop { - let stream = match &mut self.input { - StreamState::Ready(stream) => stream, - StreamState::Pending(fut) => { - let stream = ready!(fut.as_mut().poll(cx)); - self.input = StreamState::Ready(stream); - if let StreamState::Ready(stream) = &mut self.input { - stream - } else { - unreachable!() - } - } - StreamState::Closed => { - return Poll::Ready(Ok(())); - } - }; - match stream.read(buf.remaining()) { - Ok(bytes) if bytes.is_empty() => { - let StreamState::Ready(mut stream) = - std::mem::replace(&mut self.input, StreamState::Closed) - else { - unreachable!() - }; - - self.input = StreamState::Pending(Box::pin(async move { - stream.ready().await; - stream - })); - } - Ok(bytes) => { - buf.put_slice(&bytes); - - return Poll::Ready(Ok(())); - } - Err(StreamError::Closed) => { - self.input = StreamState::Closed; - return Poll::Ready(Ok(())); - } - Err(e) => { - self.input = StreamState::Closed; - return Poll::Ready(Err(std::io::Error::other(e))); - } - } - } - } -} - -type TlsWriteHalf = tokio::io::WriteHalf>; - -struct TlsWriter { - state: WriteState, -} - -enum WriteState { - Ready(TlsWriteHalf), - Writing(AbortOnDropJoinHandle>), - Closing(AbortOnDropJoinHandle>), - Closed, - Error(io::Error), -} -const READY_SIZE: usize = 1024 * 1024 * 1024; - -impl TlsWriter { - fn new(stream: TlsWriteHalf) -> Self { - Self { - state: WriteState::Ready(stream), - } - } - - fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> { - let WriteState::Ready(_) = self.state else { - return Err(StreamError::Trap(anyhow::anyhow!( - "unpermitted: must call check_write first" - ))); - }; - - if bytes.is_empty() { - return Ok(()); - } - - let WriteState::Ready(mut stream) = std::mem::replace(&mut self.state, WriteState::Closed) - else { - unreachable!() - }; - - self.state = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move { - while !bytes.is_empty() { - let n = stream.write(&bytes).await?; - let _ = bytes.split_to(n); - } - - Ok(stream) - })); - - Ok(()) - } - - fn flush(&mut self) -> Result<(), StreamError> { - // `flush` is a no-op here, as we're not managing any internal buffer. - match self.state { - WriteState::Ready(_) - | WriteState::Writing(_) - | WriteState::Closing(_) - | WriteState::Error(_) => Ok(()), - WriteState::Closed => Err(StreamError::Closed), - } - } - - fn check_write(&mut self) -> Result { - match &mut self.state { - WriteState::Ready(_) => Ok(READY_SIZE), - WriteState::Writing(_) => Ok(0), - WriteState::Closing(_) => Ok(0), - WriteState::Closed => Err(StreamError::Closed), - WriteState::Error(_) => { - let WriteState::Error(e) = std::mem::replace(&mut self.state, WriteState::Closed) - else { - unreachable!() - }; - - Err(StreamError::LastOperationFailed(e.into())) - } - } - } - - fn close(&mut self) { - match std::mem::replace(&mut self.state, WriteState::Closed) { - // No write in progress, immediately shut down: - WriteState::Ready(mut stream) => { - self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { - stream.shutdown().await - })); - } - - // Schedule the shutdown after the current write has finished: - WriteState::Writing(write) => { - self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { - let mut stream = write.await?; - stream.shutdown().await - })); - } - - WriteState::Closing(t) => { - self.state = WriteState::Closing(t); - } - WriteState::Closed | WriteState::Error(_) => {} - } - } - - async fn cancel(&mut self) { - match std::mem::replace(&mut self.state, WriteState::Closed) { - WriteState::Writing(task) => _ = task.cancel().await, - WriteState::Closing(task) => _ = task.cancel().await, - _ => {} - } - } - - async fn ready(&mut self) { - match &mut self.state { - WriteState::Writing(task) => { - self.state = match task.await { - Ok(s) => WriteState::Ready(s), - Err(e) => WriteState::Error(e), - } - } - WriteState::Closing(task) => { - self.state = match task.await { - Ok(()) => WriteState::Closed, - Err(e) => WriteState::Error(e), - } - } - _ => {} - } - } -} - -#[derive(Clone)] -struct AsyncTlsWriteStream(Arc>); - -impl AsyncTlsWriteStream { - fn new(writer: TlsWriter) -> Self { - AsyncTlsWriteStream(Arc::new(Mutex::new(writer))) - } - - fn close(&mut self) -> wasmtime::Result<()> { - try_lock_for_stream(&self.0)?.close(); - Ok(()) - } -} - -#[async_trait] -impl OutputStream for AsyncTlsWriteStream { - fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> { - try_lock_for_stream(&self.0)?.write(bytes) - } - - fn flush(&mut self) -> Result<(), StreamError> { - try_lock_for_stream(&self.0)?.flush() - } - - fn check_write(&mut self) -> Result { - try_lock_for_stream(&self.0)?.check_write() - } - - async fn cancel(&mut self) { - self.0.lock().await.cancel().await - } -} - -#[async_trait] -impl Pollable for AsyncTlsWriteStream { - async fn ready(&mut self) { - self.0.lock().await.ready().await - } -} - -fn try_lock_for_stream( - mutex: &Mutex, -) -> Result, StreamError> { - mutex - .try_lock() - .map_err(|_| StreamError::trap("concurrent access to resource not supported")) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::task::Waker; - use tokio::sync::oneshot; - - #[tokio::test] - async fn test_future_client_streams_ready_can_be_canceled() { - let (tx1, rx1) = oneshot::channel::<()>(); - - let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move { - rx1.await - .map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled"))) - }))); - - let mut fut = future_streams.ready(); - - let mut cx = std::task::Context::from_waker(Waker::noop()); - assert!(fut.as_mut().poll(&mut cx).is_pending()); - - //cancel the readiness check - drop(fut); - - match future_streams.0 { - StreamState::Closed => panic!("First future should be in Pending/ready state"), - _ => (), - } - - // make it ready and wait for it to progress - tx1.send(()).unwrap(); - future_streams.ready().await; - - match future_streams.0 { - StreamState::Ready(Ok(())) => (), - _ => panic!("First future should be in Ready(Err) state"), - } - } -} From 181d138de7d147576e880278646333c1264e78ce Mon Sep 17 00:00:00 2001 From: badeend Date: Sun, 15 Jun 2025 22:32:37 +0200 Subject: [PATCH 02/10] Add `native-tls` backend --- Cargo.lock | 137 ++++++++++++++++++ Cargo.toml | 2 + .../src/bin/tls_sample_application.rs | 2 +- crates/wasi-tls/Cargo.toml | 16 +- crates/wasi-tls/src/client_nativetls.rs | 43 ++++++ .../src/{client.rs => client_rustls.rs} | 0 crates/wasi-tls/src/lib.rs | 13 +- 7 files changed, 207 insertions(+), 6 deletions(-) create mode 100644 crates/wasi-tls/src/client_nativetls.rs rename crates/wasi-tls/src/{client.rs => client_rustls.rs} (100%) diff --git a/Cargo.lock b/Cargo.lock index 6704ca9b9436..e5554312eb2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -627,6 +627,16 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -1370,6 +1380,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2317,6 +2342,23 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "685a9ac4b61f4e728e1d2c6a7844609c16527aeb5e6c865915c08e619c16410f" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -2452,6 +2494,50 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "openssl" +version = "0.10.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "openvino" version = "0.8.0" @@ -3005,6 +3091,38 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.17" @@ -3493,6 +3611,16 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.25.0" @@ -3758,6 +3886,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "veri_engine" version = "0.1.0" @@ -4787,10 +4921,13 @@ version = "35.0.0" dependencies = [ "anyhow", "bytes", + "cfg-if", "futures", + "native-tls", "rustls 0.22.4", "test-programs-artifacts", "tokio", + "tokio-native-tls", "tokio-rustls", "wasmtime", "wasmtime-wasi", diff --git a/Cargo.toml b/Cargo.toml index e351b655021d..e27110321f11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -398,6 +398,8 @@ ittapi = "0.4.0" libm = "0.2.7" tokio-rustls = "0.25.0" rustls = "0.22.0" +tokio-native-tls = "0.3.1" +native-tls = "0.2.11" webpki-roots = "0.26.0" itertools = "0.14.0" base64 = "0.22.1" diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs index 2c570fddedff..622996e1be97 100644 --- a/crates/test-programs/src/bin/tls_sample_application.rs +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -55,7 +55,7 @@ fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() { // We're expecting an error regarding the "certificate" is some form or - // another. When we add more TLS backends other than rustls, this naive + // another. When we add more TLS backends this naive // check will likely need to be revisited/expanded: Err(e) if e.to_debug_string().contains("certificate") => Ok(()), diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index b5f32292aa16..967b422f02ad 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -11,6 +11,11 @@ description = "Wasmtime implementation of the wasi-tls API" [lints] workspace = true +[features] +default = ["rustls"] +rustls = ["dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] +native-tls = ["dep:native-tls", "dep:tokio-native-tls"] + [dependencies] anyhow = { workspace = true } bytes = { workspace = true } @@ -18,14 +23,17 @@ tokio = { workspace = true, features = [ "net", "rt-multi-thread", "time", - "io-util" + "io-util", ] } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } -tokio-rustls = { workspace = true } -rustls = { workspace = true } -webpki-roots = { workspace = true } +cfg-if = { workspace = true } +tokio-rustls = { workspace = true, optional = true } +rustls = { workspace = true, optional = true } +webpki-roots = { workspace = true, optional = true } +tokio-native-tls = { workspace = true, optional = true } +native-tls = { workspace = true, optional = true } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasi-tls/src/client_nativetls.rs b/crates/wasi-tls/src/client_nativetls.rs new file mode 100644 index 000000000000..a67dff8324dc --- /dev/null +++ b/crates/wasi-tls/src/client_nativetls.rs @@ -0,0 +1,43 @@ +//! A uniform TLS client interface, abstracting away the differences between the +//! `rustls` and `native-tls` implementations. + +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// A client TLS handshake configuration object. +/// +/// At the time of writing, there's nothing to configure (yet). +pub struct Handshake { + transport: IO, + server_name: String, +} +impl Handshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Create a new handshake. + pub fn new(server_name: String, transport: IO) -> Self { + Self { + server_name, + transport, + } + } + + /// Run the handshake to completion. + pub async fn finish(self) -> io::Result> { + self.finish_core().await.map_err(|e| io::Error::other(e)) + } + + /// Finish the handshake, failing with a native-tls error. + async fn finish_core(self) -> Result, native_tls::Error> { + let connector = native_tls::TlsConnector::new()?; + + let stream = tokio_native_tls::TlsConnector::from(connector) + .connect(&self.server_name, self.transport) + .await?; + Ok(stream) + } +} + +/// A TLS client connection. +pub type Connection = tokio_native_tls::TlsStream; diff --git a/crates/wasi-tls/src/client.rs b/crates/wasi-tls/src/client_rustls.rs similarity index 100% rename from crates/wasi-tls/src/client.rs rename to crates/wasi-tls/src/client_rustls.rs diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 22f2e16ad295..52911dd9904a 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -74,10 +74,21 @@ use wasmtime::component::{HasData, ResourceTable}; pub mod bindings; -mod client; mod host; mod io; +cfg_if::cfg_if! { + if #[cfg(feature = "native-tls")] { + mod client_nativetls; + pub(crate) use client_nativetls as client; + } else if #[cfg(feature = "rustls")] { + mod client_rustls; + pub(crate) use client_rustls as client; + } else { + compile_error!("Either the `rustls` or `native-tls` feature must be enabled."); + } +} + pub use bindings::types::LinkOptions; pub use host::{HostClientConnection, HostClientHandshake, HostFutureClientStreams}; From 708b7eb29e385e4aed5fc82dbcc8d2a7c605f805 Mon Sep 17 00:00:00 2001 From: Dave Bakker Date: Mon, 16 Jun 2025 12:03:10 +0200 Subject: [PATCH 03/10] Improve compatibility of `test_tls_sample_application`: It used to perform a "half-close" after sending the HTTP request. This is a TLS1.3+ feature, though Rustls & OpenSSL already supported it for TLS1.2 and lower. Technically, that makes them non-spec compliant, but they chose to align with the semantics of the underlying TCP connection. I suspect the TLS1.3 spec was updated to match what was already happening in reality. Anyhow, SChannel does not support half-closed connections and so the `read` call after the `close_output+shutdown` failed. I've reordered the test to first perform the HTTP conversation and _then_ do the TLS+TCP teardown. --- crates/test-programs/src/bin/tls_sample_application.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs index 622996e1be97..a5d322896089 100644 --- a/crates/test-programs/src/bin/tls_sample_application.rs +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -8,7 +8,7 @@ const PORT: u16 = 443; fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { let request = - format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n"); + format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"); let net = Network::default(); @@ -25,13 +25,13 @@ fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { tls_output .blocking_write_util(request.as_bytes()) .context("writing http request failed")?; - client_connection - .blocking_close_output(&tls_output) - .context("closing tls connection failed")?; - socket.shutdown(ShutdownType::Send)?; let response = tls_input .blocking_read_to_end() .context("reading http response failed")?; + client_connection + .blocking_close_output(&tls_output) + .context("closing tls connection failed")?; + socket.shutdown(ShutdownType::Both)?; if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") { Ok(()) From d6f4f59b9228bcbe3b3e9e9f4f8b1ea3258c8225 Mon Sep 17 00:00:00 2001 From: Dave Bakker Date: Mon, 16 Jun 2025 12:42:11 +0200 Subject: [PATCH 04/10] Implement flushing on the `AsyncWriteStream`. The `AsyncWriteStream` implementation was copied from the TCP equivalent, which doesn't need flushing. The TLS implementations _do_ maintain an internal buffer, so the `flush` call need to be hooked up. --- crates/wasi-tls/src/io.rs | 46 +++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/crates/wasi-tls/src/io.rs b/crates/wasi-tls/src/io.rs index 399d23eb6763..f2e408b7ef80 100644 --- a/crates/wasi-tls/src/io.rs +++ b/crates/wasi-tls/src/io.rs @@ -264,6 +264,7 @@ where enum WriteState { Ready(IO), Writing(AbortOnDropJoinHandle>), + Flushing(AbortOnDropJoinHandle>), Closing(AbortOnDropJoinHandle>), Closed, Error(io::Error), @@ -306,20 +307,43 @@ where } fn flush(&mut self) -> StreamResult<()> { - // `flush` is a no-op here, as we're not managing any internal buffer. match self { - WriteState::Ready(_) - | WriteState::Writing(_) - | WriteState::Closing(_) - | WriteState::Error(_) => Ok(()), - WriteState::Closed => Err(StreamError::Closed), + // Immediately flush: + WriteState::Ready(_) => { + let WriteState::Ready(mut stream) = std::mem::replace(self, WriteState::Closed) + else { + unreachable!() + }; + *self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move { + stream.flush().await?; + Ok(stream) + })); + } + + // Schedule the flush after the current write has finished: + WriteState::Writing(_) => { + let WriteState::Writing(write) = std::mem::replace(self, WriteState::Closed) else { + unreachable!() + }; + *self = WriteState::Flushing(wasmtime_wasi::runtime::spawn(async move { + let mut stream = write.await?; + stream.flush().await?; + Ok(stream) + })); + } + + WriteState::Flushing(_) | WriteState::Closing(_) | WriteState::Error(_) => {} + WriteState::Closed => return Err(StreamError::Closed), } + + Ok(()) } fn check_write(&mut self) -> StreamResult { match self { WriteState::Ready(_) => Ok(READY_SIZE), WriteState::Writing(_) => Ok(0), + WriteState::Flushing(_) => Ok(0), WriteState::Closing(_) => Ok(0), WriteState::Closed => Err(StreamError::Closed), WriteState::Error(_) => { @@ -341,10 +365,10 @@ where })); } - // Schedule the shutdown after the current write has finished: - WriteState::Writing(write) => { + // Schedule the shutdown after the current operation has finished: + WriteState::Writing(op) | WriteState::Flushing(op) => { *self = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { - let mut stream = write.await?; + let mut stream = op.await?; stream.shutdown().await })); } @@ -358,7 +382,7 @@ where async fn cancel(&mut self) { match std::mem::replace(self, WriteState::Closed) { - WriteState::Writing(task) => _ = task.cancel().await, + WriteState::Writing(task) | WriteState::Flushing(task) => _ = task.cancel().await, WriteState::Closing(task) => _ = task.cancel().await, _ => {} } @@ -366,7 +390,7 @@ where async fn ready(&mut self) { match self { - WriteState::Writing(task) => { + WriteState::Writing(task) | WriteState::Flushing(task) => { *self = match task.await { Ok(s) => WriteState::Ready(s), Err(e) => WriteState::Error(e), From 97f4217445473beb85edc7594865938457d8df8d Mon Sep 17 00:00:00 2001 From: badeend Date: Tue, 17 Jun 2025 13:02:22 +0200 Subject: [PATCH 05/10] Switch to a model that uses runtime configuration. --- Cargo.toml | 5 +- crates/cli-flags/src/lib.rs | 2 + .../src/bin/tls_sample_application.rs | 5 +- crates/wasi-tls/Cargo.toml | 2 +- crates/wasi-tls/src/client_nativetls.rs | 43 -------- crates/wasi-tls/src/client_rustls.rs | 55 ---------- crates/wasi-tls/src/host.rs | 33 +++--- crates/wasi-tls/src/lib.rs | 102 +++++++++++++----- crates/wasi-tls/src/providers/mod.rs | 20 ++++ crates/wasi-tls/src/providers/native_tls.rs | 46 ++++++++ crates/wasi-tls/src/providers/rustls.rs | 51 +++++++++ crates/wasi-tls/tests/main.rs | 66 ++++++++---- src/commands/run.rs | 36 ++++++- 13 files changed, 304 insertions(+), 162 deletions(-) delete mode 100644 crates/wasi-tls/src/client_nativetls.rs delete mode 100644 crates/wasi-tls/src/client_rustls.rs create mode 100644 crates/wasi-tls/src/providers/mod.rs create mode 100644 crates/wasi-tls/src/providers/native_tls.rs create mode 100644 crates/wasi-tls/src/providers/rustls.rs diff --git a/Cargo.toml b/Cargo.toml index e27110321f11..3670e0d7a39f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -234,7 +234,7 @@ wasmtime-wasi-nn = { path = "crates/wasi-nn", version = "35.0.0" } wasmtime-wasi-config = { path = "crates/wasi-config", version = "35.0.0" } wasmtime-wasi-keyvalue = { path = "crates/wasi-keyvalue", version = "35.0.0" } wasmtime-wasi-threads = { path = "crates/wasi-threads", version = "35.0.0" } -wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "35.0.0" } +wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "35.0.0", default-features = false } wasmtime-wast = { path = "crates/wast", version = "=35.0.0" } # Internal Wasmtime-specific crates. @@ -438,6 +438,7 @@ default = [ "wasi-config", "wasi-keyvalue", "wasi-tls", + "wasi-tls-rustls", # Most features of Wasmtime are enabled by default. "wat", @@ -478,6 +479,7 @@ trace-log = ["wasmtime/trace-log"] memory-protection-keys = ["wasmtime-cli-flags/memory-protection-keys"] profile-pulley = ["wasmtime/profile-pulley"] component-model-async = ["wasmtime-cli-flags/component-model-async", "component-model"] +wasi-tls-nativetls = ["wasi-tls", "wasmtime-wasi-tls/nativetls"] # This feature, when enabled, will statically compile out all logging statements # throughout Wasmtime and its dependencies. @@ -490,6 +492,7 @@ disable-logging = ["log/max_level_off", "tracing/max_level_off"] # the internal mapping for what they enable in Wasmtime itself. wasi-nn = ["dep:wasmtime-wasi-nn"] wasi-tls = ["dep:wasmtime-wasi-tls"] +wasi-tls-rustls = ["wasi-tls", "wasmtime-wasi-tls/rustls"] wasi-threads = ["dep:wasmtime-wasi-threads", "threads"] wasi-http = ["component-model", "dep:wasmtime-wasi-http", "dep:tokio", "dep:hyper"] wasi-config = ["dep:wasmtime-wasi-config"] diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index ada037e26343..02f1f1e025fd 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -438,6 +438,8 @@ wasmtime_option_group! { pub tcplisten: Vec, /// Enable support for WASI TLS (Transport Layer Security) imports (experimental) pub tls: Option, + /// Which TLS provider to use for the wasi-tls interface. Either `rustls` or `nativetls`. + pub tls_provider: Option, /// Implement WASI Preview1 using new Preview2 implementation (true, default) or legacy /// implementation (false) pub preview2: Option, diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs index a5d322896089..6fa7a8344262 100644 --- a/crates/test-programs/src/bin/tls_sample_application.rs +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -7,8 +7,9 @@ use test_programs::wasi::tls::types::ClientHandshake; const PORT: u16 = 443; fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { - let request = - format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"); + let request = format!( + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" + ); let net = Network::default(); diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index 967b422f02ad..bd09020b5998 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -14,7 +14,7 @@ workspace = true [features] default = ["rustls"] rustls = ["dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] -native-tls = ["dep:native-tls", "dep:tokio-native-tls"] +nativetls = ["dep:native-tls", "dep:tokio-native-tls"] [dependencies] anyhow = { workspace = true } diff --git a/crates/wasi-tls/src/client_nativetls.rs b/crates/wasi-tls/src/client_nativetls.rs deleted file mode 100644 index a67dff8324dc..000000000000 --- a/crates/wasi-tls/src/client_nativetls.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! A uniform TLS client interface, abstracting away the differences between the -//! `rustls` and `native-tls` implementations. - -use std::io; -use tokio::io::{AsyncRead, AsyncWrite}; - -/// A client TLS handshake configuration object. -/// -/// At the time of writing, there's nothing to configure (yet). -pub struct Handshake { - transport: IO, - server_name: String, -} -impl Handshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - /// Create a new handshake. - pub fn new(server_name: String, transport: IO) -> Self { - Self { - server_name, - transport, - } - } - - /// Run the handshake to completion. - pub async fn finish(self) -> io::Result> { - self.finish_core().await.map_err(|e| io::Error::other(e)) - } - - /// Finish the handshake, failing with a native-tls error. - async fn finish_core(self) -> Result, native_tls::Error> { - let connector = native_tls::TlsConnector::new()?; - - let stream = tokio_native_tls::TlsConnector::from(connector) - .connect(&self.server_name, self.transport) - .await?; - Ok(stream) - } -} - -/// A TLS client connection. -pub type Connection = tokio_native_tls::TlsStream; diff --git a/crates/wasi-tls/src/client_rustls.rs b/crates/wasi-tls/src/client_rustls.rs deleted file mode 100644 index f89691700b70..000000000000 --- a/crates/wasi-tls/src/client_rustls.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! A uniform TLS client interface, abstracting away the differences between the -//! `rustls` and `native-tls` implementations. - -use rustls::pki_types::ServerName; -use std::io; -use std::sync::Arc; -use std::sync::LazyLock; -use tokio::io::{AsyncRead, AsyncWrite}; - -/// A client TLS handshake configuration object. -/// -/// At the time of writing, there's nothing to configure (yet). -pub struct Handshake { - transport: IO, - server_name: String, -} -impl Handshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - /// Create a new handshake. - pub fn new(server_name: String, transport: IO) -> Self { - Self { - server_name, - transport, - } - } - - /// Run the handshake to completion. - pub async fn finish(self) -> io::Result> { - let domain = ServerName::try_from(self.server_name) - .map_err(|_| io::Error::other("invalid server name"))?; - - let stream = tokio_rustls::TlsConnector::from(Self::client_config()) - .connect(domain, self.transport) - .await?; - Ok(stream) - } - - fn client_config() -> Arc { - static CONFIG: LazyLock> = LazyLock::new(|| { - let roots = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - let config = rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_no_client_auth(); - Arc::new(config) - }); - Arc::clone(&CONFIG) - } -} - -/// A TLS client connection. -pub type Connection = tokio_rustls::client::TlsStream; diff --git a/crates/wasi-tls/src/host.rs b/crates/wasi-tls/src/host.rs index 784fa3257ad9..24b21ac14631 100644 --- a/crates/wasi-tls/src/host.rs +++ b/crates/wasi-tls/src/host.rs @@ -5,22 +5,22 @@ use wasmtime_wasi::p2::Pollable; use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable, IoError}; use crate::{ - WasiTlsCtx, bindings, + TlsStream, TlsTransport, WasiTls, bindings, io::{ AsyncReadStream, AsyncWriteStream, FutureOutput, WasiFuture, WasiStreamReader, WasiStreamWriter, }, }; -impl<'a> bindings::types::Host for WasiTlsCtx<'a> {} - -/// The underlying transport. Typically, this is a TCP input+output stream. -type Transport = tokio::io::Join; +impl<'a> bindings::types::Host for WasiTls<'a> {} /// Represents the ClientHandshake which will be used to configure the handshake -pub struct HostClientHandshake(crate::client::Handshake); +pub struct HostClientHandshake { + server_name: String, + transport: Box, +} -impl<'a> bindings::types::HostClientHandshake for WasiTlsCtx<'a> { +impl<'a> bindings::types::HostClientHandshake for WasiTls<'a> { fn new( &mut self, server_name: String, @@ -33,9 +33,11 @@ impl<'a> bindings::types::HostClientHandshake for WasiTlsCtx<'a> { let reader = WasiStreamReader::new(input); let writer = WasiStreamWriter::new(output); let transport = tokio::io::join(reader, writer); - let handshake = crate::client::Handshake::new(server_name, transport); - Ok(self.table.push(HostClientHandshake(handshake))?) + Ok(self.table.push(HostClientHandshake { + server_name, + transport: Box::new(transport) as Box, + })?) } fn finish( @@ -44,8 +46,13 @@ impl<'a> bindings::types::HostClientHandshake for WasiTlsCtx<'a> { ) -> wasmtime::Result> { let handshake = self.table.delete(this)?; + let connect = self + .ctx + .provider + .connect(handshake.server_name, handshake.transport); + let future = HostFutureClientStreams(WasiFuture::spawn(async move { - let tls_stream = handshake.0.finish().await?; + let tls_stream = connect.await?; let (rx, tx) = tokio::io::split(tls_stream); let write_stream = AsyncWriteStream::new(tx); @@ -78,7 +85,7 @@ impl Pollable for HostFutureClientStreams { } } -impl<'a> bindings::types::HostFutureClientStreams for WasiTlsCtx<'a> { +impl<'a> bindings::types::HostFutureClientStreams for WasiTls<'a> { fn subscribe( &mut self, this: Resource, @@ -134,10 +141,10 @@ impl<'a> bindings::types::HostFutureClientStreams for WasiTlsCtx<'a> { /// Represents the client connection and used to shut down the tls stream pub struct HostClientConnection( - crate::io::AsyncWriteStream>>, + crate::io::AsyncWriteStream>>, ); -impl<'a> bindings::types::HostClientConnection for WasiTlsCtx<'a> { +impl<'a> bindings::types::HostClientConnection for WasiTls<'a> { fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { self.table.get_mut(&this)?.0.close() } diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 52911dd9904a..feddd6c65217 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -13,11 +13,12 @@ //! component::{Linker, ResourceTable}, //! Store, Engine, Result, Config //! }; -//! use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; +//! use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; //! //! struct Ctx { //! table: ResourceTable, //! wasi_ctx: WasiCtx, +//! wasi_tls_ctx: WasiTlsCtx, //! } //! //! impl IoView for Ctx { @@ -41,6 +42,11 @@ //! .inherit_network() //! .allow_ip_name_lookup(true) //! .build(), +//! wasi_tls_ctx: WasiTlsCtxBuilder::new() +//! // Optionally, configure a specific TLS provider: +//! // .provider(Box::new(wasmtime_wasi_tls::RustlsProvider::default())) +//! // .provider(Box::new(wasmtime_wasi_tls::NativeTlsProvider::default())) +//! .build(), //! }; //! //! let mut config = Config::new(); @@ -56,7 +62,7 @@ //! let mut opts = LinkOptions::default(); //! opts.tls(true); //! wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { -//! WasiTlsCtx::new(&mut h.table) +//! WasiTls::new(&h.wasi_tls_ctx, &mut h.table) //! })?; //! //! // ... use `linker` to instantiate within `store` ... @@ -71,36 +77,28 @@ #![doc(test(attr(deny(warnings))))] #![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] +use tokio::io::{AsyncRead, AsyncWrite}; use wasmtime::component::{HasData, ResourceTable}; pub mod bindings; mod host; mod io; - -cfg_if::cfg_if! { - if #[cfg(feature = "native-tls")] { - mod client_nativetls; - pub(crate) use client_nativetls as client; - } else if #[cfg(feature = "rustls")] { - mod client_rustls; - pub(crate) use client_rustls as client; - } else { - compile_error!("Either the `rustls` or `native-tls` feature must be enabled."); - } -} +mod providers; pub use bindings::types::LinkOptions; pub use host::{HostClientConnection, HostClientHandshake, HostFutureClientStreams}; +pub use providers::*; -/// Wasi TLS context needed for internal `wasi-tls` state -pub struct WasiTlsCtx<'a> { +/// Capture the state necessary for use in the `wasi-tls` API implementation. +pub struct WasiTls<'a> { + ctx: &'a WasiTlsCtx, table: &'a mut ResourceTable, } -impl<'a> WasiTlsCtx<'a> { +impl<'a> WasiTls<'a> { /// Create a new Wasi TLS context - pub fn new(table: &'a mut ResourceTable) -> Self { - Self { table } + pub fn new(ctx: &'a WasiTlsCtx, table: &'a mut ResourceTable) -> Self { + Self { ctx, table } } } @@ -108,14 +106,70 @@ impl<'a> WasiTlsCtx<'a> { pub fn add_to_linker( l: &mut wasmtime::component::Linker, opts: &mut LinkOptions, - f: fn(&mut T) -> WasiTlsCtx<'_>, + f: fn(&mut T) -> WasiTls<'_>, ) -> anyhow::Result<()> { - bindings::types::add_to_linker::<_, WasiTls>(l, &opts, f)?; + bindings::types::add_to_linker::<_, HasWasiTls>(l, &opts, f)?; Ok(()) } -struct WasiTls; +struct HasWasiTls; +impl HasData for HasWasiTls { + type Data<'a> = WasiTls<'a>; +} -impl HasData for WasiTls { - type Data<'a> = WasiTlsCtx<'a>; +/// Builder-style structure used to create a [`WasiTlsCtx`]. +pub struct WasiTlsCtxBuilder { + provider: Box, } + +impl WasiTlsCtxBuilder { + /// Creates a builder for a new context with default parameters set. + pub fn new() -> Self { + Default::default() + } + + /// Sets the TLS provider to use for this context. + pub fn provider(mut self, provider: Box) -> Self { + self.provider = provider; + self + } + + /// Uses the configured context so far to construct the final [`WasiTlsCtx`]. + pub fn build(self) -> WasiTlsCtx { + WasiTlsCtx { + provider: self.provider, + } + } +} +impl Default for WasiTlsCtxBuilder { + fn default() -> Self { + Self { + provider: Box::new(DefaultProvider::default()), + } + } +} + +/// Wasi TLS context needed for internal `wasi-tls` state. +pub struct WasiTlsCtx { + pub(crate) provider: Box, +} + +/// The data stream that carries the encrypted TLS data. +/// Typically this is a TCP stream. +pub trait TlsTransport: AsyncRead + AsyncWrite + Send + Unpin + 'static {} +impl TlsTransport for T {} + +/// A TLS connection. +pub trait TlsStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} + +/// A TLS implementation. +pub trait TlsProvider: Send + Sync + 'static { + /// Set up a client TLS connection using the provided `server_name` and `transport`. + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>>; +} + +pub(crate) type BoxFuture = std::pin::Pin + Send>>; diff --git a/crates/wasi-tls/src/providers/mod.rs b/crates/wasi-tls/src/providers/mod.rs new file mode 100644 index 000000000000..97d239a1906a --- /dev/null +++ b/crates/wasi-tls/src/providers/mod.rs @@ -0,0 +1,20 @@ +//! The available TLS providers. + +#[cfg(feature = "nativetls")] +mod native_tls; +#[cfg(feature = "nativetls")] +pub use native_tls::*; +#[cfg(feature = "rustls")] +mod rustls; +#[cfg(feature = "rustls")] +pub use rustls::*; + +cfg_if::cfg_if! { + if #[cfg(feature = "rustls")] { + pub use RustlsProvider as DefaultProvider; + } else if #[cfg(feature = "nativetls")] { + pub use NativeTlsProvider as DefaultProvider; + } else { + compile_error!("At least one TLS provider must be enabled."); + } +} diff --git a/crates/wasi-tls/src/providers/native_tls.rs b/crates/wasi-tls/src/providers/native_tls.rs new file mode 100644 index 000000000000..6404aa3b2d36 --- /dev/null +++ b/crates/wasi-tls/src/providers/native_tls.rs @@ -0,0 +1,46 @@ +//! The `native_tls` provider. + +use std::io; + +use crate::{BoxFuture, TlsProvider, TlsStream, TlsTransport}; + +type NativeTlsStream = tokio_native_tls::TlsStream>; + +impl crate::TlsStream for NativeTlsStream {} + +/// The `native_tls` provider. +pub struct NativeTlsProvider { + _priv: (), +} + +impl TlsProvider for NativeTlsProvider { + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>> { + async fn connect_impl( + server_name: String, + transport: Box, + ) -> Result { + let connector = native_tls::TlsConnector::new()?; + let stream = tokio_native_tls::TlsConnector::from(connector) + .connect(&server_name, transport) + .await?; + Ok(stream) + } + + Box::pin(async move { + let stream = connect_impl(server_name, transport) + .await + .map_err(|e| io::Error::other(e))?; + Ok(Box::new(stream) as Box) + }) + } +} + +impl Default for NativeTlsProvider { + fn default() -> Self { + Self { _priv: () } + } +} diff --git a/crates/wasi-tls/src/providers/rustls.rs b/crates/wasi-tls/src/providers/rustls.rs new file mode 100644 index 000000000000..0d8fcd2cf22a --- /dev/null +++ b/crates/wasi-tls/src/providers/rustls.rs @@ -0,0 +1,51 @@ +//! The `rustls` provider. + +use rustls::pki_types::ServerName; +use std::io; +use std::sync::{Arc, LazyLock}; + +use crate::{BoxFuture, TlsProvider, TlsStream, TlsTransport}; + +impl crate::TlsStream for tokio_rustls::client::TlsStream> {} + +/// The `rustls` provider. +pub struct RustlsProvider { + client_config: Arc, +} + +impl TlsProvider for RustlsProvider { + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>> { + let client_config = Arc::clone(&self.client_config); + Box::pin(async move { + let domain = ServerName::try_from(server_name) + .map_err(|_| io::Error::other("invalid server name"))?; + + let stream = tokio_rustls::TlsConnector::from(client_config) + .connect(domain, transport) + .await?; + Ok(Box::new(stream) as Box) + }) + } +} + +impl Default for RustlsProvider { + fn default() -> Self { + static CONFIG: LazyLock> = LazyLock::new(|| { + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + Arc::new(config) + }); + + Self { + client_config: Arc::clone(&CONFIG), + } + } +} diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/main.rs index 21fedacb6492..25e60251bc14 100644 --- a/crates/wasi-tls/tests/main.rs +++ b/crates/wasi-tls/tests/main.rs @@ -1,15 +1,15 @@ use anyhow::{Result, anyhow}; -use test_programs_artifacts::{TLS_SAMPLE_APPLICATION_COMPONENT, foreach_tls}; use wasmtime::{ Store, component::{Component, Linker, ResourceTable}, }; use wasmtime_wasi::p2::{IoView, WasiCtx, WasiCtxBuilder, WasiView, bindings::Command}; -use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; +use wasmtime_wasi_tls::{LinkOptions, TlsProvider, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; struct Ctx { table: ResourceTable, wasi_ctx: WasiCtx, + wasi_tls_ctx: WasiTlsCtx, } impl IoView for Ctx { @@ -23,7 +23,17 @@ impl WasiView for Ctx { } } -async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { +async fn run_test(provider: Box, path: &str) -> Result<()> { + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: WasiTlsCtxBuilder::new().provider(provider).build(), + }; + let engine = test_programs_artifacts::engine(|config| { config.async_support(true); }); @@ -35,7 +45,7 @@ async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { let mut opts = LinkOptions::default(); opts.tls(true); wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { - WasiTlsCtx::new(&mut h.table) + WasiTls::new(&h.wasi_tls_ctx, &mut h.table) })?; let command = Command::instantiate_async(&mut store, &component, &linker).await?; @@ -46,27 +56,41 @@ async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { .map_err(|()| anyhow!("command returned with failing exit status")) } -macro_rules! assert_test_exists { - ($name:ident) => { - #[expect(unused_imports, reason = "just here to assert it exists")] - use self::$name as _; +macro_rules! test_case { + ($provider:ident, $name:ident) => { + #[tokio::test(flavor = "multi_thread")] + async fn $name() -> anyhow::Result<()> { + super::$name(Box::new(wasmtime_wasi_tls::$provider::default())).await + } }; } -foreach_tls!(assert_test_exists); +#[cfg(feature = "rustls")] +mod rustls { + macro_rules! rustls_test_case { + ($name:ident) => { + test_case!(RustlsProvider, $name); + }; + } + + test_programs_artifacts::foreach_tls!(rustls_test_case); +} + +#[cfg(feature = "nativetls")] +mod native_tls { + macro_rules! native_tls_test_case { + ($name:ident) => { + test_case!(NativeTlsProvider, $name); + }; + } + + test_programs_artifacts::foreach_tls!(native_tls_test_case); +} -#[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_wasi( - TLS_SAMPLE_APPLICATION_COMPONENT, - Ctx { - table: ResourceTable::new(), - wasi_ctx: WasiCtxBuilder::new() - .inherit_stderr() - .inherit_network() - .allow_ip_name_lookup(true) - .build(), - }, +async fn tls_sample_application(provider: Box) -> Result<()> { + run_test( + provider, + test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT, ) .await } diff --git a/src/commands/run.rs b/src/commands/run.rs index d71a685b72a2..d431cafdca48 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -33,7 +33,7 @@ use wasmtime_wasi_http::{ use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder}; #[cfg(feature = "wasi-tls")] -use wasmtime_wasi_tls::WasiTlsCtx; +use wasmtime_wasi_tls::{WasiTls, WasiTlsCtx}; fn parse_preloads(s: &str) -> Result<(String, PathBuf)> { let parts: Vec<&str> = s.splitn(2, '=').collect(); @@ -993,11 +993,17 @@ impl RunCommand { h.preview2_ctx.as_mut().expect("wasip2 is not configured"); let preview2_ctx = Arc::get_mut(preview2_ctx).unwrap().get_mut().unwrap(); - WasiTlsCtx::new(preview2_ctx.table()) + WasiTls::new( + Arc::get_mut(h.wasi_tls.as_mut().unwrap()).unwrap(), + preview2_ctx.table(), + ) })?; + self.set_wasi_tls_ctx(store)?; } } } + } else if self.run.common.wasi.tls_provider.is_some() { + bail!("`tls-provider` option requires `tls` to be enabled."); } Ok(()) @@ -1059,6 +1065,30 @@ impl RunCommand { Ok(()) } + #[cfg(all(feature = "wasi-tls", feature = "component-model",))] + fn set_wasi_tls_ctx(&self, store: &mut Store) -> Result<()> { + use wasmtime_wasi_tls::*; + + let provider_name = self.run.common.wasi.tls_provider.as_deref(); + let provider: Box = match provider_name { + None => Box::new(DefaultProvider::default()), + #[cfg(feature = "wasi-tls-rustls")] + Some("rustls") => Box::new(RustlsProvider::default()), + #[cfg(feature = "wasi-tls-nativetls")] + Some("nativetls") => Box::new(NativeTlsProvider::default()), + Some(p) => { + bail!( + "Unknown TLS provider: {p}. Either the option does not exist or the binary is not compiled with this feature.", + ); + } + }; + + let ctx = WasiTlsCtxBuilder::new().provider(provider).build(); + + store.data_mut().wasi_tls = Some(Arc::new(ctx)); + Ok(()) + } + #[cfg(feature = "wasi-nn")] fn collect_preloaded_nn_graphs( &self, @@ -1105,6 +1135,8 @@ struct Host { wasi_config: Option>, #[cfg(feature = "wasi-keyvalue")] wasi_keyvalue: Option>, + #[cfg(feature = "wasi-tls")] + wasi_tls: Option>, } impl Host { From 4040567f21929a1d672ffdb1a587c74d5603a035 Mon Sep 17 00:00:00 2001 From: badeend Date: Thu, 19 Jun 2025 10:50:41 +0200 Subject: [PATCH 06/10] Check wasmtime-wasi-tls features in isolation. --- .github/workflows/main.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 11fb15c4bc37..7cd3c3855557 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -417,6 +417,11 @@ jobs: -p wasmtime-c-api --no-default-features -p wasmtime-c-api --no-default-features --features wat -p wasmtime-c-api --no-default-features --features wasi + + - name: wasmtime-wasi-tls + checks: | + -p wasmtime-wasi-tls --no-default-features --features rustls + -p wasmtime-wasi-tls --no-default-features --features nativetls runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 From 240f5619541ae8ad0fa61cbe89831cd0d45c2995 Mon Sep 17 00:00:00 2001 From: badeend Date: Thu, 19 Jun 2025 21:08:14 +0200 Subject: [PATCH 07/10] Clarify intended use of `WasiTlsCtxBuilder::provider` --- crates/wasi-tls/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index feddd6c65217..21d44398c821 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -129,6 +129,11 @@ impl WasiTlsCtxBuilder { } /// Sets the TLS provider to use for this context. + /// + /// By default, this is set to the [`DefaultProvider`] which is picked at + /// compile time based on feature flags. If this crate is compiled with + /// multiple TLS providers, this method can be used to specify the provider + /// at runtime. pub fn provider(mut self, provider: Box) -> Self { self.provider = provider; self From 2f235183d5c53dc0cd61a1dddb4b4c3839b877a2 Mon Sep 17 00:00:00 2001 From: badeend Date: Thu, 26 Jun 2025 21:26:44 +0200 Subject: [PATCH 08/10] Split off native-tls implementation into separate crate --- .github/workflows/main.yml | 5 -- Cargo.lock | 19 ++++- Cargo.toml | 7 +- crates/cli-flags/src/lib.rs | 2 - crates/wasi-tls-nativetls/Cargo.toml | 26 ++++++ crates/wasi-tls-nativetls/src/lib.rs | 82 +++++++++++++++++++ crates/wasi-tls-nativetls/tests/main.rs | 72 ++++++++++++++++ crates/wasi-tls/Cargo.toml | 14 +--- crates/wasi-tls/src/lib.rs | 18 ++-- crates/wasi-tls/src/providers/mod.rs | 20 ----- crates/wasi-tls/src/providers/native_tls.rs | 46 ----------- crates/wasi-tls/src/{providers => }/rustls.rs | 0 crates/wasi-tls/tests/main.rs | 47 +++-------- scripts/publish.rs | 2 + src/commands/run.rs | 30 +------ 15 files changed, 225 insertions(+), 165 deletions(-) create mode 100644 crates/wasi-tls-nativetls/Cargo.toml create mode 100644 crates/wasi-tls-nativetls/src/lib.rs create mode 100644 crates/wasi-tls-nativetls/tests/main.rs delete mode 100644 crates/wasi-tls/src/providers/mod.rs delete mode 100644 crates/wasi-tls/src/providers/native_tls.rs rename crates/wasi-tls/src/{providers => }/rustls.rs (100%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ba9787a7eff8..9baac1eb0bde 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -393,11 +393,6 @@ jobs: -p wasmtime-c-api --no-default-features -p wasmtime-c-api --no-default-features --features wat -p wasmtime-c-api --no-default-features --features wasi - - - name: wasmtime-wasi-tls - checks: | - -p wasmtime-wasi-tls --no-default-features --features rustls - -p wasmtime-wasi-tls --no-default-features --features nativetls runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/Cargo.lock b/Cargo.lock index 25122fef163e..825c72d74a57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4450,6 +4450,7 @@ dependencies = [ "wasmtime-wasi-nn", "wasmtime-wasi-threads", "wasmtime-wasi-tls", + "wasmtime-wasi-tls-nativetls", "wasmtime-wast", "wast 235.0.0", "wat", @@ -4930,19 +4931,31 @@ version = "35.0.0" dependencies = [ "anyhow", "bytes", - "cfg-if", "futures", - "native-tls", "rustls 0.22.4", "test-programs-artifacts", "tokio", - "tokio-native-tls", "tokio-rustls", "wasmtime", "wasmtime-wasi", "webpki-roots", ] +[[package]] +name = "wasmtime-wasi-tls-nativetls" +version = "35.0.0" +dependencies = [ + "anyhow", + "futures", + "native-tls", + "test-programs-artifacts", + "tokio", + "tokio-native-tls", + "wasmtime", + "wasmtime-wasi", + "wasmtime-wasi-tls", +] + [[package]] name = "wasmtime-wast" version = "35.0.0" diff --git a/Cargo.toml b/Cargo.toml index 199c729ba4fd..6ee50b1ec63c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ wasmtime-wasi = { workspace = true, default-features = true, optional = true } wasmtime-wasi-nn = { workspace = true, optional = true } wasmtime-wasi-config = { workspace = true, optional = true } wasmtime-wasi-tls = { workspace = true, optional = true } +wasmtime-wasi-tls-nativetls = { workspace = true, optional = true } wasmtime-wasi-keyvalue = { workspace = true, optional = true } wasmtime-wasi-threads = { workspace = true, optional = true } wasmtime-wasi-http = { workspace = true, optional = true } @@ -234,7 +235,8 @@ wasmtime-wasi-nn = { path = "crates/wasi-nn", version = "35.0.0" } wasmtime-wasi-config = { path = "crates/wasi-config", version = "35.0.0" } wasmtime-wasi-keyvalue = { path = "crates/wasi-keyvalue", version = "35.0.0" } wasmtime-wasi-threads = { path = "crates/wasi-threads", version = "35.0.0" } -wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "35.0.0", default-features = false } +wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "35.0.0" } +wasmtime-wasi-tls-nativetls = { path = "crates/wasi-tls-nativetls", version = "35.0.0" } wasmtime-wast = { path = "crates/wast", version = "=35.0.0" } # Internal Wasmtime-specific crates. @@ -439,7 +441,6 @@ default = [ "wasi-config", "wasi-keyvalue", "wasi-tls", - "wasi-tls-rustls", # Most features of Wasmtime are enabled by default. "wat", @@ -480,7 +481,6 @@ trace-log = ["wasmtime/trace-log"] memory-protection-keys = ["wasmtime-cli-flags/memory-protection-keys"] profile-pulley = ["wasmtime/profile-pulley"] component-model-async = ["wasmtime-cli-flags/component-model-async", "component-model"] -wasi-tls-nativetls = ["wasi-tls", "wasmtime-wasi-tls/nativetls"] # This feature, when enabled, will statically compile out all logging statements # throughout Wasmtime and its dependencies. @@ -493,7 +493,6 @@ disable-logging = ["log/max_level_off", "tracing/max_level_off"] # the internal mapping for what they enable in Wasmtime itself. wasi-nn = ["dep:wasmtime-wasi-nn"] wasi-tls = ["dep:wasmtime-wasi-tls"] -wasi-tls-rustls = ["wasi-tls", "wasmtime-wasi-tls/rustls"] wasi-threads = ["dep:wasmtime-wasi-threads", "threads"] wasi-http = ["component-model", "dep:wasmtime-wasi-http", "dep:tokio", "dep:hyper"] wasi-config = ["dep:wasmtime-wasi-config"] diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index 02f1f1e025fd..ada037e26343 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -438,8 +438,6 @@ wasmtime_option_group! { pub tcplisten: Vec, /// Enable support for WASI TLS (Transport Layer Security) imports (experimental) pub tls: Option, - /// Which TLS provider to use for the wasi-tls interface. Either `rustls` or `nativetls`. - pub tls_provider: Option, /// Implement WASI Preview1 using new Preview2 implementation (true, default) or legacy /// implementation (false) pub preview2: Option, diff --git a/crates/wasi-tls-nativetls/Cargo.toml b/crates/wasi-tls-nativetls/Cargo.toml new file mode 100644 index 000000000000..6449ec8a7fa5 --- /dev/null +++ b/crates/wasi-tls-nativetls/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "wasmtime-wasi-tls-nativetls" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +repository = "https://github.com/bytecodealliance/wasmtime" +license = "Apache-2.0 WITH LLVM-exception" +description = "Wasmtime implementation of the wasi-tls API, using native-tls for TLS support." + +[lints] +workspace = true + +[dependencies] +wasmtime-wasi-tls = { workspace = true } +tokio = { workspace = true } +tokio-native-tls = { workspace = true } +native-tls = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +test-programs-artifacts = { workspace = true } +wasmtime = { workspace = true, features = ["runtime", "component-model"] } +wasmtime-wasi = { workspace = true } +tokio = { workspace = true, features = ["macros"] } +futures = { workspace = true } diff --git a/crates/wasi-tls-nativetls/src/lib.rs b/crates/wasi-tls-nativetls/src/lib.rs new file mode 100644 index 000000000000..488614512dcb --- /dev/null +++ b/crates/wasi-tls-nativetls/src/lib.rs @@ -0,0 +1,82 @@ +//! The `native_tls` provider. + +use std::{io, pin::pin}; + +use wasmtime_wasi_tls::{TlsProvider, TlsStream, TlsTransport}; + +type BoxFuture = std::pin::Pin + Send>>; + +/// The `native_tls` provider. +pub struct NativeTlsProvider { + _priv: (), +} + +impl TlsProvider for NativeTlsProvider { + fn connect( + &self, + server_name: String, + transport: Box, + ) -> BoxFuture>> { + async fn connect_impl( + server_name: String, + transport: Box, + ) -> Result { + let connector = native_tls::TlsConnector::new()?; + let stream = tokio_native_tls::TlsConnector::from(connector) + .connect(&server_name, transport) + .await?; + Ok(NativeTlsStream(stream)) + } + + Box::pin(async move { + let stream = connect_impl(server_name, transport) + .await + .map_err(|e| io::Error::other(e))?; + Ok(Box::new(stream) as Box) + }) + } +} + +impl Default for NativeTlsProvider { + fn default() -> Self { + Self { _priv: () } + } +} + +struct NativeTlsStream(tokio_native_tls::TlsStream>); + +impl TlsStream for NativeTlsStream {} + +impl tokio::io::AsyncRead for NativeTlsStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for NativeTlsStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.as_mut().0).poll_shutdown(cx) + } +} diff --git a/crates/wasi-tls-nativetls/tests/main.rs b/crates/wasi-tls-nativetls/tests/main.rs new file mode 100644 index 000000000000..d86fe08f6371 --- /dev/null +++ b/crates/wasi-tls-nativetls/tests/main.rs @@ -0,0 +1,72 @@ +use anyhow::{Result, anyhow}; +use wasmtime::{ + Store, + component::{Component, Linker, ResourceTable}, +}; +use wasmtime_wasi::p2::{IoView, WasiCtx, WasiCtxBuilder, WasiView, bindings::Command}; +use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; + +struct Ctx { + table: ResourceTable, + wasi_ctx: WasiCtx, + wasi_tls_ctx: WasiTlsCtx, +} + +impl IoView for Ctx { + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} +impl WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi_ctx + } +} + +async fn run_test(path: &str) -> Result<()> { + let provider = Box::new(wasmtime_wasi_tls_nativetls::NativeTlsProvider::default()); + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: WasiTlsCtxBuilder::new().provider(provider).build(), + }; + + let engine = test_programs_artifacts::engine(|config| { + config.async_support(true); + }); + let mut store = Store::new(&engine, ctx); + let component = Component::from_file(&engine, path)?; + + let mut linker = Linker::new(&engine); + wasmtime_wasi::p2::add_to_linker_async(&mut linker)?; + let mut opts = LinkOptions::default(); + opts.tls(true); + wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { + WasiTls::new(&h.wasi_tls_ctx, &mut h.table) + })?; + + let command = Command::instantiate_async(&mut store, &component, &linker).await?; + command + .wasi_cli_run() + .call_run(&mut store) + .await? + .map_err(|()| anyhow!("command returned with failing exit status")) +} + +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; + }; +} + +test_programs_artifacts::foreach_tls!(assert_test_exists); + +#[tokio::test(flavor = "multi_thread")] +async fn tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +} diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index bd09020b5998..be715c5b6d1b 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -11,11 +11,6 @@ description = "Wasmtime implementation of the wasi-tls API" [lints] workspace = true -[features] -default = ["rustls"] -rustls = ["dep:rustls", "dep:tokio-rustls", "dep:webpki-roots"] -nativetls = ["dep:native-tls", "dep:tokio-native-tls"] - [dependencies] anyhow = { workspace = true } bytes = { workspace = true } @@ -27,13 +22,10 @@ tokio = { workspace = true, features = [ ] } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } -cfg-if = { workspace = true } -tokio-rustls = { workspace = true, optional = true } -rustls = { workspace = true, optional = true } -webpki-roots = { workspace = true, optional = true } -tokio-native-tls = { workspace = true, optional = true } -native-tls = { workspace = true, optional = true } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +webpki-roots = { workspace = true } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 21d44398c821..28c96d8a4626 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -43,9 +43,8 @@ //! .allow_ip_name_lookup(true) //! .build(), //! wasi_tls_ctx: WasiTlsCtxBuilder::new() -//! // Optionally, configure a specific TLS provider: -//! // .provider(Box::new(wasmtime_wasi_tls::RustlsProvider::default())) -//! // .provider(Box::new(wasmtime_wasi_tls::NativeTlsProvider::default())) +//! // Optionally, configure a different TLS provider: +//! // .provider(Box::new(wasmtime_wasi_tls_nativetls::NativeTlsProvider::default())) //! .build(), //! }; //! @@ -83,11 +82,11 @@ use wasmtime::component::{HasData, ResourceTable}; pub mod bindings; mod host; mod io; -mod providers; +mod rustls; pub use bindings::types::LinkOptions; pub use host::{HostClientConnection, HostClientHandshake, HostFutureClientStreams}; -pub use providers::*; +pub use rustls::RustlsProvider; /// Capture the state necessary for use in the `wasi-tls` API implementation. pub struct WasiTls<'a> { @@ -128,12 +127,9 @@ impl WasiTlsCtxBuilder { Default::default() } - /// Sets the TLS provider to use for this context. + /// Configure the TLS provider to use for this context. /// - /// By default, this is set to the [`DefaultProvider`] which is picked at - /// compile time based on feature flags. If this crate is compiled with - /// multiple TLS providers, this method can be used to specify the provider - /// at runtime. + /// By default, this is set to the [`RustlsProvider`]. pub fn provider(mut self, provider: Box) -> Self { self.provider = provider; self @@ -149,7 +145,7 @@ impl WasiTlsCtxBuilder { impl Default for WasiTlsCtxBuilder { fn default() -> Self { Self { - provider: Box::new(DefaultProvider::default()), + provider: Box::new(RustlsProvider::default()), } } } diff --git a/crates/wasi-tls/src/providers/mod.rs b/crates/wasi-tls/src/providers/mod.rs deleted file mode 100644 index 97d239a1906a..000000000000 --- a/crates/wasi-tls/src/providers/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -//! The available TLS providers. - -#[cfg(feature = "nativetls")] -mod native_tls; -#[cfg(feature = "nativetls")] -pub use native_tls::*; -#[cfg(feature = "rustls")] -mod rustls; -#[cfg(feature = "rustls")] -pub use rustls::*; - -cfg_if::cfg_if! { - if #[cfg(feature = "rustls")] { - pub use RustlsProvider as DefaultProvider; - } else if #[cfg(feature = "nativetls")] { - pub use NativeTlsProvider as DefaultProvider; - } else { - compile_error!("At least one TLS provider must be enabled."); - } -} diff --git a/crates/wasi-tls/src/providers/native_tls.rs b/crates/wasi-tls/src/providers/native_tls.rs deleted file mode 100644 index 6404aa3b2d36..000000000000 --- a/crates/wasi-tls/src/providers/native_tls.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! The `native_tls` provider. - -use std::io; - -use crate::{BoxFuture, TlsProvider, TlsStream, TlsTransport}; - -type NativeTlsStream = tokio_native_tls::TlsStream>; - -impl crate::TlsStream for NativeTlsStream {} - -/// The `native_tls` provider. -pub struct NativeTlsProvider { - _priv: (), -} - -impl TlsProvider for NativeTlsProvider { - fn connect( - &self, - server_name: String, - transport: Box, - ) -> BoxFuture>> { - async fn connect_impl( - server_name: String, - transport: Box, - ) -> Result { - let connector = native_tls::TlsConnector::new()?; - let stream = tokio_native_tls::TlsConnector::from(connector) - .connect(&server_name, transport) - .await?; - Ok(stream) - } - - Box::pin(async move { - let stream = connect_impl(server_name, transport) - .await - .map_err(|e| io::Error::other(e))?; - Ok(Box::new(stream) as Box) - }) - } -} - -impl Default for NativeTlsProvider { - fn default() -> Self { - Self { _priv: () } - } -} diff --git a/crates/wasi-tls/src/providers/rustls.rs b/crates/wasi-tls/src/rustls.rs similarity index 100% rename from crates/wasi-tls/src/providers/rustls.rs rename to crates/wasi-tls/src/rustls.rs diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/main.rs index 25e60251bc14..3105cee3a517 100644 --- a/crates/wasi-tls/tests/main.rs +++ b/crates/wasi-tls/tests/main.rs @@ -4,7 +4,7 @@ use wasmtime::{ component::{Component, Linker, ResourceTable}, }; use wasmtime_wasi::p2::{IoView, WasiCtx, WasiCtxBuilder, WasiView, bindings::Command}; -use wasmtime_wasi_tls::{LinkOptions, TlsProvider, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; +use wasmtime_wasi_tls::{LinkOptions, WasiTls, WasiTlsCtx, WasiTlsCtxBuilder}; struct Ctx { table: ResourceTable, @@ -23,7 +23,7 @@ impl WasiView for Ctx { } } -async fn run_test(provider: Box, path: &str) -> Result<()> { +async fn run_test(path: &str) -> Result<()> { let ctx = Ctx { table: ResourceTable::new(), wasi_ctx: WasiCtxBuilder::new() @@ -31,7 +31,7 @@ async fn run_test(provider: Box, path: &str) -> Result<()> { .inherit_network() .allow_ip_name_lookup(true) .build(), - wasi_tls_ctx: WasiTlsCtxBuilder::new().provider(provider).build(), + wasi_tls_ctx: WasiTlsCtxBuilder::new().build(), }; let engine = test_programs_artifacts::engine(|config| { @@ -56,41 +56,16 @@ async fn run_test(provider: Box, path: &str) -> Result<()> { .map_err(|()| anyhow!("command returned with failing exit status")) } -macro_rules! test_case { - ($provider:ident, $name:ident) => { - #[tokio::test(flavor = "multi_thread")] - async fn $name() -> anyhow::Result<()> { - super::$name(Box::new(wasmtime_wasi_tls::$provider::default())).await - } +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; }; } -#[cfg(feature = "rustls")] -mod rustls { - macro_rules! rustls_test_case { - ($name:ident) => { - test_case!(RustlsProvider, $name); - }; - } - - test_programs_artifacts::foreach_tls!(rustls_test_case); -} - -#[cfg(feature = "nativetls")] -mod native_tls { - macro_rules! native_tls_test_case { - ($name:ident) => { - test_case!(NativeTlsProvider, $name); - }; - } - - test_programs_artifacts::foreach_tls!(native_tls_test_case); -} +test_programs_artifacts::foreach_tls!(assert_test_exists); -async fn tls_sample_application(provider: Box) -> Result<()> { - run_test( - provider, - test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT, - ) - .await +#[tokio::test(flavor = "multi_thread")] +async fn tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/scripts/publish.rs b/scripts/publish.rs index dc72e2c3d965..5a9de27bc43e 100644 --- a/scripts/publish.rs +++ b/scripts/publish.rs @@ -80,6 +80,7 @@ const CRATES_TO_PUBLISH: &[&str] = &[ "wasmtime-wasi-keyvalue", "wasmtime-wasi-threads", "wasmtime-wasi-tls", + "wasmtime-wasi-tls-nativetls", "wasmtime-wast", "wasmtime-internal-c-api-macros", "wasmtime-c-api-impl", @@ -99,6 +100,7 @@ const PUBLIC_CRATES: &[&str] = &[ "wasmtime-wasi-io", "wasmtime-wasi", "wasmtime-wasi-tls", + "wasmtime-wasi-tls-nativetls", "wasmtime-wasi-http", "wasmtime-wasi-nn", "wasmtime-wasi-config", diff --git a/src/commands/run.rs b/src/commands/run.rs index d431cafdca48..e4e13d23b7ad 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -998,12 +998,12 @@ impl RunCommand { preview2_ctx.table(), ) })?; - self.set_wasi_tls_ctx(store)?; + + let ctx = wasmtime_wasi_tls::WasiTlsCtxBuilder::new().build(); + store.data_mut().wasi_tls = Some(Arc::new(ctx)); } } } - } else if self.run.common.wasi.tls_provider.is_some() { - bail!("`tls-provider` option requires `tls` to be enabled."); } Ok(()) @@ -1065,30 +1065,6 @@ impl RunCommand { Ok(()) } - #[cfg(all(feature = "wasi-tls", feature = "component-model",))] - fn set_wasi_tls_ctx(&self, store: &mut Store) -> Result<()> { - use wasmtime_wasi_tls::*; - - let provider_name = self.run.common.wasi.tls_provider.as_deref(); - let provider: Box = match provider_name { - None => Box::new(DefaultProvider::default()), - #[cfg(feature = "wasi-tls-rustls")] - Some("rustls") => Box::new(RustlsProvider::default()), - #[cfg(feature = "wasi-tls-nativetls")] - Some("nativetls") => Box::new(NativeTlsProvider::default()), - Some(p) => { - bail!( - "Unknown TLS provider: {p}. Either the option does not exist or the binary is not compiled with this feature.", - ); - } - }; - - let ctx = WasiTlsCtxBuilder::new().provider(provider).build(); - - store.data_mut().wasi_tls = Some(Arc::new(ctx)); - Ok(()) - } - #[cfg(feature = "wasi-nn")] fn collect_preloaded_nn_graphs( &self, From 4f2932ecfde2381ad9b748d6ca0f1bcb57f78d82 Mon Sep 17 00:00:00 2001 From: badeend Date: Thu, 26 Jun 2025 21:43:01 +0200 Subject: [PATCH 09/10] Add a dedicated Ci job to test the `native-tls` backend. prtest:full --- .github/workflows/main.yml | 18 ++++++++++++++++++ Cargo.lock | 1 - Cargo.toml | 2 +- ci/run-tests.py | 4 ++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9baac1eb0bde..8419af629750 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -812,6 +812,23 @@ jobs: # Run the tests! - run: cargo test -p wasmtime-wasi-nn --features ${{ matrix.feature }} + # Test `wasmtime-wasi-tls-nativetls` in its own job. This is because it + # depends on OpenSSL, which is not easily available on all platforms. + test_wasi_tls_nativetls: + name: Test wasi-tls using native-tls provider + needs: determine + if: needs.determine.outputs.run-full + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - uses: ./.github/actions/install-rust + - run: cargo test -p wasmtime-wasi-tls-nativetls + # Test the `wasmtime-fuzzing` crate. Split out from the main tests because # `--all-features` brings in OCaml, which is a pain to get setup for all # targets. @@ -1114,6 +1131,7 @@ jobs: - doc - micro_checks - special_tests + - test_wasi_tls_nativetls - clippy - monolith_checks - platform_checks diff --git a/Cargo.lock b/Cargo.lock index 825c72d74a57..455d40185c33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4450,7 +4450,6 @@ dependencies = [ "wasmtime-wasi-nn", "wasmtime-wasi-threads", "wasmtime-wasi-tls", - "wasmtime-wasi-tls-nativetls", "wasmtime-wast", "wast 235.0.0", "wat", diff --git a/Cargo.toml b/Cargo.toml index 6ee50b1ec63c..c25bbf33c6c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,7 +55,6 @@ wasmtime-wasi = { workspace = true, default-features = true, optional = true } wasmtime-wasi-nn = { workspace = true, optional = true } wasmtime-wasi-config = { workspace = true, optional = true } wasmtime-wasi-tls = { workspace = true, optional = true } -wasmtime-wasi-tls-nativetls = { workspace = true, optional = true } wasmtime-wasi-keyvalue = { workspace = true, optional = true } wasmtime-wasi-threads = { workspace = true, optional = true } wasmtime-wasi-http = { workspace = true, optional = true } @@ -159,6 +158,7 @@ members = [ "crates/test-programs", "crates/wasi-preview1-component-adapter", "crates/wasi-preview1-component-adapter/verify", + "crates/wasi-tls-nativetls", "examples/fib-debug/wasm", "examples/wasm", "examples/tokio/wasm", diff --git a/ci/run-tests.py b/ci/run-tests.py index 3ddb5983a983..8f354788d63a 100755 --- a/ci/run-tests.py +++ b/ci/run-tests.py @@ -7,6 +7,9 @@ # - wasmtime-wasi-nn: mutually-exclusive features that aren't available for all # targets, needs its own CI job. # +# - wasmtime-wasi-tls-nativetls: the openssl dependency does not play nice with +# cross compilation. This crate is tested in a separate CI job. +# # - wasmtime-fuzzing: enabling all features brings in OCaml which is a pain to # configure for all targets, so it has its own CI job. # @@ -21,6 +24,7 @@ args = ['cargo', 'test', '--workspace', '--all-features'] args.append('--exclude=test-programs') args.append('--exclude=wasmtime-wasi-nn') +args.append('--exclude=wasmtime-wasi-tls-nativetls') args.append('--exclude=wasmtime-fuzzing') args.append('--exclude=wasm-spec-interpreter') args.append('--exclude=veri_engine') From cc23e74c2036acf6583f08084d9f9614a3c99743 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Fri, 27 Jun 2025 07:31:08 -0700 Subject: [PATCH 10/10] Update vets --- supply-chain/imports.lock | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index c172179e5b93..ae8f8f94fcd8 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -622,6 +622,13 @@ user-id = 6743 user-login = "epage" user-name = "Ed Page" +[[publisher.core-foundation]] +version = "0.9.3" +when = "2022-02-07" +user-id = 5946 +user-login = "jrmuizel" +user-name = "Jeff Muizelaar" + [[publisher.core-foundation-sys]] version = "0.8.4" when = "2023-04-03" @@ -931,6 +938,13 @@ user-id = 189 user-login = "BurntSushi" user-name = "Andrew Gallant" +[[publisher.openssl-probe]] +version = "0.1.6" +when = "2025-01-23" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.prettyplease]] version = "0.2.31" when = "2025-03-13"