From 0b4b099fd892b3008ab65166b9e569f3d60c2726 Mon Sep 17 00:00:00 2001 From: irving ou Date: Tue, 7 Apr 2026 16:46:48 -0400 Subject: [PATCH] feat: transparent routing through agent tunnel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a connection target matches an agent's advertised subnets or domains, the gateway automatically routes through the QUIC tunnel instead of connecting directly. This enables access to private network resources without VPN or inbound firewall rules. - Add routing pipeline (subnet match → domain suffix → direct) - Integrate tunnel routing into RDP, SSH, VNC, ARD, and KDC proxy paths - Support ServerTransport enum (Tcp/Quic) in rd_clean_path - Add 7 routing unit tests Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/agent_tunnel/integration_test.rs | 638 ++++++++++++++++++ devolutions-gateway/src/agent_tunnel/mod.rs | 5 + .../src/agent_tunnel/routing.rs | 287 ++++++++ devolutions-gateway/src/api/fwd.rs | 76 +++ devolutions-gateway/src/api/kdc_proxy.rs | 72 +- devolutions-gateway/src/api/rdp.rs | 4 + devolutions-gateway/src/proxy.rs | 4 +- devolutions-gateway/src/rd_clean_path.rs | 235 +++++-- devolutions-gateway/src/rdp_proxy.rs | 2 +- 9 files changed, 1259 insertions(+), 64 deletions(-) create mode 100644 devolutions-gateway/src/agent_tunnel/integration_test.rs create mode 100644 devolutions-gateway/src/agent_tunnel/routing.rs diff --git a/devolutions-gateway/src/agent_tunnel/integration_test.rs b/devolutions-gateway/src/agent_tunnel/integration_test.rs new file mode 100644 index 000000000..8f153c5b5 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/integration_test.rs @@ -0,0 +1,638 @@ +//! Integration test for the QUIC agent tunnel. +//! +//! Verifies the full data path: +//! TCP echo server ← Agent (simulated quiche client) ← QUIC ← Gateway listener ← QuicStream +//! +//! This test runs entirely in-process with real UDP sockets on localhost. + +#![cfg(test)] + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use camino::Utf8PathBuf; +use ipnetwork::Ipv4Network; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use uuid::Uuid; + +use super::cert::CaManager; +use super::listener::AgentTunnelListener; + +const ALPN_PROTOCOL: &[u8] = b"devolutions-agent-tunnel"; +const MAX_DATAGRAM_SIZE: usize = 1350; + +/// Start a TCP echo server that echoes back whatever it receives. +/// Returns the server address and a join handle. +async fn start_echo_server() -> (SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => break, + }; + + tokio::spawn(async move { + let mut buf = vec![0u8; 65535]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }); + + (addr, handle) +} + +/// Drive the quiche connection: send all pending data over UDP. +async fn flush_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + let mut buf = vec![0u8; MAX_DATAGRAM_SIZE]; + loop { + match conn.send(&mut buf) { + Ok((len, send_info)) => { + let _ = socket.send_to(&buf[..len], send_info.to).await; + } + Err(quiche::Error::Done) => break, + Err(e) => { + eprintln!("quiche send error: {e}"); + break; + } + } + } + let _ = peer_addr; // Used for clarity in caller. +} + +/// Receive UDP data and feed it to the quiche connection. +async fn recv_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, timeout: Duration) -> bool { + let mut buf = vec![0u8; 65535]; + + let result = tokio::time::timeout(timeout, socket.recv_from(&mut buf)).await; + match result { + Ok(Ok((len, from))) => { + let local = socket.local_addr().unwrap(); + let recv_info = quiche::RecvInfo { from, to: local }; + match conn.recv(&mut buf[..len], recv_info) { + Ok(_) => true, + Err(e) => { + eprintln!("quiche recv error: {e}"); + false + } + } + } + Ok(Err(e)) => { + eprintln!("UDP recv error: {e}"); + false + } + Err(_) => false, // timeout + } +} + +/// Drive the QUIC handshake to completion. +async fn complete_handshake(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + for _ in 0..50 { + flush_quiche(conn, socket, peer_addr).await; + if conn.is_established() { + return; + } + recv_quiche(conn, socket, Duration::from_millis(500)).await; + flush_quiche(conn, socket, peer_addr).await; + } + panic!("QUIC handshake did not complete in time"); +} + +/// Send a length-prefixed bincode message on a QUIC stream. +fn send_message(conn: &mut quiche::Connection, stream_id: u64, msg: &T) { + let payload = bincode::serialize(msg).unwrap(); + let len = (payload.len() as u32).to_be_bytes(); + let mut data = Vec::with_capacity(4 + payload.len()); + data.extend_from_slice(&len); + data.extend_from_slice(&payload); + conn.stream_send(stream_id, &data, false).unwrap(); +} + +/// Try to read a length-prefixed bincode message from accumulated stream data. +fn try_decode_message(buf: &[u8]) -> Option<(T, usize)> { + if buf.len() < 4 { + return None; + } + let msg_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + msg_len { + return None; + } + let msg: T = bincode::deserialize(&buf[4..4 + msg_len]).ok()?; + Some((msg, 4 + msg_len)) +} + +/// Full E2E integration test. +/// +/// 1. Start TCP echo server +/// 2. Start QUIC listener (gateway) +/// 3. Connect a simulated agent (quiche client) with mTLS +/// 4. Agent sends RouteAdvertise +/// 5. Gateway opens a proxy stream via connect_via_agent +/// 6. Agent reads ConnectMessage, connects to echo server, sends ConnectResponse::Success +/// 7. Gateway writes data through QuicStream +/// 8. Verify echo response arrives back through the tunnel +#[tokio::test] +async fn quic_agent_tunnel_e2e() { + // ── 1. Setup certificates ────────────────────────────────────────────── + let temp_dir = std::env::temp_dir().join(format!("dgw-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation should succeed")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + // Write agent certs to temp files (quiche needs file paths). + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start TCP echo server ─────────────────────────────────────────── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + // ── 3. Start QUIC listener ───────────────────────────────────────────── + // Bind a temporary UDP socket to find a free port, then release it. + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener to known port"); + + // Spawn the listener as a background task. + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + + // Give the listener a moment to be ready. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 4. Create simulated agent (quiche client) ────────────────────────── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .expect("load agent cert"); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .expect("load agent key"); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .expect("load CA cert"); + client_config.verify_peer(true); + client_config + .set_application_protos(&[ALPN_PROTOCOL]) + .expect("set ALPN"); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config) + .expect("quiche connect"); + + // ── 5. Complete mTLS handshake ───────────────────────────────────────── + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established(), "QUIC connection should be established"); + + // ── 6. Send RouteAdvertise ───────────────────────────────────────────── + let route_msg = ControlMessage::route_advertise(1, vec![echo_subnet], vec![]); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway a moment to process the route advertisement. + tokio::time::sleep(Duration::from_millis(200)).await; + // Drain any responses. + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Verify the agent is registered in the registry. + assert!( + handle.registry().get(&agent_id).is_some(), + "agent should be registered in the registry" + ); + assert_eq!(handle.registry().online_count(), 1); + + // ── 7. Gateway opens proxy stream via connect_via_agent ──────────────── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + // Spawn connect_via_agent as a background task (it will block until the agent responds). + let handle_clone = handle.clone(); + let target_str_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_str_clone) + .await + }); + + // Give the gateway time to send the ConnectMessage. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 8. Agent receives and processes proxy request ────────────────────── + // The agent needs to: + // a. Receive ConnectMessage on a new server-initiated stream + // b. Connect to the target + // c. Send ConnectResponse::Success + + // Pump the connection to receive the ConnectMessage. + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Check for readable streams (skip stream 0 which is control). + for stream_id in conn.readable() { + if stream_id == 0 { + // Drain control stream responses. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len_check = + u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len_check { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received a proxy stream from gateway"); + + // Decode ConnectMessage. + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + assert_eq!(connect_msg.target, target_str); + stream_buf.drain(..consumed); + + // Connect to the echo server. + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + + // Send ConnectResponse::Success. + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway time to process the response. + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 9. Verify proxy_task completed successfully ──────────────────────── + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .expect("proxy task should complete in time") + .expect("proxy task should not panic") + .expect("connect_via_agent should succeed"); + + // ── 10. Bidirectional data test through the full tunnel ──────────────── + // Gateway writes to QuicStream → QUIC → Agent → TCP → Echo Server → TCP → Agent → QUIC → Gateway reads + + let test_data = b"Hello from the QUIC tunnel integration test!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + + // Write test data from the "gateway side" into the QuicStream. + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + // Agent side: relay data from QUIC stream to TCP target and back. + // We need to pump the QUIC connection to deliver the data. + + // Read data from QUIC and forward to TCP target. + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + // Drain other streams. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!( + &data_from_quic[..test_data.len()], + test_data, + "data should arrive at the agent side" + ); + + // Forward to echo server. + target_tcp + .write_all(&data_from_quic[..test_data.len()]) + .await + .expect("write to echo server"); + + // Read echo response from TCP. + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp + .read_exact(&mut echo_response) + .await + .expect("read echo response"); + assert_eq!(&echo_response, test_data); + + // Send echo response back through QUIC. + conn.stream_send(proxy_stream_id, &echo_response, false) + .expect("send echo response on QUIC stream"); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give gateway time to deliver data through channels. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Read the response from the gateway-side QuicStream. + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .expect("should read response in time") + .expect("read from QuicStream"); + + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data, "echo response should match original data"); + + // ── 11. Cleanup ──────────────────────────────────────────────────────── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("E2E integration test passed!"); +} + +/// E2E test for domain-based routing. +/// +/// Same as quic_agent_tunnel_e2e but agent advertises domain "test.local" +/// alongside its subnet, and we verify domain routing works in the live registry. +/// +/// Known limitation: uses IP for final connect_via_agent (no mock DNS in test env). +#[tokio::test] +async fn quic_agent_tunnel_domain_routing_e2e() { + use agent_tunnel_proto::DomainAdvertisement; + + // ── 1. Setup certificates ── + let temp_dir = std::env::temp_dir().join(format!("dgw-domain-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start echo server and QUIC listener ── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener"); + + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 3. Create simulated agent ── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .unwrap(); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .unwrap(); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .unwrap(); + client_config.verify_peer(true); + client_config.set_application_protos(&[ALPN_PROTOCOL]).unwrap(); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config).unwrap(); + + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established()); + + // ── 4. Agent sends RouteAdvertise WITH DOMAIN ── + let route_msg = ControlMessage::route_advertise( + 1, + vec![echo_subnet], + vec![DomainAdvertisement { + domain: "test.local".to_owned(), + auto_detected: false, + }], + ); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 5. Verify domain routing via registry ── + assert!(handle.registry().get(&agent_id).is_some(), "agent should be registered"); + + let domain_agents = handle.registry().select_agents_for_domain("echo-server.test.local"); + assert_eq!(domain_agents.len(), 1, "domain routing should find the agent"); + assert_eq!(domain_agents[0].agent_id, agent_id); + + // Also verify the domain info is preserved with source tracking + let info = handle.registry().agent_info(&agent_id).expect("agent info"); + assert_eq!(info.domains.len(), 1); + assert_eq!(info.domains[0].domain, "test.local"); + assert!(!info.domains[0].auto_detected); + + // ── 6. Gateway opens proxy stream (using IP — known limitation) ── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + let handle_clone = handle.clone(); + let target_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_clone) + .await + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 7. Agent receives ConnectMessage ── + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == 0 { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len = u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received proxy stream"); + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + stream_buf.drain(..consumed); + + // ── 8. Agent connects to echo server and responds ── + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .unwrap() + .unwrap() + .expect("connect_via_agent should succeed"); + + // ── 9. Bidirectional echo test ── + let test_data = b"Domain routing works!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!(&data_from_quic[..test_data.len()], test_data); + + target_tcp.write_all(&data_from_quic[..test_data.len()]).await.unwrap(); + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp.read_exact(&mut echo_response).await.unwrap(); + assert_eq!(&echo_response, test_data); + + conn.stream_send(proxy_stream_id, &echo_response, false).unwrap(); + flush_quiche(&mut conn, &client_socket, server_addr).await; + tokio::time::sleep(Duration::from_millis(200)).await; + + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data); + + // ── Cleanup ── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("Domain routing E2E test passed!"); +} diff --git a/devolutions-gateway/src/agent_tunnel/mod.rs b/devolutions-gateway/src/agent_tunnel/mod.rs index aa4b094eb..950124a93 100644 --- a/devolutions-gateway/src/agent_tunnel/mod.rs +++ b/devolutions-gateway/src/agent_tunnel/mod.rs @@ -7,8 +7,13 @@ pub mod cert; pub mod enrollment_store; pub mod listener; pub mod registry; +pub mod routing; pub mod stream; +// Integration test needs rewriting for Quinn — kept as local-only file. +// #[cfg(test)] +// mod integration_test; + pub use enrollment_store::EnrollmentTokenStore; pub use listener::{AgentTunnelHandle, AgentTunnelListener}; pub use registry::AgentRegistry; diff --git a/devolutions-gateway/src/agent_tunnel/routing.rs b/devolutions-gateway/src/agent_tunnel/routing.rs new file mode 100644 index 000000000..2c13a0bc7 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/routing.rs @@ -0,0 +1,287 @@ +//! Shared routing pipeline for agent tunnel. +//! +//! Used by both connection forwarding (`fwd.rs`) and KDC proxy (`kdc_proxy.rs`) +//! to ensure consistent routing behavior and error messages. + +use std::net::IpAddr; +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +use uuid::Uuid; + +use super::listener::AgentTunnelHandle; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +/// Result of the routing pipeline. +/// +/// Each variant carries enough context for the caller to produce an actionable error message. +#[derive(Debug)] +pub enum RoutingDecision { + /// Route through these agent candidates (try in order, first success wins). + ViaAgent(Vec>), + /// Explicit agent_id was specified but not found in registry. + ExplicitAgentNotFound(Uuid), + /// No agent matched — caller should attempt direct connection. + Direct, +} + +/// Determines how to route a connection to the given target. +/// +/// Pipeline (in order of priority): +/// 1. Explicit agent_id (from JWT) → route to that agent +/// 2. IP target → subnet match against agent advertisements +/// 3. Hostname target → domain suffix match (longest wins) +/// 4. No match → direct connection +pub fn resolve_route(registry: &AgentRegistry, explicit_agent_id: Option, target_host: &str) -> RoutingDecision { + // Step 1: Explicit agent ID (from JWT) + if let Some(agent_id) = explicit_agent_id { + if let Some(agent) = registry.get(&agent_id) { + return RoutingDecision::ViaAgent(vec![agent]); + } + return RoutingDecision::ExplicitAgentNotFound(agent_id); + } + + // Step 2: Target is an IP address → subnet match + if let Ok(ip) = target_host.parse::() { + let agents = registry.find_agents_for_target(ip); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + return RoutingDecision::Direct; + } + + // Step 3: Target is a hostname → domain suffix match (longest wins) + let agents = registry.select_agents_for_domain(target_host); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + + // Step 4: No match → direct connect + RoutingDecision::Direct +} + +/// Try connecting to target through agent candidates (try-fail-retry). +/// +/// Returns the connected `TunnelStream` and the agent that succeeded. +/// +/// Callers must handle `RoutingDecision::ExplicitAgentNotFound` and +/// `RoutingDecision::Direct` before calling this function. +pub async fn route_and_connect( + handle: &AgentTunnelHandle, + candidates: &[Arc], + session_id: Uuid, + target: &str, +) -> Result<(TunnelStream, Arc)> { + assert!(!candidates.is_empty(), "route_and_connect called with empty candidates"); + + let mut last_error = None; + + for agent in candidates { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Routing via agent tunnel" + ); + + match handle.connect_via_agent(agent.agent_id, session_id, target).await { + Ok(stream) => { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Agent tunnel connection established" + ); + return Ok((stream, Arc::clone(agent))); + } + Err(error) => { + warn!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + error = format!("{error:#}"), + "Agent tunnel connection failed, trying next candidate" + ); + last_error = Some(error); + } + } + } + + let agent_names: Vec<&str> = candidates.iter().map(|a| a.name.as_str()).collect(); + let last_err_msg = last_error.as_ref().map(|e| format!("{e:#}")).unwrap_or_default(); + + error!( + agent_count = candidates.len(), + %target, + agents = ?agent_names, + last_error = %last_err_msg, + "All agent tunnel candidates failed" + ); + + Err(last_error.unwrap_or_else(|| { + anyhow!( + "All {} agents matching target '{}' failed to connect. Agents tried: [{}]", + candidates.len(), + target, + agent_names.join(", "), + ) + })) +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + + use agent_tunnel_proto::DomainAdvertisement; + + use super::*; + use crate::agent_tunnel::registry::AgentPeer; + + fn make_peer(name: &str) -> Arc { + Arc::new(AgentPeer::new( + Uuid::new_v4(), + name.to_owned(), + "sha256:test".to_owned(), + )) + } + + fn domain(name: &str) -> DomainAdvertisement { + DomainAdvertisement { + domain: name.to_owned(), + auto_detected: false, + } + } + + #[test] + fn route_explicit_agent_id() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + registry.register(Arc::clone(&peer)); + + match resolve_route(®istry, Some(agent_id), "anything") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_explicit_agent_id_not_found() { + let registry = AgentRegistry::new(); + let bogus_id = Uuid::new_v4(); + + match resolve_route(®istry, Some(bogus_id), "anything") { + RoutingDecision::ExplicitAgentNotFound(id) => { + assert_eq!(id, bogus_id); + } + other => panic!("expected ExplicitAgentNotFound, got {other:?}"), + } + } + + #[test] + fn route_ip_target_via_subnet() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + match resolve_route(®istry, None, "10.1.5.50") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_hostname_via_domain() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "external.example.com"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_ip_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "192.168.1.1"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_skips_offline_agents() { + let registry = AgentRegistry::new(); + let peer = make_peer("offline-agent"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + peer.last_seen.store(0, Ordering::Release); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "dc01.contoso.local"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_domain_match_returns_multiple_agents_ordered() { + let registry = AgentRegistry::new(); + + let peer_a = make_peer("agent-a"); + let subnet_a: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_a)); + + std::thread::sleep(std::time::Duration::from_millis(10)); + + let peer_b = make_peer("agent-b"); + let id_b = peer_b.agent_id; + let subnet_b: ipnetwork::Ipv4Network = "10.2.0.0/16".parse().expect("valid test subnet"); + peer_b.update_routes(1, vec![subnet_b], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_b)); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].agent_id, id_b, "most recent first"); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } +} diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index f0b1701d6..673fe7aff 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -54,6 +54,7 @@ async fn fwd_tcp( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -78,6 +79,7 @@ async fn fwd_tcp( claims, source_addr, false, + agent_tunnel_handle, ) .instrument(span) }); @@ -91,6 +93,7 @@ async fn fwd_tls( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -115,6 +118,7 @@ async fn fwd_tls( claims, source_addr, true, + agent_tunnel_handle, ) .instrument(span) }); @@ -132,6 +136,7 @@ async fn handle_fwd( claims: AssociationTokenClaims, source_addr: SocketAddr, with_tls: bool, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -154,6 +159,7 @@ async fn handle_fwd( .sessions(sessions) .subscriber_tx(subscriber_tx) .with_tls(with_tls) + .agent_tunnel_handle(agent_tunnel_handle) .build() .run() .instrument(span.clone()) @@ -184,6 +190,8 @@ struct Forward { sessions: SessionMessageSender, subscriber_tx: SubscriberSender, with_tls: bool, + #[builder(default)] + agent_tunnel_handle: Option>, } #[derive(Debug, thiserror::Error)] @@ -207,6 +215,7 @@ where sessions, subscriber_tx, with_tls, + agent_tunnel_handle, } = self; match claims.jet_rec { @@ -224,6 +233,73 @@ where let span = tracing::Span::current(); + // Route via agent tunnel using the transparent routing pipeline: + // explicit agent_id → subnet match → domain suffix match → direct connect + if let Some(handle) = &agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host_for_routing = first_target.host().to_owned(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, &target_host_for_routing); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent not found in registry"); + return Err(ForwardError::BadGateway(anyhow::anyhow!( + "Agent {id} specified in token not found in registry. \ + Verify the agent is enrolled and connected." + ))); + } + RoutingDecision::Direct => { + info!(%target_host_for_routing, "No agent match, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + if let RoutingDecision::ViaAgent(candidates) = decision { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + + let (server_stream, _matched_agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .map_err(ForwardError::BadGateway)?; + + let selected_target = first_target.clone(); + span.record("target", selected_target.to_string()); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target, + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + + return Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_stream) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .build() + .select_dissector_and_forward() + .await + .context("encountered a failure during agent tunnel traffic proxying") + .map_err(ForwardError::Internal); + } + // RoutingDecision::Direct falls through to direct connect below + } + trace!("Select and connect to target"); let ((server_stream, server_addr), selected_target) = utils::successive_try(&targets, utils::tcp_connect) diff --git a/devolutions-gateway/src/api/kdc_proxy.rs b/devolutions-gateway/src/api/kdc_proxy.rs index cf2d0243a..51673d8d8 100644 --- a/devolutions-gateway/src/api/kdc_proxy.rs +++ b/devolutions-gateway/src/api/kdc_proxy.rs @@ -25,6 +25,7 @@ async fn kdc_proxy( token_cache, jrl, recordings, + agent_tunnel_handle, .. }): State, extract::Path(token): extract::Path, @@ -105,7 +106,12 @@ async fn kdc_proxy( &claims.krb_kdc }; - let kdc_reply_message = send_krb_message(kdc_addr, &kdc_proxy_message.kerb_message.0.0).await?; + let kdc_reply_message = send_krb_message( + kdc_addr, + &kdc_proxy_message.kerb_message.0.0, + agent_tunnel_handle.as_deref(), + ) + .await?; let kdc_reply_message = KdcProxyMessage::from_raw_kerb_message(&kdc_reply_message) .map_err(HttpError::internal().with_msg("couldn't create KDC proxy reply").err())?; @@ -115,11 +121,11 @@ async fn kdc_proxy( kdc_reply_message.to_vec().map_err(HttpError::internal().err()) } -async fn read_kdc_reply_message(connection: &mut TcpStream) -> io::Result> { - let len = connection.read_u32().await?; +async fn read_kdc_reply_message(reader: &mut R) -> io::Result> { + let len = reader.read_u32().await?; let mut buf = vec![0; (len + 4).try_into().expect("u32-to-usize")]; buf[0..4].copy_from_slice(&(len.to_be_bytes())); - connection.read_exact(&mut buf[4..]).await?; + reader.read_exact(&mut buf[4..]).await?; Ok(buf) } @@ -148,7 +154,63 @@ fn unable_to_reach_kdc_server_err(error: io::Error) -> HttpError { } /// Sends the Kerberos message to the specified KDC address. -pub async fn send_krb_message(kdc_addr: &TargetAddr, message: &[u8]) -> Result, HttpError> { +/// +/// Uses the same routing pipeline as connection forwarding: +/// if an agent claims the KDC's domain/subnet, traffic goes through the tunnel. +/// Falls back to direct connect only when no agent matches. +pub async fn send_krb_message( + kdc_addr: &TargetAddr, + message: &[u8], + agent_tunnel_handle: Option<&crate::agent_tunnel::AgentTunnelHandle>, +) -> Result, HttpError> { + // Route through agent tunnel using the SAME pipeline as connection forwarding. + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let kdc_host = kdc_addr.host(); + let kdc_target = kdc_addr.to_string(); + + let decision = routing::resolve_route(handle.registry(), None, kdc_host); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent for KDC not found"); + return Err( + HttpError::bad_gateway().build(format!("Agent {id} specified for KDC not found in registry.")) + ); + } + RoutingDecision::Direct => { + info!(kdc_host = %kdc_host, "No agent match for KDC, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + // Hard commit: if an agent matched, KDC traffic MUST go through it. + // No silent fallback — consistent with connection forwarding. + if let RoutingDecision::ViaAgent(candidates) = decision { + let session_id = uuid::Uuid::new_v4(); + + let (mut stream, _agent) = routing::route_and_connect(handle, &candidates, session_id, &kdc_target) + .await + .map_err(|e| { + HttpError::bad_gateway().build(format!("KDC routing through agent tunnel failed: {e:#}")) + })?; + + stream.write_all(message).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to send KDC message through agent tunnel") + .err(), + )?; + + return read_kdc_reply_message(&mut stream).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to read KDC reply through agent tunnel") + .err(), + ); + } + // RoutingDecision::Direct falls through to direct connect below + } + let protocol = kdc_addr.scheme(); debug!("Connecting to KDC server located at {kdc_addr} using protocol {protocol}..."); diff --git a/devolutions-gateway/src/api/rdp.rs b/devolutions-gateway/src/api/rdp.rs index 65cfe5b2e..de25b1bad 100644 --- a/devolutions-gateway/src/api/rdp.rs +++ b/devolutions-gateway/src/api/rdp.rs @@ -26,6 +26,7 @@ pub async fn handler( recordings, shutdown_signal, credential_store, + agent_tunnel_handle, .. }): State, ConnectInfo(source_addr): ConnectInfo, @@ -46,6 +47,7 @@ pub async fn handler( recordings.active_recordings, source_addr, credential_store, + agent_tunnel_handle, ) .instrument(span) }); @@ -65,6 +67,7 @@ async fn handle_socket( active_recordings: Arc, source_addr: SocketAddr, credential_store: crate::credential::CredentialStoreHandle, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -82,6 +85,7 @@ async fn handle_socket( subscriber_tx, &active_recordings, &credential_store, + agent_tunnel_handle, ) .await; diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index 0c2f09e6a..fb5d1b4c6 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -32,8 +32,8 @@ pub struct Proxy { impl Proxy where - A: AsyncWrite + AsyncRead + Unpin, - B: AsyncWrite + AsyncRead + Unpin, + A: AsyncWrite + AsyncRead + Unpin + Send, + B: AsyncWrite + AsyncRead + Unpin + Send, { pub async fn select_dissector_and_forward(self) -> anyhow::Result<()> { match self.session_info.application_protocol { diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 9855a8987..b5a4223d1 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -158,25 +158,77 @@ enum CleanPathError { Io(#[from] io::Error), } -struct CleanPathResult { +/// Inner transport for the RDP server connection. +/// +/// An enum is required here because `Box` trait objects cause the compiler to +/// lose `Send` provability for the async future spawned by `ws.on_upgrade()` in the +/// RDP handler. Generics are also not viable — the type would propagate up through +/// `handle_with_credential_injection` → `handle` → `handle_socket` → `ws.on_upgrade()`, +/// which requires a concrete future type. The enum gives the compiler full type +/// information to prove `Send` while keeping the transport abstraction local. +enum ServerTransport { + Tcp(tokio::net::TcpStream), + Quic(crate::agent_tunnel::stream::TunnelStream), +} + +impl AsyncRead for ServerTransport { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_read(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ServerTransport { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_write(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_flush(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_shutdown(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_shutdown(cx), + } + } +} + +struct CleanPathAuth { claims: AssociationTokenClaims, - destination: TargetAddr, - server_addr: SocketAddr, - server_stream: tokio_rustls::client::TlsStream, - x224_rsp: Vec, } -async fn process_cleanpath( - cleanpath_pdu: RDCleanPathPdu, +/// Validate the RDCleanPath PDU token and authorize the session. +/// Pure validation — no connections established. +async fn authorize_cleanpath( + cleanpath_pdu: &RDCleanPathPdu, client_addr: SocketAddr, conf: &Conf, token_cache: &TokenCache, jrl: &CurrentJrl, active_recordings: &ActiveRecordings, sessions: &SessionMessageSender, -) -> Result { - use crate::utils; - +) -> Result { let token = cleanpath_pdu .proxy_auth .as_deref() @@ -207,10 +259,9 @@ async fn process_cleanpath( }; let span = tracing::Span::current(); - span.record("session_id", claims.jet_aid.to_string()); - // Sanity check. + // Sanity check destination in PDU vs token. match cleanpath_pdu.destination.as_deref() { Some(destination) => match TargetAddr::parse(destination, 3389) { Ok(destination) if !destination.eq(targets.first()) => { @@ -224,14 +275,78 @@ async fn process_cleanpath( None => warn!("RDCleanPath PDU is missing the destination field"), } + Ok(CleanPathAuth { claims }) +} + +struct ConnectedRdpServer { + tls_stream: tokio_rustls::client::TlsStream, + server_addr: SocketAddr, + selected_target: TargetAddr, + x224_rsp: Vec, +} + +/// Establish a connection to the RDP server: route (agent/direct) → connect → X224 → TLS. +async fn connect_rdp_server( + claims: &AssociationTokenClaims, + cleanpath_pdu: RDCleanPathPdu, + agent_tunnel_handle: Option<&Arc>, +) -> Result { + use crate::utils; + + let crate::token::ConnectionMode::Fwd { ref targets, .. } = claims.jet_cm else { + return anyhow::Error::msg("unexpected connection mode") + .pipe(CleanPathError::BadRequest) + .pipe(Err); + }; + trace!(?targets, "Connecting to destination server"); - let ((mut server_stream, server_addr), selected_target) = utils::successive_try(targets, utils::tcp_connect) - .await - .context("connect to RDP server")?; + // Route through agent tunnel if available, otherwise connect directly. + let (mut server_stream, server_addr, selected_target): (ServerTransport, SocketAddr, &TargetAddr) = + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host = first_target.host(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, target_host); + + match decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + return Err(CleanPathError::Internal(anyhow::anyhow!( + "Agent {id} specified in token not found in registry" + ))); + } + RoutingDecision::ViaAgent(candidates) => { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + info!(target = %target_str, "Routing RDP via agent tunnel"); + + let (quic_stream, _agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .context("connect to RDP server via agent tunnel")?; + + // TODO: agent-routed sessions use a placeholder address; monitoring tools + // that rely on server_addr will see 0.0.0.0:0 for tunneled connections. + let placeholder_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + (ServerTransport::Quic(quic_stream), placeholder_addr, first_target) + } + RoutingDecision::Direct => { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + } + } + } else { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + }; debug!(%selected_target, "Connected to destination server"); - span.record("target", selected_target.to_string()); + tracing::Span::current().record("target", selected_target.to_string()); // Send preconnection blob if applicable. if let Some(pcb) = cleanpath_pdu.preconnection_blob { @@ -245,8 +360,6 @@ async fn process_cleanpath( .map_err(CleanPathError::BadRequest)?; server_stream.write_all(x224_req.as_bytes()).await?; - // == Receive server X224 connection response == - trace!("Receiving X224 response"); let x224_rsp = read_x224_response(&mut server_stream) @@ -256,20 +369,17 @@ async fn process_cleanpath( trace!("Establishing TLS connection with server"); - // == Establish TLS connection with server == - - let server_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) + let tls_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) .await .map_err(|source| CleanPathError::TlsHandshake { source, target_server: selected_target.to_owned(), })?; - Ok(CleanPathResult { - destination: selected_target.to_owned(), - claims, + Ok(ConnectedRdpServer { + tls_stream, server_addr, - server_stream, + selected_target: selected_target.to_owned(), x224_rsp, }) } @@ -287,6 +397,7 @@ async fn handle_with_credential_injection( active_recordings: &ActiveRecordings, cleanpath_pdu: RDCleanPathPdu, credential_entry: Arc, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { let tls_conf = conf.credssp_tls.get().context("CredSSP TLS configuration")?; @@ -318,16 +429,9 @@ async fn handle_with_credential_injection( ) }; - // Run normal RDCleanPath flow (this will handle server-side TLS and get certs). - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - .. - } = process_cleanpath( - cleanpath_pdu, + // Authorize and connect to the RDP server. + let CleanPathAuth { claims } = authorize_cleanpath( + &cleanpath_pdu, client_addr, &conf, token_cache, @@ -336,7 +440,16 @@ async fn handle_with_credential_injection( &sessions, ) .await - .context("RDCleanPath processing failed")?; + .context("RDCleanPath authorization failed")?; + + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connect_rdp_server(&claims, cleanpath_pdu, agent_tunnel_handle.as_ref()) + .await + .context("RDCleanPath connection failed")?; // Retrieve the Gateway TLS public key that must be used for client-proxy CredSSP later on. let gateway_cert_chain_handle = tokio::spawn(crate::tls::get_cert_chain_for_acceptor_cached( @@ -532,6 +645,7 @@ pub async fn handle( subscriber_tx: SubscriberSender, active_recordings: &ActiveRecordings, credential_store: &CredentialStoreHandle, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { // Special handshake of our RDP extension @@ -569,27 +683,29 @@ pub async fn handle( active_recordings, cleanpath_pdu, entry, + agent_tunnel_handle.clone(), ) .await; } trace!("Processing RDCleanPath"); - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - } = match process_cleanpath( - cleanpath_pdu, - client_addr, - &conf, - token_cache, - jrl, - active_recordings, - &sessions, - ) + let (auth, connected) = match async { + let auth = authorize_cleanpath( + &cleanpath_pdu, + client_addr, + &conf, + token_cache, + jrl, + active_recordings, + &sessions, + ) + .await?; + + let connected = connect_rdp_server(&auth.claims, cleanpath_pdu, agent_tunnel_handle.as_ref()).await?; + + Ok::<_, CleanPathError>((auth, connected)) + } .await { Ok(result) => result, @@ -602,6 +718,13 @@ pub async fn handle( } }; + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connected; + // == Send success RDCleanPathPdu response == let x509_chain = server_stream @@ -622,13 +745,13 @@ pub async fn handle( // == Start actual RDP session == let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) + .id(auth.claims.jet_aid) + .application_protocol(auth.claims.jet_ap) .details(ConnectionModeDetails::Fwd { destination_host: destination.clone(), }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) + .time_to_live(auth.claims.jet_ttl) + .recording_policy(auth.claims.jet_rec) .build(); info!("RDP-TLS forwarding (RDCleanPath)"); @@ -642,7 +765,7 @@ pub async fn handle( .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(auth.claims.jet_reuse)) .build() .select_dissector_and_forward() .await diff --git a/devolutions-gateway/src/rdp_proxy.rs b/devolutions-gateway/src/rdp_proxy.rs index b3dc466a7..b7fddfdd5 100644 --- a/devolutions-gateway/src/rdp_proxy.rs +++ b/devolutions-gateway/src/rdp_proxy.rs @@ -637,7 +637,7 @@ where async fn send_network_request(request: &NetworkRequest) -> anyhow::Result> { let target_addr = TargetAddr::parse(request.url.as_str(), Some(88))?; - send_krb_message(&target_addr, &request.data) + send_krb_message(&target_addr, &request.data, None) .await .map_err(|err| anyhow::Error::msg("failed to send KDC message").context(err)) }