Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 241 additions & 51 deletions Cargo.lock

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[workspace]
members = [".", "attestation-provider-server"]

[patch.crates-io]
dcap-qvl = { git = "https://github.com/Phala-Network/dcap-qvl.git", rev = "f1dcc65371e941a7b83e3234833d23a1fb232ab1" }

[package]
name = "attested-tls-proxy"
version = "1.1.1"
Expand All @@ -11,10 +14,10 @@ repository = "https://github.com/flashbots/attested-tls-proxy"
keywords = ["attested-TLS", "CVM", "TDX"]

[dependencies]
attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" }
nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "main" }
attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main" }
pccs = { git = "https://github.com/flashbots/attested-tls", branch = "main" }
attested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" }
nested-tls = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" }
attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" }
pccs = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro" }
tokio = { version = "1.50.0", features = ["full"] }
tokio-rustls = { version = "0.26.4", default-features = false, features = ["aws_lc_rs"] }
x509-parser = { version = "0.18.0", features = ["verify"] }
Expand All @@ -37,7 +40,7 @@ reqwest = { version = "0.12.24", default-features = false, features = [
webpki-roots = "1.0.7"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
axum = "0.8.8"
axum = "0.8.9"
tower-http = { version = "0.6.7", features = ["fs"] }
rsa = { version = "0.9", default-features = false }
p256 = { version = "0.13.2", features = ["pkcs8"] }
Expand All @@ -49,7 +52,7 @@ pin-project-lite = "0.2.16"
[dev-dependencies]
tempfile = "3.23.0"
tdx-quote = { version = "0.0.5", features = ["mock"] }
attestation = { git = "https://github.com/flashbots/attested-tls", branch = "main", features = ["mock"] }
attestation = { git = "https://github.com/flashbots/attested-tls", branch = "peg/nitro", features = ["mock"] }
tokio = { version = "1.48.0", features = ["full"] }
jsonrpsee = { version = "0.26.0", features = ["server"] }
mock-tdx = { git = "https://github.com/flashbots/attested-tls", branch = "main" }
Expand Down
8 changes: 7 additions & 1 deletion attestation-provider-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@ repository = "https://github.com/flashbots/attested-tls-proxy"
[dependencies]
attested-tls-proxy = { path = ".." }
tokio = { version = "1.48.0", features = ["full"] }
axum = "0.8.8"
axum = "0.8.9"
clap = { version = "4.5.51", features = ["derive", "env"] }
anyhow = "1.0.100"
bytes = "1.11.1"
hex = "0.4.3"
http = "1.3.1"
http-body-util = "0.1.3"
hyper = "1.7.0"
hyper-util = { version = "0.1.17", features = ["tokio"] }
tokio-vsock = { version = "0.7.2", features = ["axum08"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
parity-scale-codec = "3.7.5"
Expand Down
73 changes: 58 additions & 15 deletions attestation-provider-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,40 @@ use std::net::SocketAddr;

use anyhow::anyhow;
use attested_tls_proxy::attestation::{AttestationExchangeMessage, AttestationVerifier};
use axum::serve::Listener;
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::Request;
use hyper::client::conn::http1;
use hyper_util::rt::TokioIo;
use parity_scale_codec::{Decode, Encode};
use tokio::net::TcpListener;
use tokio_vsock::{VsockAddr, VsockStream};

#[derive(Debug, Clone, Copy)]
pub enum AttestationProviderEndpoint {
Tcp(SocketAddr),
Vsock { cid: u32, port: u32 },
}

#[derive(Clone)]
struct SharedState {
attestation_generator: AttestationGenerator,
}

/// An HTTP server which provides attestations
pub async fn attestation_provider_server(
listener: TcpListener,
pub async fn attestation_provider_server<L>(
listener: L,
attestation_generator: AttestationGenerator,
) -> anyhow::Result<()> {
) -> anyhow::Result<()>
where
L: Listener,
L::Addr: std::fmt::Debug,
{
let app = axum::Router::new()
.route("/attest/{input_data}", axum::routing::get(get_attest))
.with_state(SharedState {
Expand Down Expand Up @@ -52,17 +68,40 @@ async fn get_attest(

/// A client helper which makes a request to `/attest`
pub async fn attestation_provider_client(
server_addr: SocketAddr,
server_endpoint: AttestationProviderEndpoint,
attestation_verifier: AttestationVerifier,
) -> anyhow::Result<AttestationExchangeMessage> {
let input_data = [0; 64];
let response = reqwest::get(format!(
"http://{server_addr}/attest/{}",
hex::encode(input_data)
))
.await?
.bytes()
.await?;
let response = match server_endpoint {
AttestationProviderEndpoint::Tcp(server_addr) => reqwest::get(format!(
"http://{server_addr}/attest/{}",
hex::encode(input_data)
))
.await?
.bytes()
.await?
.to_vec(),
AttestationProviderEndpoint::Vsock { cid, port } => {
let stream = VsockStream::connect(VsockAddr::new(cid, port)).await?;
let io = TokioIo::new(stream);
let (mut sender, connection) = http1::handshake(io).await?;

tokio::spawn(async move {
if let Err(err) = connection.await {
eprintln!("vsock HTTP connection error: {err}");
}
});

let request = Request::builder()
.method(http::Method::GET)
.uri(format!("/attest/{}", hex::encode(input_data)))
.header(http::header::HOST, format!("{cid}:{port}"))
.body(http_body_util::Empty::<Bytes>::new())?;

let response = sender.send_request(request).await?;
response.into_body().collect().await?.to_bytes().to_vec()
}
};

let remote_attestation_message = AttestationExchangeMessage::decode(&mut &response[..])?;
let remote_attestation_type = remote_attestation_message.attestation_type;
Expand Down Expand Up @@ -97,6 +136,7 @@ impl IntoResponse for ServerError {
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;

#[tokio::test]
async fn test_attestation_provider_server() {
Expand All @@ -110,8 +150,11 @@ mod tests {
.await
.unwrap();
});
attestation_provider_client(server_addr, AttestationVerifier::expect_none())
.await
.unwrap();
attestation_provider_client(
AttestationProviderEndpoint::Tcp(server_addr),
AttestationVerifier::expect_none(),
)
.await
.unwrap();
}
}
63 changes: 57 additions & 6 deletions attestation-provider-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use attestation_provider_server::{attestation_provider_client, attestation_provi
use attested_tls_proxy::attestation::{
AttestationGenerator, AttestationVerifier, measurements::MeasurementPolicy,
};
use clap::{Parser, Subcommand};
use clap::{Parser, Subcommand, ValueEnum};
use std::{net::SocketAddr, path::PathBuf};
use tokio::net::TcpListener;
use tokio_vsock::{VMADDR_CID_ANY, VsockAddr, VsockListener};
use tracing::level_filters::LevelFilter;

const GIT_REV: &str = match option_env!("GIT_REV") {
Expand All @@ -30,22 +31,44 @@ struct Cli {
#[derive(Subcommand, Debug, Clone)]
enum CliCommand {
Server {
/// Network transport to use for the server listener
#[arg(long, value_enum, default_value_t = NetworkTransport::Tcp)]
listen_transport: NetworkTransport,
/// Socket address to listen on
#[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")]
listen_addr: SocketAddr,
/// Vsock port to listen on when using `--listen-transport vsock`
#[arg(long, default_value_t = 8000, env = "VSOCK_PORT")]
vsock_port: u32,
/// Type of attestation to present (will attempt to detect if not given)
#[arg(long)]
server_attestation_type: Option<String>,
},
Client {
/// Network transport to use for the attestation provider server
#[arg(long, value_enum, default_value_t = NetworkTransport::Tcp)]
server_transport: NetworkTransport,
/// Socket address of a attestation provider server
#[arg(short, long, default_value = "127.0.0.1:8000", env = "SERVER_ADDR")]
server_addr: SocketAddr,
/// Vsock CID of the attestation provider server when using `--server-transport vsock`
#[arg(long, default_value_t = 10, env = "SERVER_CID")]
server_cid: u32,
/// Vsock port of the attestation provider server when using `--server-transport vsock`
#[arg(long, default_value_t = 8000, env = "SERVER_VSOCK_PORT")]
server_vsock_port: u32,
/// Optional path to file containing JSON measurements to be enforced on the remote party
#[arg(long, global = true, env = "MEASUREMENTS_FILE")]
measurements_file: Option<PathBuf>,
},
}

#[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq)]
enum NetworkTransport {
Tcp,
Vsock,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
Expand Down Expand Up @@ -74,19 +97,35 @@ async fn main() -> anyhow::Result<()> {

match cli.command {
CliCommand::Server {
listen_transport,
listen_addr,
vsock_port,
server_attestation_type,
} => {
let attestation_generator =
AttestationGenerator::new_with_detection(server_attestation_type, None)?;

let listener = TcpListener::bind(listen_addr).await?;

println!("Listening on {}", listener.local_addr()?);
attestation_provider_server(listener, attestation_generator).await?;
match listen_transport {
NetworkTransport::Tcp => {
let listener = TcpListener::bind(listen_addr).await?;
println!("Listening on {}", listener.local_addr()?);
attestation_provider_server(listener, attestation_generator).await?;
}
NetworkTransport::Vsock => {
let listener = VsockListener::bind(VsockAddr::new(VMADDR_CID_ANY, vsock_port))?;
println!(
"Listening on vsock cid={} port={}",
VMADDR_CID_ANY, vsock_port
);
attestation_provider_server(listener, attestation_generator).await?;
}
}
}
CliCommand::Client {
server_transport,
server_addr,
server_cid,
server_vsock_port,
measurements_file,
} => {
let measurement_policy = match measurements_file {
Expand All @@ -102,8 +141,20 @@ async fn main() -> anyhow::Result<()> {
internal_pccs: None,
};

let server_endpoint = match server_transport {
NetworkTransport::Tcp => {
attestation_provider_server::AttestationProviderEndpoint::Tcp(server_addr)
}
NetworkTransport::Vsock => {
attestation_provider_server::AttestationProviderEndpoint::Vsock {
cid: server_cid,
port: server_vsock_port,
}
}
};

let attestation_message =
attestation_provider_client(server_addr, attestation_verifier).await?;
attestation_provider_client(server_endpoint, attestation_verifier).await?;

println!("{attestation_message:?}")
}
Expand Down
27 changes: 27 additions & 0 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading