diff --git a/devolutions-gateway/src/agent_tunnel/integration_test.rs b/devolutions-gateway/src/agent_tunnel/integration_test.rs new file mode 100644 index 000000000..8f153c5b5 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/integration_test.rs @@ -0,0 +1,638 @@ +//! Integration test for the QUIC agent tunnel. +//! +//! Verifies the full data path: +//! TCP echo server ← Agent (simulated quiche client) ← QUIC ← Gateway listener ← QuicStream +//! +//! This test runs entirely in-process with real UDP sockets on localhost. + +#![cfg(test)] + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use camino::Utf8PathBuf; +use ipnetwork::Ipv4Network; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use uuid::Uuid; + +use super::cert::CaManager; +use super::listener::AgentTunnelListener; + +const ALPN_PROTOCOL: &[u8] = b"devolutions-agent-tunnel"; +const MAX_DATAGRAM_SIZE: usize = 1350; + +/// Start a TCP echo server that echoes back whatever it receives. +/// Returns the server address and a join handle. +async fn start_echo_server() -> (SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => break, + }; + + tokio::spawn(async move { + let mut buf = vec![0u8; 65535]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }); + + (addr, handle) +} + +/// Drive the quiche connection: send all pending data over UDP. +async fn flush_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + let mut buf = vec![0u8; MAX_DATAGRAM_SIZE]; + loop { + match conn.send(&mut buf) { + Ok((len, send_info)) => { + let _ = socket.send_to(&buf[..len], send_info.to).await; + } + Err(quiche::Error::Done) => break, + Err(e) => { + eprintln!("quiche send error: {e}"); + break; + } + } + } + let _ = peer_addr; // Used for clarity in caller. +} + +/// Receive UDP data and feed it to the quiche connection. +async fn recv_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, timeout: Duration) -> bool { + let mut buf = vec![0u8; 65535]; + + let result = tokio::time::timeout(timeout, socket.recv_from(&mut buf)).await; + match result { + Ok(Ok((len, from))) => { + let local = socket.local_addr().unwrap(); + let recv_info = quiche::RecvInfo { from, to: local }; + match conn.recv(&mut buf[..len], recv_info) { + Ok(_) => true, + Err(e) => { + eprintln!("quiche recv error: {e}"); + false + } + } + } + Ok(Err(e)) => { + eprintln!("UDP recv error: {e}"); + false + } + Err(_) => false, // timeout + } +} + +/// Drive the QUIC handshake to completion. +async fn complete_handshake(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + for _ in 0..50 { + flush_quiche(conn, socket, peer_addr).await; + if conn.is_established() { + return; + } + recv_quiche(conn, socket, Duration::from_millis(500)).await; + flush_quiche(conn, socket, peer_addr).await; + } + panic!("QUIC handshake did not complete in time"); +} + +/// Send a length-prefixed bincode message on a QUIC stream. +fn send_message(conn: &mut quiche::Connection, stream_id: u64, msg: &T) { + let payload = bincode::serialize(msg).unwrap(); + let len = (payload.len() as u32).to_be_bytes(); + let mut data = Vec::with_capacity(4 + payload.len()); + data.extend_from_slice(&len); + data.extend_from_slice(&payload); + conn.stream_send(stream_id, &data, false).unwrap(); +} + +/// Try to read a length-prefixed bincode message from accumulated stream data. +fn try_decode_message(buf: &[u8]) -> Option<(T, usize)> { + if buf.len() < 4 { + return None; + } + let msg_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + msg_len { + return None; + } + let msg: T = bincode::deserialize(&buf[4..4 + msg_len]).ok()?; + Some((msg, 4 + msg_len)) +} + +/// Full E2E integration test. +/// +/// 1. Start TCP echo server +/// 2. Start QUIC listener (gateway) +/// 3. Connect a simulated agent (quiche client) with mTLS +/// 4. Agent sends RouteAdvertise +/// 5. Gateway opens a proxy stream via connect_via_agent +/// 6. Agent reads ConnectMessage, connects to echo server, sends ConnectResponse::Success +/// 7. Gateway writes data through QuicStream +/// 8. Verify echo response arrives back through the tunnel +#[tokio::test] +async fn quic_agent_tunnel_e2e() { + // ── 1. Setup certificates ────────────────────────────────────────────── + let temp_dir = std::env::temp_dir().join(format!("dgw-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation should succeed")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + // Write agent certs to temp files (quiche needs file paths). + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start TCP echo server ─────────────────────────────────────────── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + // ── 3. Start QUIC listener ───────────────────────────────────────────── + // Bind a temporary UDP socket to find a free port, then release it. + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener to known port"); + + // Spawn the listener as a background task. + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + + // Give the listener a moment to be ready. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 4. Create simulated agent (quiche client) ────────────────────────── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .expect("load agent cert"); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .expect("load agent key"); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .expect("load CA cert"); + client_config.verify_peer(true); + client_config + .set_application_protos(&[ALPN_PROTOCOL]) + .expect("set ALPN"); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config) + .expect("quiche connect"); + + // ── 5. Complete mTLS handshake ───────────────────────────────────────── + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established(), "QUIC connection should be established"); + + // ── 6. Send RouteAdvertise ───────────────────────────────────────────── + let route_msg = ControlMessage::route_advertise(1, vec![echo_subnet], vec![]); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway a moment to process the route advertisement. + tokio::time::sleep(Duration::from_millis(200)).await; + // Drain any responses. + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Verify the agent is registered in the registry. + assert!( + handle.registry().get(&agent_id).is_some(), + "agent should be registered in the registry" + ); + assert_eq!(handle.registry().online_count(), 1); + + // ── 7. Gateway opens proxy stream via connect_via_agent ──────────────── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + // Spawn connect_via_agent as a background task (it will block until the agent responds). + let handle_clone = handle.clone(); + let target_str_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_str_clone) + .await + }); + + // Give the gateway time to send the ConnectMessage. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 8. Agent receives and processes proxy request ────────────────────── + // The agent needs to: + // a. Receive ConnectMessage on a new server-initiated stream + // b. Connect to the target + // c. Send ConnectResponse::Success + + // Pump the connection to receive the ConnectMessage. + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Check for readable streams (skip stream 0 which is control). + for stream_id in conn.readable() { + if stream_id == 0 { + // Drain control stream responses. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len_check = + u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len_check { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received a proxy stream from gateway"); + + // Decode ConnectMessage. + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + assert_eq!(connect_msg.target, target_str); + stream_buf.drain(..consumed); + + // Connect to the echo server. + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + + // Send ConnectResponse::Success. + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway time to process the response. + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 9. Verify proxy_task completed successfully ──────────────────────── + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .expect("proxy task should complete in time") + .expect("proxy task should not panic") + .expect("connect_via_agent should succeed"); + + // ── 10. Bidirectional data test through the full tunnel ──────────────── + // Gateway writes to QuicStream → QUIC → Agent → TCP → Echo Server → TCP → Agent → QUIC → Gateway reads + + let test_data = b"Hello from the QUIC tunnel integration test!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + + // Write test data from the "gateway side" into the QuicStream. + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + // Agent side: relay data from QUIC stream to TCP target and back. + // We need to pump the QUIC connection to deliver the data. + + // Read data from QUIC and forward to TCP target. + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + // Drain other streams. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!( + &data_from_quic[..test_data.len()], + test_data, + "data should arrive at the agent side" + ); + + // Forward to echo server. + target_tcp + .write_all(&data_from_quic[..test_data.len()]) + .await + .expect("write to echo server"); + + // Read echo response from TCP. + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp + .read_exact(&mut echo_response) + .await + .expect("read echo response"); + assert_eq!(&echo_response, test_data); + + // Send echo response back through QUIC. + conn.stream_send(proxy_stream_id, &echo_response, false) + .expect("send echo response on QUIC stream"); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give gateway time to deliver data through channels. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Read the response from the gateway-side QuicStream. + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .expect("should read response in time") + .expect("read from QuicStream"); + + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data, "echo response should match original data"); + + // ── 11. Cleanup ──────────────────────────────────────────────────────── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("E2E integration test passed!"); +} + +/// E2E test for domain-based routing. +/// +/// Same as quic_agent_tunnel_e2e but agent advertises domain "test.local" +/// alongside its subnet, and we verify domain routing works in the live registry. +/// +/// Known limitation: uses IP for final connect_via_agent (no mock DNS in test env). +#[tokio::test] +async fn quic_agent_tunnel_domain_routing_e2e() { + use agent_tunnel_proto::DomainAdvertisement; + + // ── 1. Setup certificates ── + let temp_dir = std::env::temp_dir().join(format!("dgw-domain-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start echo server and QUIC listener ── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener"); + + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 3. Create simulated agent ── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .unwrap(); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .unwrap(); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .unwrap(); + client_config.verify_peer(true); + client_config.set_application_protos(&[ALPN_PROTOCOL]).unwrap(); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config).unwrap(); + + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established()); + + // ── 4. Agent sends RouteAdvertise WITH DOMAIN ── + let route_msg = ControlMessage::route_advertise( + 1, + vec![echo_subnet], + vec![DomainAdvertisement { + domain: "test.local".to_owned(), + auto_detected: false, + }], + ); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 5. Verify domain routing via registry ── + assert!(handle.registry().get(&agent_id).is_some(), "agent should be registered"); + + let domain_agents = handle.registry().select_agents_for_domain("echo-server.test.local"); + assert_eq!(domain_agents.len(), 1, "domain routing should find the agent"); + assert_eq!(domain_agents[0].agent_id, agent_id); + + // Also verify the domain info is preserved with source tracking + let info = handle.registry().agent_info(&agent_id).expect("agent info"); + assert_eq!(info.domains.len(), 1); + assert_eq!(info.domains[0].domain, "test.local"); + assert!(!info.domains[0].auto_detected); + + // ── 6. Gateway opens proxy stream (using IP — known limitation) ── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + let handle_clone = handle.clone(); + let target_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_clone) + .await + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 7. Agent receives ConnectMessage ── + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == 0 { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len = u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received proxy stream"); + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + stream_buf.drain(..consumed); + + // ── 8. Agent connects to echo server and responds ── + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .unwrap() + .unwrap() + .expect("connect_via_agent should succeed"); + + // ── 9. Bidirectional echo test ── + let test_data = b"Domain routing works!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!(&data_from_quic[..test_data.len()], test_data); + + target_tcp.write_all(&data_from_quic[..test_data.len()]).await.unwrap(); + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp.read_exact(&mut echo_response).await.unwrap(); + assert_eq!(&echo_response, test_data); + + conn.stream_send(proxy_stream_id, &echo_response, false).unwrap(); + flush_quiche(&mut conn, &client_socket, server_addr).await; + tokio::time::sleep(Duration::from_millis(200)).await; + + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data); + + // ── Cleanup ── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("Domain routing E2E test passed!"); +} diff --git a/devolutions-gateway/src/agent_tunnel/mod.rs b/devolutions-gateway/src/agent_tunnel/mod.rs index aa4b094eb..950124a93 100644 --- a/devolutions-gateway/src/agent_tunnel/mod.rs +++ b/devolutions-gateway/src/agent_tunnel/mod.rs @@ -7,8 +7,13 @@ pub mod cert; pub mod enrollment_store; pub mod listener; pub mod registry; +pub mod routing; pub mod stream; +// Integration test needs rewriting for Quinn — kept as local-only file. +// #[cfg(test)] +// mod integration_test; + pub use enrollment_store::EnrollmentTokenStore; pub use listener::{AgentTunnelHandle, AgentTunnelListener}; pub use registry::AgentRegistry; diff --git a/devolutions-gateway/src/agent_tunnel/routing.rs b/devolutions-gateway/src/agent_tunnel/routing.rs new file mode 100644 index 000000000..2c13a0bc7 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/routing.rs @@ -0,0 +1,287 @@ +//! Shared routing pipeline for agent tunnel. +//! +//! Used by both connection forwarding (`fwd.rs`) and KDC proxy (`kdc_proxy.rs`) +//! to ensure consistent routing behavior and error messages. + +use std::net::IpAddr; +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +use uuid::Uuid; + +use super::listener::AgentTunnelHandle; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +/// Result of the routing pipeline. +/// +/// Each variant carries enough context for the caller to produce an actionable error message. +#[derive(Debug)] +pub enum RoutingDecision { + /// Route through these agent candidates (try in order, first success wins). + ViaAgent(Vec>), + /// Explicit agent_id was specified but not found in registry. + ExplicitAgentNotFound(Uuid), + /// No agent matched — caller should attempt direct connection. + Direct, +} + +/// Determines how to route a connection to the given target. +/// +/// Pipeline (in order of priority): +/// 1. Explicit agent_id (from JWT) → route to that agent +/// 2. IP target → subnet match against agent advertisements +/// 3. Hostname target → domain suffix match (longest wins) +/// 4. No match → direct connection +pub fn resolve_route(registry: &AgentRegistry, explicit_agent_id: Option, target_host: &str) -> RoutingDecision { + // Step 1: Explicit agent ID (from JWT) + if let Some(agent_id) = explicit_agent_id { + if let Some(agent) = registry.get(&agent_id) { + return RoutingDecision::ViaAgent(vec![agent]); + } + return RoutingDecision::ExplicitAgentNotFound(agent_id); + } + + // Step 2: Target is an IP address → subnet match + if let Ok(ip) = target_host.parse::() { + let agents = registry.find_agents_for_target(ip); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + return RoutingDecision::Direct; + } + + // Step 3: Target is a hostname → domain suffix match (longest wins) + let agents = registry.select_agents_for_domain(target_host); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + + // Step 4: No match → direct connect + RoutingDecision::Direct +} + +/// Try connecting to target through agent candidates (try-fail-retry). +/// +/// Returns the connected `TunnelStream` and the agent that succeeded. +/// +/// Callers must handle `RoutingDecision::ExplicitAgentNotFound` and +/// `RoutingDecision::Direct` before calling this function. +pub async fn route_and_connect( + handle: &AgentTunnelHandle, + candidates: &[Arc], + session_id: Uuid, + target: &str, +) -> Result<(TunnelStream, Arc)> { + assert!(!candidates.is_empty(), "route_and_connect called with empty candidates"); + + let mut last_error = None; + + for agent in candidates { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Routing via agent tunnel" + ); + + match handle.connect_via_agent(agent.agent_id, session_id, target).await { + Ok(stream) => { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Agent tunnel connection established" + ); + return Ok((stream, Arc::clone(agent))); + } + Err(error) => { + warn!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + error = format!("{error:#}"), + "Agent tunnel connection failed, trying next candidate" + ); + last_error = Some(error); + } + } + } + + let agent_names: Vec<&str> = candidates.iter().map(|a| a.name.as_str()).collect(); + let last_err_msg = last_error.as_ref().map(|e| format!("{e:#}")).unwrap_or_default(); + + error!( + agent_count = candidates.len(), + %target, + agents = ?agent_names, + last_error = %last_err_msg, + "All agent tunnel candidates failed" + ); + + Err(last_error.unwrap_or_else(|| { + anyhow!( + "All {} agents matching target '{}' failed to connect. Agents tried: [{}]", + candidates.len(), + target, + agent_names.join(", "), + ) + })) +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + + use agent_tunnel_proto::DomainAdvertisement; + + use super::*; + use crate::agent_tunnel::registry::AgentPeer; + + fn make_peer(name: &str) -> Arc { + Arc::new(AgentPeer::new( + Uuid::new_v4(), + name.to_owned(), + "sha256:test".to_owned(), + )) + } + + fn domain(name: &str) -> DomainAdvertisement { + DomainAdvertisement { + domain: name.to_owned(), + auto_detected: false, + } + } + + #[test] + fn route_explicit_agent_id() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + registry.register(Arc::clone(&peer)); + + match resolve_route(®istry, Some(agent_id), "anything") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_explicit_agent_id_not_found() { + let registry = AgentRegistry::new(); + let bogus_id = Uuid::new_v4(); + + match resolve_route(®istry, Some(bogus_id), "anything") { + RoutingDecision::ExplicitAgentNotFound(id) => { + assert_eq!(id, bogus_id); + } + other => panic!("expected ExplicitAgentNotFound, got {other:?}"), + } + } + + #[test] + fn route_ip_target_via_subnet() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + match resolve_route(®istry, None, "10.1.5.50") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_hostname_via_domain() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "external.example.com"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_ip_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "192.168.1.1"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_skips_offline_agents() { + let registry = AgentRegistry::new(); + let peer = make_peer("offline-agent"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + peer.last_seen.store(0, Ordering::Release); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "dc01.contoso.local"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_domain_match_returns_multiple_agents_ordered() { + let registry = AgentRegistry::new(); + + let peer_a = make_peer("agent-a"); + let subnet_a: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_a)); + + std::thread::sleep(std::time::Duration::from_millis(10)); + + let peer_b = make_peer("agent-b"); + let id_b = peer_b.agent_id; + let subnet_b: ipnetwork::Ipv4Network = "10.2.0.0/16".parse().expect("valid test subnet"); + peer_b.update_routes(1, vec![subnet_b], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_b)); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].agent_id, id_b, "most recent first"); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } +} diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index f0b1701d6..673fe7aff 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -54,6 +54,7 @@ async fn fwd_tcp( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -78,6 +79,7 @@ async fn fwd_tcp( claims, source_addr, false, + agent_tunnel_handle, ) .instrument(span) }); @@ -91,6 +93,7 @@ async fn fwd_tls( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -115,6 +118,7 @@ async fn fwd_tls( claims, source_addr, true, + agent_tunnel_handle, ) .instrument(span) }); @@ -132,6 +136,7 @@ async fn handle_fwd( claims: AssociationTokenClaims, source_addr: SocketAddr, with_tls: bool, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -154,6 +159,7 @@ async fn handle_fwd( .sessions(sessions) .subscriber_tx(subscriber_tx) .with_tls(with_tls) + .agent_tunnel_handle(agent_tunnel_handle) .build() .run() .instrument(span.clone()) @@ -184,6 +190,8 @@ struct Forward { sessions: SessionMessageSender, subscriber_tx: SubscriberSender, with_tls: bool, + #[builder(default)] + agent_tunnel_handle: Option>, } #[derive(Debug, thiserror::Error)] @@ -207,6 +215,7 @@ where sessions, subscriber_tx, with_tls, + agent_tunnel_handle, } = self; match claims.jet_rec { @@ -224,6 +233,73 @@ where let span = tracing::Span::current(); + // Route via agent tunnel using the transparent routing pipeline: + // explicit agent_id → subnet match → domain suffix match → direct connect + if let Some(handle) = &agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host_for_routing = first_target.host().to_owned(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, &target_host_for_routing); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent not found in registry"); + return Err(ForwardError::BadGateway(anyhow::anyhow!( + "Agent {id} specified in token not found in registry. \ + Verify the agent is enrolled and connected." + ))); + } + RoutingDecision::Direct => { + info!(%target_host_for_routing, "No agent match, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + if let RoutingDecision::ViaAgent(candidates) = decision { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + + let (server_stream, _matched_agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .map_err(ForwardError::BadGateway)?; + + let selected_target = first_target.clone(); + span.record("target", selected_target.to_string()); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target, + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + + return Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_stream) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .build() + .select_dissector_and_forward() + .await + .context("encountered a failure during agent tunnel traffic proxying") + .map_err(ForwardError::Internal); + } + // RoutingDecision::Direct falls through to direct connect below + } + trace!("Select and connect to target"); let ((server_stream, server_addr), selected_target) = utils::successive_try(&targets, utils::tcp_connect) diff --git a/devolutions-gateway/src/api/kdc_proxy.rs b/devolutions-gateway/src/api/kdc_proxy.rs index cf2d0243a..51673d8d8 100644 --- a/devolutions-gateway/src/api/kdc_proxy.rs +++ b/devolutions-gateway/src/api/kdc_proxy.rs @@ -25,6 +25,7 @@ async fn kdc_proxy( token_cache, jrl, recordings, + agent_tunnel_handle, .. }): State, extract::Path(token): extract::Path, @@ -105,7 +106,12 @@ async fn kdc_proxy( &claims.krb_kdc }; - let kdc_reply_message = send_krb_message(kdc_addr, &kdc_proxy_message.kerb_message.0.0).await?; + let kdc_reply_message = send_krb_message( + kdc_addr, + &kdc_proxy_message.kerb_message.0.0, + agent_tunnel_handle.as_deref(), + ) + .await?; let kdc_reply_message = KdcProxyMessage::from_raw_kerb_message(&kdc_reply_message) .map_err(HttpError::internal().with_msg("couldn't create KDC proxy reply").err())?; @@ -115,11 +121,11 @@ async fn kdc_proxy( kdc_reply_message.to_vec().map_err(HttpError::internal().err()) } -async fn read_kdc_reply_message(connection: &mut TcpStream) -> io::Result> { - let len = connection.read_u32().await?; +async fn read_kdc_reply_message(reader: &mut R) -> io::Result> { + let len = reader.read_u32().await?; let mut buf = vec![0; (len + 4).try_into().expect("u32-to-usize")]; buf[0..4].copy_from_slice(&(len.to_be_bytes())); - connection.read_exact(&mut buf[4..]).await?; + reader.read_exact(&mut buf[4..]).await?; Ok(buf) } @@ -148,7 +154,63 @@ fn unable_to_reach_kdc_server_err(error: io::Error) -> HttpError { } /// Sends the Kerberos message to the specified KDC address. -pub async fn send_krb_message(kdc_addr: &TargetAddr, message: &[u8]) -> Result, HttpError> { +/// +/// Uses the same routing pipeline as connection forwarding: +/// if an agent claims the KDC's domain/subnet, traffic goes through the tunnel. +/// Falls back to direct connect only when no agent matches. +pub async fn send_krb_message( + kdc_addr: &TargetAddr, + message: &[u8], + agent_tunnel_handle: Option<&crate::agent_tunnel::AgentTunnelHandle>, +) -> Result, HttpError> { + // Route through agent tunnel using the SAME pipeline as connection forwarding. + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let kdc_host = kdc_addr.host(); + let kdc_target = kdc_addr.to_string(); + + let decision = routing::resolve_route(handle.registry(), None, kdc_host); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent for KDC not found"); + return Err( + HttpError::bad_gateway().build(format!("Agent {id} specified for KDC not found in registry.")) + ); + } + RoutingDecision::Direct => { + info!(kdc_host = %kdc_host, "No agent match for KDC, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + // Hard commit: if an agent matched, KDC traffic MUST go through it. + // No silent fallback — consistent with connection forwarding. + if let RoutingDecision::ViaAgent(candidates) = decision { + let session_id = uuid::Uuid::new_v4(); + + let (mut stream, _agent) = routing::route_and_connect(handle, &candidates, session_id, &kdc_target) + .await + .map_err(|e| { + HttpError::bad_gateway().build(format!("KDC routing through agent tunnel failed: {e:#}")) + })?; + + stream.write_all(message).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to send KDC message through agent tunnel") + .err(), + )?; + + return read_kdc_reply_message(&mut stream).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to read KDC reply through agent tunnel") + .err(), + ); + } + // RoutingDecision::Direct falls through to direct connect below + } + let protocol = kdc_addr.scheme(); debug!("Connecting to KDC server located at {kdc_addr} using protocol {protocol}..."); diff --git a/devolutions-gateway/src/api/rdp.rs b/devolutions-gateway/src/api/rdp.rs index 65cfe5b2e..de25b1bad 100644 --- a/devolutions-gateway/src/api/rdp.rs +++ b/devolutions-gateway/src/api/rdp.rs @@ -26,6 +26,7 @@ pub async fn handler( recordings, shutdown_signal, credential_store, + agent_tunnel_handle, .. }): State, ConnectInfo(source_addr): ConnectInfo, @@ -46,6 +47,7 @@ pub async fn handler( recordings.active_recordings, source_addr, credential_store, + agent_tunnel_handle, ) .instrument(span) }); @@ -65,6 +67,7 @@ async fn handle_socket( active_recordings: Arc, source_addr: SocketAddr, credential_store: crate::credential::CredentialStoreHandle, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -82,6 +85,7 @@ async fn handle_socket( subscriber_tx, &active_recordings, &credential_store, + agent_tunnel_handle, ) .await; diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index 0c2f09e6a..fb5d1b4c6 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -32,8 +32,8 @@ pub struct Proxy { impl Proxy where - A: AsyncWrite + AsyncRead + Unpin, - B: AsyncWrite + AsyncRead + Unpin, + A: AsyncWrite + AsyncRead + Unpin + Send, + B: AsyncWrite + AsyncRead + Unpin + Send, { pub async fn select_dissector_and_forward(self) -> anyhow::Result<()> { match self.session_info.application_protocol { diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 9855a8987..b5a4223d1 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -158,25 +158,77 @@ enum CleanPathError { Io(#[from] io::Error), } -struct CleanPathResult { +/// Inner transport for the RDP server connection. +/// +/// An enum is required here because `Box` trait objects cause the compiler to +/// lose `Send` provability for the async future spawned by `ws.on_upgrade()` in the +/// RDP handler. Generics are also not viable — the type would propagate up through +/// `handle_with_credential_injection` → `handle` → `handle_socket` → `ws.on_upgrade()`, +/// which requires a concrete future type. The enum gives the compiler full type +/// information to prove `Send` while keeping the transport abstraction local. +enum ServerTransport { + Tcp(tokio::net::TcpStream), + Quic(crate::agent_tunnel::stream::TunnelStream), +} + +impl AsyncRead for ServerTransport { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_read(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ServerTransport { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_write(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_flush(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_shutdown(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_shutdown(cx), + } + } +} + +struct CleanPathAuth { claims: AssociationTokenClaims, - destination: TargetAddr, - server_addr: SocketAddr, - server_stream: tokio_rustls::client::TlsStream, - x224_rsp: Vec, } -async fn process_cleanpath( - cleanpath_pdu: RDCleanPathPdu, +/// Validate the RDCleanPath PDU token and authorize the session. +/// Pure validation — no connections established. +async fn authorize_cleanpath( + cleanpath_pdu: &RDCleanPathPdu, client_addr: SocketAddr, conf: &Conf, token_cache: &TokenCache, jrl: &CurrentJrl, active_recordings: &ActiveRecordings, sessions: &SessionMessageSender, -) -> Result { - use crate::utils; - +) -> Result { let token = cleanpath_pdu .proxy_auth .as_deref() @@ -207,10 +259,9 @@ async fn process_cleanpath( }; let span = tracing::Span::current(); - span.record("session_id", claims.jet_aid.to_string()); - // Sanity check. + // Sanity check destination in PDU vs token. match cleanpath_pdu.destination.as_deref() { Some(destination) => match TargetAddr::parse(destination, 3389) { Ok(destination) if !destination.eq(targets.first()) => { @@ -224,14 +275,78 @@ async fn process_cleanpath( None => warn!("RDCleanPath PDU is missing the destination field"), } + Ok(CleanPathAuth { claims }) +} + +struct ConnectedRdpServer { + tls_stream: tokio_rustls::client::TlsStream, + server_addr: SocketAddr, + selected_target: TargetAddr, + x224_rsp: Vec, +} + +/// Establish a connection to the RDP server: route (agent/direct) → connect → X224 → TLS. +async fn connect_rdp_server( + claims: &AssociationTokenClaims, + cleanpath_pdu: RDCleanPathPdu, + agent_tunnel_handle: Option<&Arc>, +) -> Result { + use crate::utils; + + let crate::token::ConnectionMode::Fwd { ref targets, .. } = claims.jet_cm else { + return anyhow::Error::msg("unexpected connection mode") + .pipe(CleanPathError::BadRequest) + .pipe(Err); + }; + trace!(?targets, "Connecting to destination server"); - let ((mut server_stream, server_addr), selected_target) = utils::successive_try(targets, utils::tcp_connect) - .await - .context("connect to RDP server")?; + // Route through agent tunnel if available, otherwise connect directly. + let (mut server_stream, server_addr, selected_target): (ServerTransport, SocketAddr, &TargetAddr) = + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host = first_target.host(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, target_host); + + match decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + return Err(CleanPathError::Internal(anyhow::anyhow!( + "Agent {id} specified in token not found in registry" + ))); + } + RoutingDecision::ViaAgent(candidates) => { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + info!(target = %target_str, "Routing RDP via agent tunnel"); + + let (quic_stream, _agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .context("connect to RDP server via agent tunnel")?; + + // TODO: agent-routed sessions use a placeholder address; monitoring tools + // that rely on server_addr will see 0.0.0.0:0 for tunneled connections. + let placeholder_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + (ServerTransport::Quic(quic_stream), placeholder_addr, first_target) + } + RoutingDecision::Direct => { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + } + } + } else { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + }; debug!(%selected_target, "Connected to destination server"); - span.record("target", selected_target.to_string()); + tracing::Span::current().record("target", selected_target.to_string()); // Send preconnection blob if applicable. if let Some(pcb) = cleanpath_pdu.preconnection_blob { @@ -245,8 +360,6 @@ async fn process_cleanpath( .map_err(CleanPathError::BadRequest)?; server_stream.write_all(x224_req.as_bytes()).await?; - // == Receive server X224 connection response == - trace!("Receiving X224 response"); let x224_rsp = read_x224_response(&mut server_stream) @@ -256,20 +369,17 @@ async fn process_cleanpath( trace!("Establishing TLS connection with server"); - // == Establish TLS connection with server == - - let server_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) + let tls_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) .await .map_err(|source| CleanPathError::TlsHandshake { source, target_server: selected_target.to_owned(), })?; - Ok(CleanPathResult { - destination: selected_target.to_owned(), - claims, + Ok(ConnectedRdpServer { + tls_stream, server_addr, - server_stream, + selected_target: selected_target.to_owned(), x224_rsp, }) } @@ -287,6 +397,7 @@ async fn handle_with_credential_injection( active_recordings: &ActiveRecordings, cleanpath_pdu: RDCleanPathPdu, credential_entry: Arc, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { let tls_conf = conf.credssp_tls.get().context("CredSSP TLS configuration")?; @@ -318,16 +429,9 @@ async fn handle_with_credential_injection( ) }; - // Run normal RDCleanPath flow (this will handle server-side TLS and get certs). - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - .. - } = process_cleanpath( - cleanpath_pdu, + // Authorize and connect to the RDP server. + let CleanPathAuth { claims } = authorize_cleanpath( + &cleanpath_pdu, client_addr, &conf, token_cache, @@ -336,7 +440,16 @@ async fn handle_with_credential_injection( &sessions, ) .await - .context("RDCleanPath processing failed")?; + .context("RDCleanPath authorization failed")?; + + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connect_rdp_server(&claims, cleanpath_pdu, agent_tunnel_handle.as_ref()) + .await + .context("RDCleanPath connection failed")?; // Retrieve the Gateway TLS public key that must be used for client-proxy CredSSP later on. let gateway_cert_chain_handle = tokio::spawn(crate::tls::get_cert_chain_for_acceptor_cached( @@ -532,6 +645,7 @@ pub async fn handle( subscriber_tx: SubscriberSender, active_recordings: &ActiveRecordings, credential_store: &CredentialStoreHandle, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { // Special handshake of our RDP extension @@ -569,27 +683,29 @@ pub async fn handle( active_recordings, cleanpath_pdu, entry, + agent_tunnel_handle.clone(), ) .await; } trace!("Processing RDCleanPath"); - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - } = match process_cleanpath( - cleanpath_pdu, - client_addr, - &conf, - token_cache, - jrl, - active_recordings, - &sessions, - ) + let (auth, connected) = match async { + let auth = authorize_cleanpath( + &cleanpath_pdu, + client_addr, + &conf, + token_cache, + jrl, + active_recordings, + &sessions, + ) + .await?; + + let connected = connect_rdp_server(&auth.claims, cleanpath_pdu, agent_tunnel_handle.as_ref()).await?; + + Ok::<_, CleanPathError>((auth, connected)) + } .await { Ok(result) => result, @@ -602,6 +718,13 @@ pub async fn handle( } }; + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connected; + // == Send success RDCleanPathPdu response == let x509_chain = server_stream @@ -622,13 +745,13 @@ pub async fn handle( // == Start actual RDP session == let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) + .id(auth.claims.jet_aid) + .application_protocol(auth.claims.jet_ap) .details(ConnectionModeDetails::Fwd { destination_host: destination.clone(), }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) + .time_to_live(auth.claims.jet_ttl) + .recording_policy(auth.claims.jet_rec) .build(); info!("RDP-TLS forwarding (RDCleanPath)"); @@ -642,7 +765,7 @@ pub async fn handle( .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(auth.claims.jet_reuse)) .build() .select_dissector_and_forward() .await diff --git a/devolutions-gateway/src/rdp_proxy.rs b/devolutions-gateway/src/rdp_proxy.rs index b3dc466a7..b7fddfdd5 100644 --- a/devolutions-gateway/src/rdp_proxy.rs +++ b/devolutions-gateway/src/rdp_proxy.rs @@ -637,7 +637,7 @@ where async fn send_network_request(request: &NetworkRequest) -> anyhow::Result> { let target_addr = TargetAddr::parse(request.url.as_str(), Some(88))?; - send_krb_message(&target_addr, &request.data) + send_krb_message(&target_addr, &request.data, None) .await .map_err(|err| anyhow::Error::msg("failed to send KDC message").context(err)) }