diff --git a/Cargo.lock b/Cargo.lock index a33b9d2..f6617c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,6 +332,7 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tokio-rustls", + "tokio-vsock", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index e966fba..dc53022 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/ attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" } pccs = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" } tokio = { version = "1.50.0", features = ["full"] } +tokio-vsock = "0.7.2" tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] } x509-parser = { version = "0.18.0", features = ["verify"] } x509-parser-016 = { package = "x509-parser", version = "0.16", features = ["verify"] } diff --git a/README.md b/README.md index e613714..3078f13 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,13 @@ Proxy-client to proxy-server connections use TLS 1.3. The server can expose two - `--inner-listen-addr` exposes the inner attested TLS listener. - `--outer-listen-addr` exposes an optional outer nested-TLS listener that wraps the inner session with a regular PKI TLS session. -At least one of these listeners must be configured. If TLS certificate and key files are provided, they apply only to the outer listener, and `--outer-listen-addr` is required. +The same listeners can be exposed over AWS Nitro VSOCK instead of TCP: + +- `--inner-vsock-port` exposes the inner listener over VSOCK. +- `--outer-vsock-port` exposes the outer listener over VSOCK. +- `--inner-vsock-cid` and `--outer-vsock-cid` default to `VMADDR_CID_ANY`. + +At least one of these listeners must be configured. If TLS certificate and key files are provided, they apply only to the outer listener, and an outer TCP or VSOCK listener is required. When the server runs without an outer listener, the inner attested certificate still needs a DNS identity. In that case, use `--inner-certificate-name` to control the certificate name embedded into the inner attested certificate. If an outer certificate is present, the server derives that identity from the outer certificate instead. @@ -86,6 +92,8 @@ On the client side: - default mode connects to the server's outer listener and verifies the outer PKI certificate before entering the inner attested TLS session - `--inner-session-only` connects directly to the inner attested TLS listener +- `--listen-transport vsock --listen-vsock-port ` makes the local client ingress listener use VSOCK instead of TCP +- `--target-transport vsock --target-vsock-cid --target-vsock-port ` makes the client connect to the proxy server over VSOCK; the positional target remains the TLS server name In both modes, attestation is taken from the peer certificate on the inner TLS session, then enforced against the configured measurement policy. @@ -187,7 +195,50 @@ cargo run -- client \ localhost:7001 ``` -In inner-only mode the client does not accept `--tls-ca-certificate`, `--tls-private-key-path`, or `--tls-certificate-path`. +In inner-only mode the client does not accept `--tls-ca-certificate`. `--tls-certificate-path` and `--tls-private-key-path` may be supplied when the server requires client authentication; the certificate identity is used for the generated inner attested client certificate. + +### Nitro VSOCK Examples + +Expose a server inner listener over VSOCK: + +```bash +cargo run -- server \ + --inner-vsock-port 7001 \ + --inner-certificate-name localhost \ + --server-attestation-type none \ + --allowed-remote-attestation-type none \ + 127.0.0.1:8000 +``` + +Connect a client to that VSOCK inner listener: + +```bash +cargo run -- client \ + --listen-addr 127.0.0.1:6000 \ + --target-transport vsock \ + --target-vsock-cid \ + --target-vsock-port 7001 \ + --inner-session-only \ + --client-attestation-type none \ + --allowed-remote-attestation-type none \ + localhost +``` + +Run the proxy client itself with VSOCK ingress, useful when the client runs inside a Nitro enclave and receives requests from the parent instance: + +```bash +cargo run -- client \ + --listen-transport vsock \ + --listen-vsock-port 6000 \ + --target-transport vsock \ + --target-vsock-cid \ + --target-vsock-port 7001 \ + --inner-session-only \ + --tls-private-key-path client.key \ + --tls-certificate-path client.crt \ + --allowed-remote-attestation-type none \ + localhost +``` ## CLI Differences from `cvm-reverse-proxy` @@ -195,7 +246,7 @@ This aims to have a similar command line interface to `cvm-reverse-proxy`, but t - The measurements file path is specified with `--measurements-file` rather than `--server-measurements` or `--client-measurements`. - If no measurements file is specified, `--allowed-remote-attestation-type` must be given. -- The server splits listener configuration into `--inner-listen-addr` and optional `--outer-listen-addr`. +- The server splits listener configuration into inner and outer listeners, each using either TCP `--*-listen-addr` or VSOCK `--*-vsock-port`. - `--log-dcap-quote` logs remote DCAP quotes into `quotes/`. ## Docker diff --git a/flake.nix b/flake.nix index 09781a0..602384c 100644 --- a/flake.nix +++ b/flake.nix @@ -8,43 +8,133 @@ system = "x86_64-linux"; pkgs = import nixpkgs { inherit system; }; + # Both workspace members share a single Cargo.lock, so their dependency + # hashes are identical. Keeping this in one place means a lockfile bump + # only requires updating hashes here. + # + # Note: mock-tdx-0.0.1 appears twice in the lockfile (peg/nitro transitive + # dep and main-branch dev-dep). The peg/nitro rev is shared with + # attestation-0.0.1 (same SHA → same hash). The main-branch entry is + # stripped by cleanedLockFile below, so no hash entry is needed for it. + sharedOutputHashes = { + "attestation-0.0.1" = "sha256-4wa8gP9xQCZZL4JUnb1fNfpwxcahec5SgYZamdqX2h8="; + "attested-tls-0.0.1" = "sha256-4wa8gP9xQCZZL4JUnb1fNfpwxcahec5SgYZamdqX2h8="; + "cc-eventlog-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "cc-eventlog-0.5.8" = "sha256-KEauakj53LrhKTc0yYp5SM8ec0cFNm4YVuHCJYiPQjw="; + "dcap-qvl-0.3.12" = "sha256-rLTp5wIhXRAcBtJb7lfd1TAg7yPRnwa0cBa1YT4LwKU="; + "dstack-attest-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "dstack-types-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "nested-tls-0.0.1" = "sha256-4wa8gP9xQCZZL4JUnb1fNfpwxcahec5SgYZamdqX2h8="; + "pccs-0.0.1" = "sha256-4wa8gP9xQCZZL4JUnb1fNfpwxcahec5SgYZamdqX2h8="; + "ra-tls-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "size-parser-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "tdx-attest-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; + "tdx-attest-0.5.8" = "sha256-KEauakj53LrhKTc0yYp5SM8ec0cFNm4YVuHCJYiPQjw="; + }; + + # nixpkgs importCargoLock creates one symlink per package keyed by + # "-". When two git crates share the same name+version + # (here: mock-tdx-0.0.1 from peg/nitro and from main), the second ln + # follows the first symlink into a read-only store path and fails. + # + # The main-branch entry is a dev-dep of attested-tls-proxy only; since + # doCheck = false it is never compiled. Strip it from the lockfile at + # evaluation time so importCargoLock only ever sees the peg/nitro + # transitive dep (already covered by the attestation-0.0.1 hash above). + cleanedLockFile = builtins.toFile "Cargo.lock" ( + builtins.replaceStrings + [ + # [[package]] block for mock-tdx (main branch). + # The leading \n eats the blank separator line before the block; + # the trailing blank line before peg/nitro mock-tdx is preserved. + "\n[[package]]\nname = \"mock-tdx\"\nversion = \"0.0.1\"\nsource = \"git+https://github.com/flashbots/attested-tls?branch=main#eaa10f0528c8c561273717913596de65cff807b3\"\ndependencies = [\n \"axum\",\n \"dcap-qvl\",\n \"hex\",\n \"p256\",\n \"parity-scale-codec\",\n \"rcgen 0.14.7\",\n \"serde\",\n \"serde-saphyr\",\n \"serde_bytes\",\n \"serde_json\",\n \"sha2\",\n \"time\",\n \"tokio\",\n \"urlencoding\",\n \"x509-parser 0.18.1\",\n \"yasna 0.5.2\",\n]\n" + # Dep reference in the attested-tls-proxy package entry + # (Cargo.lock dep references omit the #rev suffix) + " \"mock-tdx 0.0.1 (git+https://github.com/flashbots/attested-tls?branch=main)\",\n" + ] + [ "" "" ] + (builtins.readFile ./Cargo.lock) + ); + + # Vendor directory built from the cleaned lockfile (no mock-tdx main branch). + sharedCargoDeps = pkgs.rustPlatform.importCargoLock { + lockFile = cleanedLockFile; + outputHashes = sharedOutputHashes; + }; + + # Patch the unpacked source to match the vendor dir: + # cargo reads the source's Cargo.lock at build time and requires it to + # be consistent with what is vendored. + sharedPostUnpack = '' + cp ${cleanedLockFile} "$sourceRoot/Cargo.lock" + chmod u+w "$sourceRoot/Cargo.lock" + sed -i '/^mock-tdx/d' "$sourceRoot/Cargo.toml" + ''; + + sharedBuildInputs = [ pkgs.openssl pkgs.tpm2-tss ]; + sharedNativeBuildInputs = [ pkgs.pkg-config ]; + server = pkgs.rustPlatform.buildRustPackage { pname = "attestation-provider-server"; version = "1.1.1"; src = ./.; - cargoLock = { - lockFile = ./Cargo.lock; - outputHashes = { - "attestation-0.0.1" = "sha256-1I9iQcFNt02fHs8Q18LK2+f8U0TzhfdFz7JvV0mKJUw="; - "attested-tls-0.0.1" = "sha256-1I9iQcFNt02fHs8Q18LK2+f8U0TzhfdFz7JvV0mKJUw="; - "cc-eventlog-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "cc-eventlog-0.5.8" = "sha256-KEauakj53LrhKTc0yYp5SM8ec0cFNm4YVuHCJYiPQjw="; - "dcap-qvl-0.3.12" = "sha256-rLTp5wIhXRAcBtJb7lfd1TAg7yPRnwa0cBa1YT4LwKU="; - "dstack-attest-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "dstack-types-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "nested-tls-0.0.1" = "sha256-1I9iQcFNt02fHs8Q18LK2+f8U0TzhfdFz7JvV0mKJUw="; - "pccs-0.0.1" = "sha256-1I9iQcFNt02fHs8Q18LK2+f8U0TzhfdFz7JvV0mKJUw="; - "ra-tls-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "size-parser-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "tdx-attest-0.5.11" = "sha256-q6Vrlx4N7Ce2EQTQH+0HCSEzFZmY8PzDHxrO8L3kMsQ="; - "tdx-attest-0.5.8" = "sha256-KEauakj53LrhKTc0yYp5SM8ec0cFNm4YVuHCJYiPQjw="; - }; - }; + cargoDeps = sharedCargoDeps; + postUnpack = sharedPostUnpack; cargoBuildFlags = [ "-p" "attestation-provider-server" ]; - cargoHash = "sha256-rLTp5wIhXRAcBtJb7lfd1TAg7yPRnwa0cBa1YT4LwKU="; - nativeBuildInputs = [ pkgs.pkg-config ]; - buildInputs = [ pkgs.openssl pkgs.tpm2-tss ]; + nativeBuildInputs = sharedNativeBuildInputs; + buildInputs = sharedBuildInputs; + + doCheck = false; + }; + + proxy = pkgs.rustPlatform.buildRustPackage { + pname = "attested-tls-proxy"; + version = "1.1.1"; + src = ./.; + + cargoDeps = sharedCargoDeps; + postUnpack = sharedPostUnpack; + cargoBuildFlags = [ "-p" "attested-tls-proxy" ]; + + nativeBuildInputs = sharedNativeBuildInputs; + buildInputs = sharedBuildInputs; doCheck = false; }; - imageRoot = pkgs.buildEnv { + serverImageRoot = pkgs.buildEnv { name = "attestation-provider-server-image-root"; paths = [ server pkgs.cacert ]; pathsToLink = [ "/bin" "/etc/ssl/certs" ]; }; + + proxyImageRoot = pkgs.buildEnv { + name = "attested-tls-proxy-image-root"; + paths = [ proxy pkgs.cacert ]; + pathsToLink = [ "/bin" "/etc/ssl/certs" ]; + }; + + # A single text file at /srv/hello.txt for testing the file server image. + # writeTextDir "srv/hello.txt" produces $out/srv/hello.txt, which buildEnv + # links into /srv/hello.txt inside the image. + testContent = pkgs.writeTextDir "srv/hello.txt" + "Hello from attested-file-server!\n"; + + # Nitro enclaves don't bring up the loopback interface by default. + # The file server binds axum on 127.0.0.1 and the proxy connects back to + # it over loopback, so lo must be up before the binary starts. + fileServerEntrypoint = pkgs.writeShellScriptBin "attested-file-server-start" '' + ${pkgs.iproute2}/bin/ip link set lo up + exec ${proxy}/bin/attested-tls-proxy "$@" + ''; + + fileServerImageRoot = pkgs.buildEnv { + name = "attested-file-server-image-root"; + paths = [ fileServerEntrypoint pkgs.cacert testContent ]; + pathsToLink = [ "/bin" "/etc/ssl/certs" "/srv" ]; + }; in { packages.${system} = { @@ -52,7 +142,7 @@ attestation-provider-server-image = pkgs.dockerTools.buildLayeredImage { name = "attestation-provider-server"; tag = "latest"; - contents = [ imageRoot ]; + contents = [ serverImageRoot ]; config = { Cmd = [ "/bin/attestation-provider-server" @@ -66,6 +156,58 @@ ]; }; }; + + attested-tls-proxy = proxy; + attested-tls-proxy-server-image = pkgs.dockerTools.buildLayeredImage { + name = "attested-tls-proxy-server"; + tag = "latest"; + contents = [ proxyImageRoot ]; + config = { + # Global flags must precede the subcommand. + # --allowed-remote-attestation-type satisfies the mandatory CLI + # requirement; it only takes effect when --client-auth is passed. + # target_addr is a required positional arg supplied via Cmd so it + # can be overridden at runtime: + # docker run attested-tls-proxy-server 127.0.0.1:8080 + Entrypoint = [ + "/bin/attested-tls-proxy" + "--allowed-remote-attestation-type" + "aws-nitro" + "server" + "--server-attestation-type" + "aws-nitro" + "--inner-vsock-port" + "8001" + ]; + Cmd = [ "127.0.0.1:3000" ]; + }; + }; + + attested-file-server-image = pkgs.dockerTools.buildLayeredImage { + name = "attested-file-server"; + tag = "latest"; + contents = [ fileServerImageRoot ]; + config = { + # attested-file-server starts an internal HTTP server on a random + # loopback port, then wraps it with an attested TLS listener. + # path_to_serve (/srv) is the positional arg after the subcommand. + # --inner-listen-addr is TCP-only (no vsock on this subcommand). + # Retrieve the test file with: + # attested-tls-proxy ... attested-get :8002 --url-path /hello.txt + Entrypoint = [ + "/bin/attested-file-server-start" + "--allowed-remote-attestation-type" + "aws-nitro" + "attested-file-server" + "/srv" + "--server-attestation-type" + "aws-nitro" + "--inner-vsock-port" + "8002" + ]; + }; + }; + default = self.packages.${system}.attestation-provider-server-image; }; diff --git a/src/file_server.rs b/src/file_server.rs index 6d55b21..4f1c4e1 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -1,7 +1,7 @@ //! Static HTTP file server provided by an attested TLS proxy server use crate::{ AttestationGenerator, AttestationVerifier, OuterTlsConfig, OuterTlsMode, ProxyError, - ProxyServer, TlsCertAndKey, + ProxyListenAddr, ProxyServer, TlsCertAndKey, }; use std::{net::SocketAddr, path::PathBuf}; use tokio::net::ToSocketAddrs; @@ -13,10 +13,10 @@ pub struct AttestedFileServerConfig { pub path_to_serve: PathBuf, /// TLS certificate and key for the optional outer listener pub outer_cert_and_key: Option, - /// Bind address for the optional outer nested-TLS listener - pub outer_listen_addr: Option, - /// Bind address for the optional inner attested-TLS listener - pub inner_listen_addr: Option, + /// Bind address (TCP or vsock) for the optional outer nested-TLS listener + pub outer_listen_addr: Option>, + /// Bind address (TCP or vsock) for the optional inner attested-TLS listener + pub inner_listen_addr: Option>, /// Certificate name to embed in the inner attested certificate pub inner_certificate_name: Option, /// Attestation generator used by the proxy server @@ -55,7 +55,7 @@ where (None, None) => None, }; - let server = ProxyServer::new( + let server = ProxyServer::new_with_listeners( outer_session, inner_listen_addr, inner_certificate_name, diff --git a/src/http_version.rs b/src/http_version.rs index 157a6d9..23c3814 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -8,8 +8,8 @@ pub const ALPN_H2: &[u8] = b"h2"; pub const ALPN_HTTP11: &[u8] = b"http/1.1"; type ProxyClientTlsStream = - tokio_rustls::client::TlsStream>; -type ProxyClientInnerOnlyTlsStream = tokio_rustls::client::TlsStream; + tokio_rustls::client::TlsStream>; +type ProxyClientInnerOnlyTlsStream = tokio_rustls::client::TlsStream; /// Supported HTTP versions #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 87f3423..2227293 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,9 +22,17 @@ use hyper_util::rt::TokioIo; use nested_tls::{ client::NestingTlsConnector, server::NestingTlsAcceptor, server::NestingTlsStream, }; -use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; +use std::{ + fmt, + net::SocketAddr, + num::TryFromIntError, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; use thiserror::Error; -use tokio::io::{self, AsyncWriteExt}; +use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; @@ -33,6 +41,7 @@ use tokio_rustls::rustls::{ pki_types::{CertificateDer, PrivateKeyDer, ServerName}, }; use tokio_rustls::{TlsAcceptor, TlsConnector}; +use tokio_vsock::{VsockAddr, VsockListener, VsockStream}; use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; @@ -62,8 +71,271 @@ type RequestWithResponseSender = ( oneshot::Sender>, hyper::Error>>, ); -type OuterProxySession = (Arc, NestingTlsAcceptor); -type InnerProxySession = (Arc, TlsAcceptor); +type OuterProxySession = (Arc, NestingTlsAcceptor); +type InnerProxySession = (Arc, TlsAcceptor); + +/// Address to bind for an incoming proxy listener. +#[derive(Debug, Clone, Copy)] +pub enum ProxyListenAddr { + /// Bind a TCP listener. + Tcp(A), + /// Bind an AF_VSOCK listener. + Vsock { + /// Local CID to bind. Use `tokio_vsock::VMADDR_CID_ANY` for the usual Nitro listener case. + cid: u32, + /// Local VSOCK port to bind. + port: u32, + }, +} + +impl ProxyListenAddr { + pub fn tcp(addr: A) -> Self { + Self::Tcp(addr) + } + + pub fn vsock(cid: u32, port: u32) -> Self { + Self::Vsock { cid, port } + } +} + +/// Remote proxy endpoint for the proxy client to connect to. +#[derive(Debug, Clone)] +pub enum ProxyConnectTarget { + /// Connect to a TCP host:port endpoint. + Tcp(A), + /// Connect to a VSOCK endpoint and use `server_name` for TLS certificate verification. + Vsock { + /// Remote VSOCK CID. + cid: u32, + /// Remote VSOCK port. + port: u32, + /// TLS server name to verify. + server_name: String, + }, +} + +impl ProxyConnectTarget { + pub fn tcp(addr: A) -> Self { + Self::Tcp(addr) + } + + pub fn vsock(cid: u32, port: u32, server_name: impl Into) -> Self { + Self::Vsock { + cid, + port, + server_name: server_name.into(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProxyAddr { + Tcp(SocketAddr), + Vsock { cid: u32, port: u32 }, +} + +impl ProxyAddr { + fn as_tcp(self) -> io::Result { + match self { + Self::Tcp(addr) => Ok(addr), + Self::Vsock { .. } => Err(io::Error::other("listener is not a TCP listener")), + } + } +} + +impl fmt::Display for ProxyAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Tcp(addr) => write!(f, "{addr}"), + Self::Vsock { cid, port } => write!(f, "vsock:{cid}:{port}"), + } + } +} + +impl From for ProxyAddr { + fn from(addr: SocketAddr) -> Self { + Self::Tcp(addr) + } +} + +impl From for ProxyAddr { + fn from(addr: VsockAddr) -> Self { + Self::Vsock { + cid: addr.cid(), + port: addr.port(), + } + } +} + +#[derive(Debug, Clone, Copy)] +enum PeerAddr { + Tcp(SocketAddr), + Vsock, +} + +impl PeerAddr { + fn forwarded_ip(self) -> Option { + match self { + Self::Tcp(addr) => Some(addr.ip().to_string()), + Self::Vsock => None, + } + } +} + +impl From for PeerAddr { + fn from(addr: SocketAddr) -> Self { + Self::Tcp(addr) + } +} + +impl From for PeerAddr { + fn from(_addr: VsockAddr) -> Self { + Self::Vsock + } +} + +#[derive(Debug)] +enum ProxyListener { + Tcp(TcpListener), + Vsock(VsockListener), +} + +impl ProxyListener { + async fn bind(addr: ProxyListenAddr) -> io::Result + where + A: ToSocketAddrs, + { + match addr { + ProxyListenAddr::Tcp(addr) => TcpListener::bind(addr).await.map(Self::Tcp), + ProxyListenAddr::Vsock { cid, port } => { + VsockListener::bind(VsockAddr::new(cid, port)).map(Self::Vsock) + } + } + } + + async fn accept(&self) -> io::Result<(TransportStream, PeerAddr)> { + match self { + Self::Tcp(listener) => { + let (stream, addr) = listener.accept().await?; + Ok((TransportStream::Tcp { inner: stream }, addr.into())) + } + Self::Vsock(listener) => { + let (stream, addr) = listener.accept().await?; + Ok((TransportStream::Vsock { inner: stream }, addr.into())) + } + } + } + + fn local_addr(&self) -> io::Result { + match self { + Self::Tcp(listener) => listener.local_addr().map(ProxyAddr::from), + Self::Vsock(listener) => listener.local_addr().map(ProxyAddr::from), + } + } +} + +pin_project_lite::pin_project! { + #[project = TransportStreamProj] + #[derive(Debug)] + pub(crate) enum TransportStream { + Tcp { #[pin] inner: TcpStream }, + Vsock { #[pin] inner: VsockStream }, + } +} + +impl TransportStream { + async fn connect(target: &ProxyConnectAddr) -> io::Result { + match target { + ProxyConnectAddr::Tcp(target) => TcpStream::connect(target) + .await + .map(|inner| Self::Tcp { inner }), + ProxyConnectAddr::Vsock { cid, port, .. } => { + VsockStream::connect(VsockAddr::new(*cid, *port)) + .await + .map(|inner| Self::Vsock { inner }) + } + } + } +} + +impl AsyncRead for TransportStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + TransportStreamProj::Tcp { inner } => inner.poll_read(cx, buf), + TransportStreamProj::Vsock { inner } => inner.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TransportStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + TransportStreamProj::Tcp { inner } => inner.poll_write(cx, buf), + TransportStreamProj::Vsock { inner } => inner.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + TransportStreamProj::Tcp { inner } => inner.poll_flush(cx), + TransportStreamProj::Vsock { inner } => inner.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + TransportStreamProj::Tcp { inner } => inner.poll_shutdown(cx), + TransportStreamProj::Vsock { inner } => inner.poll_shutdown(cx), + } + } +} + +#[derive(Debug, Clone)] +enum ProxyConnectAddr { + Tcp(String), + Vsock { + cid: u32, + port: u32, + server_name: String, + }, +} + +impl ProxyConnectAddr { + fn from_target(target: ProxyConnectTarget) -> Self + where + A: ToString, + { + match target { + ProxyConnectTarget::Tcp(target) => { + Self::Tcp(host_to_host_with_port(&target.to_string())) + } + ProxyConnectTarget::Vsock { + cid, + port, + server_name, + } => Self::Vsock { + cid, + port, + server_name: host_to_host_with_port(&server_name), + }, + } + } + + fn server_name(&self) -> &str { + match self { + Self::Tcp(target) => target, + Self::Vsock { server_name, .. } => server_name, + } + } +} #[derive(Clone)] enum ProxyTlsConnector { @@ -102,10 +374,7 @@ pub enum OuterTlsMode { }, } -impl OuterTlsConfig -where - A: ToSocketAddrs, -{ +impl OuterTlsConfig { fn certificate_name(&self) -> Result { match &self.tls { OuterTlsMode::CertAndKey(cert_and_key) => { @@ -117,11 +386,23 @@ where } } + fn map_listen_addr(self, map: impl FnOnce(A) -> B) -> OuterTlsConfig { + OuterTlsConfig { + listen_addr: map(self.listen_addr), + tls: self.tls, + } + } +} + +impl OuterTlsConfig> +where + A: ToSocketAddrs, +{ async fn into_listener_and_acceptor( self, inner_server_config: Arc, client_auth: bool, - ) -> Result<(Arc, NestingTlsAcceptor), ProxyError> { + ) -> Result<(Arc, NestingTlsAcceptor), ProxyError> { let listen_addr = self.listen_addr; let outer_server_config = match self.tls { OuterTlsMode::CertAndKey(cert_and_key) => { @@ -148,7 +429,7 @@ where OuterTlsMode::Preconfigured { server_config, .. } => server_config, }; - let outer_listener = Arc::new(TcpListener::bind(listen_addr).await?); + let outer_listener = Arc::new(ProxyListener::bind(listen_addr).await?); let outer_tls_acceptor = NestingTlsAcceptor::new(Arc::new(outer_server_config), inner_server_config); @@ -245,6 +526,32 @@ impl ProxyServer { attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result + where + O: ToSocketAddrs, + I: ToSocketAddrs, + { + Self::new_with_listeners( + outer_session.map(|outer_session| outer_session.map_listen_addr(ProxyListenAddr::Tcp)), + inner_local.map(ProxyListenAddr::Tcp), + inner_certificate_name, + target, + attestation_generator, + attestation_verifier, + client_auth, + ) + .await + } + + /// Start with dual listeners, each of which can be TCP or VSOCK. + pub async fn new_with_listeners( + outer_session: Option>>, + inner_local: Option>, + inner_certificate_name: Option, + target: String, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + ) -> Result where O: ToSocketAddrs, I: ToSocketAddrs, @@ -269,7 +576,7 @@ impl ProxyServer { ); let inner = match inner_local { Some(inner_local) => { - let inner_listener = Arc::new(TcpListener::bind(inner_local).await?); + let inner_listener = Arc::new(ProxyListener::bind(inner_local).await?); let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone()); Some((inner_listener, inner_tls_acceptor)) } @@ -383,8 +690,8 @@ impl ProxyServer { Ok(join_handle) } - /// Helper to get the socket address of either underlying TCP listener - pub fn local_addr(&self) -> std::io::Result { + /// Helper to get the transport address of either underlying listener + pub fn local_proxy_addr(&self) -> std::io::Result { match &self.outer { Some((listener, _)) => listener.local_addr(), None => self @@ -396,26 +703,47 @@ impl ProxyServer { } } - /// Helper to get the socket address of the underlying outer TCP listener if present - pub fn outer_local_addr(&self) -> std::io::Result> { + /// Helper to get the TCP socket address of either underlying TCP listener. + /// + /// Returns an error when the selected listener is VSOCK. + pub fn local_addr(&self) -> std::io::Result { + self.local_proxy_addr()?.as_tcp() + } + + /// Helper to get the transport address of the underlying outer listener if present + pub fn outer_local_proxy_addr(&self) -> std::io::Result> { self.outer .as_ref() .map(|(listener, _)| listener.local_addr()) .transpose() } - /// Helper to get the socket address of the underlying inner TCP listener if present - pub fn inner_local_addr(&self) -> std::io::Result> { + /// Helper to get the socket address of the underlying outer TCP listener if present + pub fn outer_local_addr(&self) -> std::io::Result> { + self.outer_local_proxy_addr()? + .map(ProxyAddr::as_tcp) + .transpose() + } + + /// Helper to get the transport address of the underlying inner listener if present + pub fn inner_local_proxy_addr(&self) -> std::io::Result> { self.inner .as_ref() .map(|(listener, _)| listener.local_addr()) .transpose() } + /// Helper to get the socket address of the underlying inner TCP listener if present + pub fn inner_local_addr(&self) -> std::io::Result> { + self.inner_local_proxy_addr()? + .map(ProxyAddr::as_tcp) + .transpose() + } + async fn handle_outer_connection( - tls_stream: NestingTlsStream, + tls_stream: NestingTlsStream, target: String, - client_addr: SocketAddr, + client_addr: PeerAddr, ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted connection"); @@ -446,9 +774,9 @@ impl ProxyServer { } async fn handle_inner_connection( - tls_stream: tokio_rustls::server::TlsStream, + tls_stream: tokio_rustls::server::TlsStream, target: String, - client_addr: SocketAddr, + client_addr: PeerAddr, ) -> Result<(), ProxyError> { debug!("[proxy-server] accepted inner-only connection"); @@ -478,7 +806,7 @@ impl ProxyServer { tls_stream: IO, http_version: HttpVersion, target: String, - client_addr: SocketAddr, + client_addr: PeerAddr, attestation: Option, ) -> Result<(), ProxyError> where @@ -507,20 +835,21 @@ impl ProxyServer { let old_value = update_header(headers, &http::header::HOST, &target); debug!("Updating Host header - old value: {old_value:?} new value: {target}",); - // Add the x-real-ip header - let client_ip = client_addr.ip().to_string(); - update_header(headers, &X_REAL_IP, &client_ip); + if let Some(client_ip) = client_addr.forwarded_ip() { + // Add the x-real-ip header + update_header(headers, &X_REAL_IP, &client_ip); - // Add or update the x-forwarded-for header - let new_x_forwarded_for = - match headers.get(&X_FORWARDED_FOR).and_then(|v| v.to_str().ok()) { - Some(existing) if !existing.trim().is_empty() => { - format!("{}, {}", existing.trim(), client_ip) - } - _ => client_ip.clone(), - }; + // Add or update the x-forwarded-for header + let new_x_forwarded_for = + match headers.get(&X_FORWARDED_FOR).and_then(|v| v.to_str().ok()) { + Some(existing) if !existing.trim().is_empty() => { + format!("{}, {}", existing.trim(), client_ip) + } + _ => client_ip.clone(), + }; - update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + update_header(headers, &X_FORWARDED_FOR, &new_x_forwarded_for); + } // Strip any caller-provided attestation metadata before injecting authenticated values. headers.remove(ATTESTATION_TYPE_HEADER); @@ -634,8 +963,8 @@ fn full>(chunk: T) -> BoxBody { /// A proxy client which forwards http traffic to a proxy-server #[derive(Debug)] pub struct ProxyClient { - /// The underlying TCP listener - listener: TcpListener, + /// The underlying local listener. + listener: ProxyListener, /// A channel for sending requests to the connection to the proxy-server requests_tx: mpsc::Sender, } @@ -650,6 +979,30 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { + Self::new_with_transport( + cert_and_key, + ProxyListenAddr::Tcp(address), + ProxyConnectTarget::Tcp(server_name), + attestation_generator, + attestation_verifier, + remote_certificate, + ) + .await + } + + /// Start with optional TLS client auth and TCP or VSOCK transports. + pub async fn new_with_transport( + cert_and_key: Option, + listen_addr: ProxyListenAddr, + target: ProxyConnectTarget, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + remote_certificate: Option>, + ) -> Result + where + L: ToSocketAddrs, + T: ToString, + { let root_store = match remote_certificate.as_ref() { Some(remote_certificate) => { let mut root_store = RootCertStore::empty(); @@ -672,10 +1025,10 @@ impl ProxyClient { .with_no_client_auth() }; - Self::new_with_tls_config( + Self::new_with_transport_tls_config( outer_client_config, - address, - server_name, + listen_addr, + target, attestation_generator, attestation_verifier, cert_and_key.map(|cert_and_key| cert_and_key.cert_chain), @@ -723,10 +1076,62 @@ impl ProxyClient { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - Self::new_with_connector( - address, + Self::new_with_transport_connector( + ProxyListenAddr::Tcp(address), + ProxyConnectTarget::Tcp(target_name), + ProxyTlsConnector::Nested(nesting_tls_connector), + ) + .await + } + + /// Create a new proxy client with given TLS configuration and TCP or VSOCK transports. + pub async fn new_with_transport_tls_config( + outer_client_config: ClientConfig, + listen_addr: ProxyListenAddr, + target: ProxyConnectTarget, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + cert_chain: Option>>, + ) -> Result + where + L: ToSocketAddrs, + T: ToString, + { + let outer_has_client_auth = outer_client_config.client_auth_cert_resolver.has_certs(); + let inner_has_client_auth = cert_chain.is_some(); + + if outer_has_client_auth != inner_has_client_auth { + return Err(ProxyError::ClientAuthMisconfigured); + } + + let attested_cert_verifier = + AttestedCertificateVerifier::try_default(attestation_verifier)?; + + let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_identity_from_chain(cert_chain)?, + ) + .await?; + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_client_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth() + }; + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + let nesting_tls_connector = + NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); + + Self::new_with_transport_connector( + listen_addr, + target, ProxyTlsConnector::Nested(nesting_tls_connector), - &target_name, ) .await } @@ -782,24 +1187,69 @@ impl ProxyClient { }; ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); - Self::new_with_connector( - address, + Self::new_with_transport_connector( + ProxyListenAddr::Tcp(address), + ProxyConnectTarget::Tcp(target_name), + ProxyTlsConnector::InnerOnly(TlsConnector::from(Arc::new(inner_client_config))), + ) + .await + } + + /// Create a new inner-only proxy client with TCP or VSOCK transports. + pub async fn new_inner_only_with_transport_tls_config( + listen_addr: ProxyListenAddr, + target: ProxyConnectTarget, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + cert_chain: Option>>, + ) -> Result + where + L: ToSocketAddrs, + T: ToString, + { + let attested_cert_verifier = + AttestedCertificateVerifier::try_default(attestation_verifier)?; + + let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_identity_from_chain(cert_chain)?, + ) + .await?; + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_client_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth() + }; + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + Self::new_with_transport_connector( + listen_addr, + target, ProxyTlsConnector::InnerOnly(TlsConnector::from(Arc::new(inner_client_config))), - &target_name, ) .await } /// Create a new proxy client with a configured TLS connector. - async fn new_with_connector( - address: impl ToSocketAddrs, + async fn new_with_transport_connector( + listen_addr: ProxyListenAddr, + target: ProxyConnectTarget, tls_connector: ProxyTlsConnector, - target_name: &str, - ) -> Result { - let listener = TcpListener::bind(address).await?; + ) -> Result + where + L: ToSocketAddrs, + T: ToString, + { + let listener = ProxyListener::bind(listen_addr).await?; // Process the hostname / port provided by the user - let target = host_to_host_with_port(target_name); + let target = ProxyConnectAddr::from_target(target); // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( @@ -932,11 +1382,18 @@ impl ProxyClient { } } - /// Helper to return the local socket address from the underlying TCP listener - pub fn local_addr(&self) -> std::io::Result { + /// Helper to return the local transport address from the underlying listener. + pub fn local_proxy_addr(&self) -> std::io::Result { self.listener.local_addr() } + /// Helper to return the local socket address from the underlying TCP listener. + /// + /// Returns an error when the local listener is VSOCK. + pub fn local_addr(&self) -> std::io::Result { + self.local_proxy_addr()?.as_tcp() + } + /// Accept an incoming connection and handle it in a separate task pub async fn accept(&self) -> io::Result> { let (inbound, _client_addr) = self.listener.accept().await?; @@ -954,7 +1411,7 @@ impl ProxyClient { /// Handle an incoming connection from the source client async fn handle_connection( - inbound: TcpStream, + inbound: TransportStream, requests_tx: mpsc::Sender, ) -> Result<(), ProxyError> { tracing::debug!("proxy-client accepted connection"); @@ -987,7 +1444,7 @@ impl ProxyClient { // Attempt connection and handshake with the proxy-server // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( - target: &str, + target: &ProxyConnectAddr, tls_connector: &ProxyTlsConnector, should_bail: bool, ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { @@ -1018,11 +1475,11 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( tls_connector: &ProxyTlsConnector, - target: &str, + target: &ProxyConnectAddr, ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { - let outbound_stream = tokio::net::TcpStream::connect(target).await?; + let outbound_stream = TransportStream::connect(target).await?; - let domain = server_name_from_host(target)?; + let domain = server_name_from_host(target.server_name())?; match tls_connector { ProxyTlsConnector::Nested(connector) => { let tls_stream = connector.connect(domain, outbound_stream).await?; @@ -1787,7 +2244,10 @@ mod tests { let (sender, conn, _attestation) = ProxyClient::setup_connection( &ProxyTlsConnector::Nested(nesting_tls_connector), - &format!("localhost:{}", proxy_addr.port()), + &ProxyConnectAddr::from_target(ProxyConnectTarget::Tcp(format!( + "localhost:{}", + proxy_addr.port() + ))), ) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index 4c8cc4f..ef2deb4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,16 @@ use anyhow::{anyhow, ensure}; use attestation::{AttestationType, AttestationVerifier, measurements::MeasurementPolicy}; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use pccs::Pccs; use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio_vsock::VMADDR_CID_ANY; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyServer, TlsCertAndKey, + AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyConnectTarget, + ProxyListenAddr, ProxyServer, TlsCertAndKey, attested_get::attested_get, file_server::{AttestedFileServerConfig, attested_file_server}, get_inner_tls_cert, health_check, @@ -30,6 +32,12 @@ const DEBUG_LOG_TARGETS: &[&str] = &[ "pccs", ]; +#[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq)] +enum NetworkTransport { + Tcp, + Vsock, +} + #[derive(Parser, Debug, Clone)] #[command(version = GIT_REV, about, long_about = None)] struct Cli { @@ -62,22 +70,40 @@ struct Cli { enum CliCommand { /// Run a proxy client Client { + /// Network transport to use for the local client listener + #[arg(long, value_enum, default_value_t = NetworkTransport::Tcp, env = "LISTEN_TRANSPORT")] + listen_transport: NetworkTransport, /// Socket address to listen on #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] listen_addr: SocketAddr, + /// Local VSOCK CID to bind when using `--listen-transport vsock` + #[arg(long, default_value_t = VMADDR_CID_ANY, env = "LISTEN_VSOCK_CID")] + listen_vsock_cid: u32, + /// Local VSOCK port to bind when using `--listen-transport vsock` + #[arg(long, env = "LISTEN_VSOCK_PORT")] + listen_vsock_port: Option, + /// Network transport to use when connecting to the proxy server + #[arg(long, value_enum, default_value_t = NetworkTransport::Tcp, env = "TARGET_TRANSPORT")] + target_transport: NetworkTransport, + /// Remote VSOCK CID for the proxy server when using `--target-transport vsock` + #[arg(long, env = "TARGET_VSOCK_CID")] + target_vsock_cid: Option, + /// Remote VSOCK port for the proxy server when using `--target-transport vsock` + #[arg(long, env = "TARGET_VSOCK_PORT")] + target_vsock_port: Option, /// Connect directly to the server's inner attested TLS listener instead of nested TLS #[arg(long)] inner_session_only: bool, - /// The hostname:port or ip:port of the proxy server (port defaults to 443) + /// The proxy server hostname:port for TCP, or TLS server name for VSOCK target_addr: String, - /// Type of attestation to present (dafaults to 'auto' for automatic detection) - /// If other than None, a TLS key and certicate must also be given + /// Type of attestation to present (defaults to automatic detection) + /// Client certificate material enables client authentication. #[arg(long, env = "CLIENT_ATTESTATION_TYPE")] client_attestation_type: Option, - /// The path to a PEM encoded private key for client authentication in nested-TLS mode + /// The path to a PEM encoded private key for outer client authentication in nested-TLS mode #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// The path to a PEM encoded certificate chain for client authentication in nested-TLS mode + /// The path to a PEM encoded certificate chain for client authentication #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. @@ -96,9 +122,21 @@ enum CliCommand { /// Socket address to listen on for the outer nested-TLS listener, if enabled #[arg(long)] outer_listen_addr: Option, + /// VSOCK CID to bind for the outer nested-TLS listener + #[arg(long, default_value_t = VMADDR_CID_ANY, env = "OUTER_VSOCK_CID")] + outer_vsock_cid: u32, + /// VSOCK port to bind for the outer nested-TLS listener, if enabled + #[arg(long, env = "OUTER_VSOCK_PORT")] + outer_vsock_port: Option, /// Socket address to listen on for the inner-only attested TLS listener #[arg(long)] inner_listen_addr: Option, + /// VSOCK CID to bind for the inner-only attested TLS listener + #[arg(long, default_value_t = VMADDR_CID_ANY, env = "INNER_VSOCK_CID")] + inner_vsock_cid: u32, + /// VSOCK port to bind for the inner-only attested TLS listener, if enabled + #[arg(long, env = "INNER_VSOCK_PORT")] + inner_vsock_port: Option, /// DNS name to embed into the inner attested certificate when no outer listener is used #[arg(long)] inner_certificate_name: Option, @@ -144,9 +182,21 @@ enum CliCommand { /// Socket address to listen on for the outer nested-TLS listener, if enabled #[arg(long)] outer_listen_addr: Option, + /// VSOCK CID to bind for the outer nested-TLS listener + #[arg(long, default_value_t = VMADDR_CID_ANY, env = "OUTER_VSOCK_CID")] + outer_vsock_cid: u32, + /// VSOCK port to bind for the outer nested-TLS listener, if enabled + #[arg(long, env = "OUTER_VSOCK_PORT")] + outer_vsock_port: Option, /// Socket address to listen on for the inner-only attested TLS listener #[arg(long)] inner_listen_addr: Option, + /// VSOCK CID to bind for the inner-only attested TLS listener + #[arg(long, default_value_t = VMADDR_CID_ANY, env = "INNER_VSOCK_CID")] + inner_vsock_cid: u32, + /// VSOCK port to bind for the inner-only attested TLS listener, if enabled + #[arg(long, env = "INNER_VSOCK_PORT")] + inner_vsock_port: Option, /// DNS name to embed into the inner attested certificate when no outer listener is used #[arg(long)] inner_certificate_name: Option, @@ -249,7 +299,13 @@ async fn main() -> anyhow::Result<()> { match cli.command { CliCommand::Client { + listen_transport, listen_addr, + listen_vsock_cid, + listen_vsock_port, + target_transport, + target_vsock_cid, + target_vsock_port, inner_session_only, target_addr, client_attestation_type, @@ -263,6 +319,18 @@ async fn main() -> anyhow::Result<()> { .strip_prefix("https://") .unwrap_or(&target_addr) .to_string(); + let listen_endpoint = client_listen_endpoint( + listen_transport, + listen_addr, + listen_vsock_cid, + listen_vsock_port, + )?; + let target_endpoint = client_target_endpoint( + target_transport, + target_addr.clone(), + target_vsock_cid, + target_vsock_port, + )?; if let Some(listen_addr_healthcheck) = listen_addr_healthcheck { health_check::server(listen_addr_healthcheck).await?; @@ -303,19 +371,19 @@ async fn main() -> anyhow::Result<()> { AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap)?; let client = if inner_session_only { - ProxyClient::new_inner_only( - tls_cert_and_chain, - listen_addr, - target_addr, + ProxyClient::new_inner_only_with_transport_tls_config( + listen_endpoint, + target_endpoint, client_attestation_generator, attestation_verifier, + tls_cert_and_chain.map(|cert_and_key| cert_and_key.cert_chain), ) .await? } else { - ProxyClient::new( + ProxyClient::new_with_transport( tls_cert_and_chain, - listen_addr, - target_addr, + listen_endpoint, + target_endpoint, client_attestation_generator, attestation_verifier, remote_tls_cert, @@ -331,7 +399,11 @@ async fn main() -> anyhow::Result<()> { } CliCommand::Server { outer_listen_addr, + outer_vsock_cid, + outer_vsock_port, inner_listen_addr, + inner_vsock_cid, + inner_vsock_port, inner_certificate_name, target_addr, tls_private_key_path, @@ -347,23 +419,35 @@ async fn main() -> anyhow::Result<()> { let tls_cert_and_chain = load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; - validate_listener_args( - inner_listen_addr, + let outer_listen = optional_listen_endpoint( + "outer", outer_listen_addr, + outer_vsock_cid, + outer_vsock_port, + )?; + let inner_listen = optional_listen_endpoint( + "inner", + inner_listen_addr, + inner_vsock_cid, + inner_vsock_port, + )?; + validate_listener_args( + inner_listen.is_some(), + outer_listen.is_some(), tls_cert_and_chain.is_some(), )?; let local_attestation_generator = AttestationGenerator::new_with_detection(server_attestation_type, dev_dummy_dcap)?; - let server = ProxyServer::new( + let server = ProxyServer::new_with_listeners( tls_cert_and_chain - .zip(outer_listen_addr) + .zip(outer_listen) .map(|(cert_and_key, listen_addr)| OuterTlsConfig { listen_addr, tls: OuterTlsMode::CertAndKey(cert_and_key), }), - inner_listen_addr, + inner_listen, inner_certificate_name, target_addr, local_attestation_generator, @@ -410,7 +494,11 @@ async fn main() -> anyhow::Result<()> { CliCommand::AttestedFileServer { path_to_serve, outer_listen_addr, + outer_vsock_cid, + outer_vsock_port, inner_listen_addr, + inner_vsock_cid, + inner_vsock_port, inner_certificate_name, server_attestation_type, tls_private_key_path, @@ -419,9 +507,21 @@ async fn main() -> anyhow::Result<()> { } => { let tls_cert_and_chain = load_tls_cert_and_key_server(tls_certificate_path, tls_private_key_path)?; - validate_listener_args( - inner_listen_addr, + let outer_listen = optional_listen_endpoint( + "outer", outer_listen_addr, + outer_vsock_cid, + outer_vsock_port, + )?; + let inner_listen = optional_listen_endpoint( + "inner", + inner_listen_addr, + inner_vsock_cid, + inner_vsock_port, + )?; + validate_listener_args( + inner_listen.is_some(), + outer_listen.is_some(), tls_cert_and_chain.is_some(), )?; @@ -435,8 +535,8 @@ async fn main() -> anyhow::Result<()> { attested_file_server(AttestedFileServerConfig { path_to_serve, outer_cert_and_key: tls_cert_and_chain, - outer_listen_addr, - inner_listen_addr, + outer_listen_addr: outer_listen, + inner_listen_addr: inner_listen, inner_certificate_name, attestation_generator, attestation_verifier, @@ -495,26 +595,82 @@ fn load_tls_cert_and_key_server( } } +fn client_listen_endpoint( + listen_transport: NetworkTransport, + listen_addr: SocketAddr, + listen_vsock_cid: u32, + listen_vsock_port: Option, +) -> anyhow::Result> { + match listen_transport { + NetworkTransport::Tcp => Ok(ProxyListenAddr::Tcp(listen_addr)), + NetworkTransport::Vsock => Ok(ProxyListenAddr::Vsock { + cid: listen_vsock_cid, + port: listen_vsock_port.ok_or_else(|| { + anyhow!("--listen-vsock-port is required with --listen-transport vsock") + })?, + }), + } +} + +fn client_target_endpoint( + target_transport: NetworkTransport, + target_addr: String, + target_vsock_cid: Option, + target_vsock_port: Option, +) -> anyhow::Result> { + match target_transport { + NetworkTransport::Tcp => Ok(ProxyConnectTarget::Tcp(target_addr)), + NetworkTransport::Vsock => Ok(ProxyConnectTarget::Vsock { + cid: target_vsock_cid.ok_or_else(|| { + anyhow!("--target-vsock-cid is required with --target-transport vsock") + })?, + port: target_vsock_port.ok_or_else(|| { + anyhow!("--target-vsock-port is required with --target-transport vsock") + })?, + server_name: target_addr, + }), + } +} + +fn optional_listen_endpoint( + name: &str, + tcp_addr: Option, + vsock_cid: u32, + vsock_port: Option, +) -> anyhow::Result>> { + match (tcp_addr, vsock_port) { + (Some(_), Some(_)) => Err(anyhow!( + "--{name}-listen-addr and --{name}-vsock-port are mutually exclusive" + )), + (Some(addr), None) => Ok(Some(ProxyListenAddr::Tcp(addr))), + (None, Some(port)) => Ok(Some(ProxyListenAddr::Vsock { + cid: vsock_cid, + port, + })), + (None, None) => Ok(None), + } +} + fn validate_listener_args( - inner_listen_addr: Option, - outer_listen_addr: Option, + inner_listener_configured: bool, + outer_listener_configured: bool, has_outer_tls: bool, ) -> anyhow::Result<()> { - if inner_listen_addr.is_none() && outer_listen_addr.is_none() { + if !inner_listener_configured && !outer_listener_configured { return Err(anyhow!( - "At least one of --inner-listen-addr or --outer-listen-addr must be provided" + "At least one inner or outer listener must be configured" )); } - if has_outer_tls && outer_listen_addr.is_none() { + if has_outer_tls && !outer_listener_configured { return Err(anyhow!( - "--outer-listen-addr is required when TLS certificate and key are provided" + "An outer listener is required when TLS certificate and key are provided" )); } - if !has_outer_tls && outer_listen_addr.is_some() { + if !has_outer_tls && outer_listener_configured { return Err(anyhow!( - "--outer-listen-addr requires TLS certificate and key" + "An outer listener requires TLS certificate and key" )); } @@ -523,8 +679,8 @@ fn validate_listener_args( fn validate_client_args( inner_session_only: bool, - tls_private_key_path: Option<&PathBuf>, - tls_certificate_path: Option<&PathBuf>, + _tls_private_key_path: Option<&PathBuf>, + _tls_certificate_path: Option<&PathBuf>, tls_ca_certificate: Option<&PathBuf>, ) -> anyhow::Result<()> { if inner_session_only && tls_ca_certificate.is_some() { @@ -533,12 +689,6 @@ fn validate_client_args( )); } - if inner_session_only && (tls_private_key_path.is_some() || tls_certificate_path.is_some()) { - return Err(anyhow!( - "--tls-private-key-path and --tls-certificate-path are not supported with --inner-session-only" - )); - } - Ok(()) } @@ -590,12 +740,51 @@ mod tests { } #[test] - fn client_rejects_tls_client_auth_in_inner_only_mode() { + fn client_allows_tls_client_auth_in_inner_only_mode() { let cert_path = PathBuf::from("client.crt"); let key_path = PathBuf::from("client.key"); - let err = validate_client_args(true, Some(&key_path), Some(&cert_path), None) - .unwrap_err() - .to_string(); - assert!(err.contains("--tls-private-key-path")); + validate_client_args(true, Some(&key_path), Some(&cert_path), None).unwrap(); + } + + #[test] + fn client_requires_vsock_listen_port_when_listening_on_vsock() { + let err = client_listen_endpoint( + NetworkTransport::Vsock, + "127.0.0.1:0".parse().unwrap(), + VMADDR_CID_ANY, + None, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("--listen-vsock-port")); + } + + #[test] + fn client_requires_vsock_target_when_connecting_over_vsock() { + let err = client_target_endpoint( + NetworkTransport::Vsock, + "localhost".to_string(), + Some(3), + None, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("--target-vsock-port")); + } + + #[test] + fn server_rejects_tcp_and_vsock_for_same_listener() { + let err = optional_listen_endpoint( + "inner", + Some("127.0.0.1:7001".parse().unwrap()), + VMADDR_CID_ANY, + Some(7001), + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("mutually exclusive")); } }