diff --git a/Cargo.lock b/Cargo.lock index 7d19ba15..aa2a9959 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2753,6 +2753,7 @@ dependencies = [ "protobuf-src", "serde", "serde_json", + "socket2 0.6.2", "thiserror 2.0.18", "tonic", "tonic-build", diff --git a/Cargo.toml b/Cargo.toml index 1f4096f3..be4800b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,7 @@ pin-project-lite = "0.2" tokio-stream = "0.1" protobuf-src = "1.1.0" url = "2" +socket2 = "0.6" # Database sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "sqlite", "migrate"] } diff --git a/architecture/sandbox-connect.md b/architecture/sandbox-connect.md index eb834c6d..2027935a 100644 --- a/architecture/sandbox-connect.md +++ b/architecture/sandbox-connect.md @@ -138,11 +138,22 @@ sequenceDiagram 5. When SSH starts, it spawns the `ssh-proxy` subprocess as its `ProxyCommand`. 6. `crates/navigator-cli/src/ssh.rs` -- `sandbox_ssh_proxy()`: - Parses the gateway URL, connects via TCP (plain) or TLS (mTLS) + - Enables TCP keepalive on the gateway socket - Sends a raw HTTP CONNECT request with `X-Sandbox-Id` and `X-Sandbox-Token` headers - Reads the response status line; proceeds if 200 - Spawns two `tokio::spawn` tasks for bidirectional copy between stdin/stdout and the gateway stream - When the remote-to-stdout direction completes, aborts the stdin-to-remote task (SSH has all the data it needs) +### Connection stability + +Recent SSH stability hardening is split across the client, gateway, sandbox, and edge tunnel paths: + +- **OpenSSH keepalives**: the CLI now sets `ServerAliveInterval=30` and `ServerAliveCountMax=3` on every SSH invocation so idle sessions still emit SSH traffic. +- **TCP keepalive**: the CLI-to-gateway and gateway-to-sandbox TCP sockets enable 30-second keepalive probes to reduce drops from NAT, load balancers, and other idle-sensitive middleboxes. +- **Sandbox SSH daemon**: the embedded `russh` server disables its default 10-minute inactivity timeout and instead sends protocol keepalives every 30 seconds. This prevents quiet shells from being garbage-collected while still detecting dead peers. +- **Edge WebSocket tunnel**: the WebSocket bridge now lets both copy directions observe shutdown instead of aborting the peer task immediately, which reduces abrupt closes and truncated tail data. +- **Limit diagnostics**: when the gateway rejects a connection because the per-session or per-sandbox cap is reached, it now logs the active count and configured limit to make 429s easier to diagnose. + ### Command Execution (CLI) The `sandbox exec` path is identical to interactive connect except: diff --git a/crates/navigator-cli/src/edge_tunnel.rs b/crates/navigator-cli/src/edge_tunnel.rs index 814e245f..fcbe2371 100644 --- a/crates/navigator-cli/src/edge_tunnel.rs +++ b/crates/navigator-cli/src/edge_tunnel.rs @@ -36,6 +36,8 @@ use tokio_tungstenite::tungstenite::http::HeaderValue; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tracing::{debug, error, warn}; +const EDGE_TUNNEL_WS_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + /// A running edge-authenticated tunnel proxy. /// /// The proxy listens on a local TCP port and tunnels each connection over a @@ -124,24 +126,18 @@ async fn handle_connection(tcp_stream: TcpStream, config: &TunnelConfig) -> Resu let (ws_sink, ws_source) = ws_stream.split(); let (tcp_read, tcp_write) = tokio::io::split(tcp_stream); - // Two tasks: TCP->WS and WS->TCP. Abort the peer task once either - // direction finishes so the connection tears down promptly. - let mut tcp_to_ws = tokio::spawn(copy_tcp_to_ws(tcp_read, ws_sink)); - let mut ws_to_tcp = tokio::spawn(copy_ws_to_tcp(ws_source, tcp_write)); + // Keep both directions alive until each side has observed shutdown. This + // avoids cutting off trailing bytes when one half closes slightly earlier + // than the other. + let tcp_to_ws = tokio::spawn(copy_tcp_to_ws(tcp_read, ws_sink)); + let ws_to_tcp = tokio::spawn(copy_ws_to_tcp(ws_source, tcp_write)); - tokio::select! { - res = &mut tcp_to_ws => { - if let Err(e) = res { - debug!(error = %e, "tcp->ws task panicked"); - } - ws_to_tcp.abort(); - } - res = &mut ws_to_tcp => { - if let Err(e) = res { - debug!(error = %e, "ws->tcp task panicked"); - } - tcp_to_ws.abort(); - } + let (tcp_to_ws_res, ws_to_tcp_res) = tokio::join!(tcp_to_ws, ws_to_tcp); + if let Err(e) = tcp_to_ws_res { + debug!(error = %e, "tcp->ws task panicked"); + } + if let Err(e) = ws_to_tcp_res { + debug!(error = %e, "ws->tcp task panicked"); } Ok(()) @@ -170,9 +166,13 @@ async fn open_ws(config: &TunnelConfig) -> Result Command { .arg("-o") .arg("GlobalKnownHostsFile=/dev/null") .arg("-o") + .arg(format!( + "ServerAliveInterval={SSH_SERVER_ALIVE_INTERVAL_SECS}" + )) + .arg("-o") + .arg(format!("ServerAliveCountMax={SSH_SERVER_ALIVE_COUNT_MAX}")) + .arg("-o") .arg("LogLevel=ERROR"); command } @@ -530,10 +543,20 @@ pub async fn sandbox_ssh_proxy( // any bytes read past the `\r\n\r\n` header boundary stay buffered and // are returned by subsequent reads during the bidirectional copy phase. let mut buf_stream = BufReader::new(stream); - let status = read_connect_status(&mut buf_stream).await?; + let status = tokio::time::timeout( + SSH_PROXY_STATUS_TIMEOUT, + read_connect_status(&mut buf_stream), + ) + .await + .map_err(|_| miette::miette!("timed out waiting for gateway CONNECT response"))??; if status != 200 { + let reason = match status { + 401 => " (SSH session expired, was revoked, or is invalid)", + 429 => " (too many concurrent SSH connections for this session or sandbox)", + _ => "", + }; return Err(miette::miette!( - "gateway CONNECT failed with status {status}" + "gateway CONNECT failed with status {status}{reason}" )); } @@ -594,6 +617,8 @@ pub fn print_ssh_config(gateway: &str, name: &str) { println!(" StrictHostKeyChecking no"); println!(" UserKnownHostsFile /dev/null"); println!(" GlobalKnownHostsFile /dev/null"); + println!(" ServerAliveInterval {SSH_SERVER_ALIVE_INTERVAL_SECS}"); + println!(" ServerAliveCountMax {SSH_SERVER_ALIVE_COUNT_MAX}"); println!(" LogLevel ERROR"); println!(" ProxyCommand {proxy_cmd}"); } @@ -628,25 +653,37 @@ async fn connect_gateway( .ok_or_else(|| miette::miette!("edge token required for tunnel"))?; let gateway_url = format!("https://{host}:{port}"); let proxy = crate::edge_tunnel::start_tunnel_proxy(&gateway_url, token).await?; - let tcp = TcpStream::connect(proxy.local_addr) - .await - .into_diagnostic()?; + let tcp = tokio::time::timeout( + SSH_PROXY_CONNECT_TIMEOUT, + TcpStream::connect(proxy.local_addr), + ) + .await + .map_err(|_| miette::miette!("timed out connecting to edge tunnel proxy"))? + .into_diagnostic()?; tcp.set_nodelay(true).into_diagnostic()?; + let _ = enable_tcp_keepalive(&tcp); return Ok(Box::new(tcp)); } - let tcp = TcpStream::connect((host, port)).await.into_diagnostic()?; + let tcp = tokio::time::timeout(SSH_PROXY_CONNECT_TIMEOUT, TcpStream::connect((host, port))) + .await + .map_err(|_| miette::miette!("timed out connecting to SSH gateway"))? + .into_diagnostic()?; tcp.set_nodelay(true).into_diagnostic()?; + let _ = enable_tcp_keepalive(&tcp); if scheme.eq_ignore_ascii_case("https") { let materials = require_tls_materials(&format!("https://{host}:{port}"), tls)?; let config = build_rustls_config(&materials)?; let connector = TlsConnector::from(Arc::new(config)); let server_name = ServerName::try_from(host.to_string()) .map_err(|_| miette::miette!("invalid server name: {host}"))?; - let tls = connector - .connect(server_name, tcp) - .await - .into_diagnostic()?; + let tls = tokio::time::timeout( + SSH_PROXY_CONNECT_TIMEOUT, + connector.connect(server_name, tcp), + ) + .await + .map_err(|_| miette::miette!("timed out establishing TLS to SSH gateway"))? + .into_diagnostic()?; Ok(Box::new(tls)) } else { Ok(Box::new(tcp)) @@ -688,3 +725,22 @@ async fn read_connect_status(stream: &mut R) -> Result ProxyStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ssh_base_command_enables_server_keepalives() { + let command = ssh_base_command("openshell ssh-proxy"); + let args = command + .get_args() + .map(|arg| arg.to_string_lossy().into_owned()) + .collect::>(); + + assert!(args.contains(&format!( + "ServerAliveInterval={SSH_SERVER_ALIVE_INTERVAL_SECS}" + ))); + assert!(args.contains(&format!("ServerAliveCountMax={SSH_SERVER_ALIVE_COUNT_MAX}"))); + } +} diff --git a/crates/navigator-core/Cargo.toml b/crates/navigator-core/Cargo.toml index 138ef0a7..68f4fc5c 100644 --- a/crates/navigator-core/Cargo.toml +++ b/crates/navigator-core/Cargo.toml @@ -19,6 +19,7 @@ miette = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } url = { workspace = true } +socket2 = { workspace = true } [build-dependencies] tonic-build = { workspace = true } diff --git a/crates/navigator-core/src/lib.rs b/crates/navigator-core/src/lib.rs index a84c586b..721f3ecb 100644 --- a/crates/navigator-core/src/lib.rs +++ b/crates/navigator-core/src/lib.rs @@ -12,6 +12,7 @@ pub mod config; pub mod error; pub mod forward; pub mod inference; +pub mod net; pub mod proto; pub use config::{Config, TlsConfig}; diff --git a/crates/navigator-core/src/net.rs b/crates/navigator-core/src/net.rs new file mode 100644 index 00000000..130d6bdf --- /dev/null +++ b/crates/navigator-core/src/net.rs @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared socket configuration helpers. + +use socket2::{SockRef, TcpKeepalive}; +use std::io; +use std::time::Duration; + +/// Idle time before TCP keepalive probes start. +pub const TCP_KEEPALIVE_IDLE: Duration = Duration::from_secs(30); + +/// Interval between TCP keepalive probes on supported platforms. +pub const TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); + +fn default_keepalive() -> TcpKeepalive { + let keepalive = TcpKeepalive::new().with_time(TCP_KEEPALIVE_IDLE); + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + target_os = "cygwin", + ))] + let keepalive = keepalive.with_interval(TCP_KEEPALIVE_INTERVAL); + keepalive +} + +/// Enable aggressive TCP keepalive on a socket. +#[cfg(unix)] +pub fn enable_tcp_keepalive(socket: &S) -> io::Result<()> +where + S: std::os::fd::AsFd, +{ + SockRef::from(socket).set_tcp_keepalive(&default_keepalive()) +} + +/// Enable aggressive TCP keepalive on a socket. +#[cfg(windows)] +pub fn enable_tcp_keepalive(socket: &S) -> io::Result<()> +where + S: std::os::windows::io::AsSocket, +{ + SockRef::from(socket).set_tcp_keepalive(&default_keepalive()) +} + +/// Enable aggressive TCP keepalive on a socket. +#[cfg(not(any(unix, windows)))] +pub fn enable_tcp_keepalive(_socket: &S) -> io::Result<()> { + Ok(()) +} diff --git a/crates/navigator-sandbox/src/ssh.rs b/crates/navigator-sandbox/src/ssh.rs index 3013542d..dbfc4d89 100644 --- a/crates/navigator-sandbox/src/ssh.rs +++ b/crates/navigator-sandbox/src/ssh.rs @@ -9,6 +9,7 @@ use crate::sandbox; #[cfg(target_os = "linux")] use crate::{register_managed_child, unregister_managed_child}; use miette::{IntoDiagnostic, Result}; +use navigator_core::net::enable_tcp_keepalive; use nix::pty::{Winsize, openpty}; use nix::unistd::setsid; use rand_core::OsRng; @@ -28,6 +29,21 @@ use tokio::net::TcpListener; use tracing::{info, warn}; const PREFACE_MAGIC: &str = "NSSH1"; +const SSH_PREFACE_TIMEOUT: Duration = Duration::from_secs(10); +const SSH_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); +const SSH_KEEPALIVE_MAX_MISSES: usize = 3; + +fn build_ssh_server_config(host_key: PrivateKey) -> russh::server::Config { + russh::server::Config { + auth_rejection_time: Duration::from_secs(1), + inactivity_timeout: None, + keepalive_interval: Some(SSH_KEEPALIVE_INTERVAL), + keepalive_max: SSH_KEEPALIVE_MAX_MISSES, + nodelay: true, + keys: vec![host_key], + ..Default::default() + } +} /// A time-bounded set of nonces used to detect replayed NSSH1 handshakes. /// Each entry records the `Instant` it was inserted; a background reaper task @@ -47,14 +63,7 @@ async fn ssh_server_init( )> { let mut rng = OsRng; let host_key = PrivateKey::random(&mut rng, Algorithm::Ed25519).into_diagnostic()?; - - let mut config = russh::server::Config { - auth_rejection_time: Duration::from_secs(1), - ..Default::default() - }; - config.keys.push(host_key); - - let config = Arc::new(config); + let config = Arc::new(build_ssh_server_config(host_key)); let ca_paths = ca_file_paths.as_ref().map(|p| Arc::new(p.clone())); let listener = TcpListener::bind(listen_addr).await.into_diagnostic()?; info!(addr = %listen_addr, "SSH server listening"); @@ -158,8 +167,13 @@ async fn handle_connection( nonce_cache: &NonceCache, ) -> Result<()> { info!(peer = %peer, "SSH connection: reading handshake preface"); + if let Err(err) = enable_tcp_keepalive(&stream) { + warn!(peer = %peer, error = %err, "SSH connection: failed to enable TCP keepalive"); + } let mut line = String::new(); - read_line(&mut stream, &mut line).await?; + tokio::time::timeout(SSH_PREFACE_TIMEOUT, read_line(&mut stream, &mut line)) + .await + .map_err(|_| miette::miette!("timed out waiting for SSH handshake preface"))??; info!(peer = %peer, preface_len = line.len(), "SSH connection: preface received, verifying"); if !verify_preface(&line, secret, handshake_skew_secs, nonce_cache)? { warn!(peer = %peer, "SSH connection: handshake verification failed"); @@ -267,6 +281,11 @@ struct SshHandler { proxy_url: Option, ca_file_paths: Option>, provider_env: HashMap, + channels: HashMap, +} + +#[derive(Default)] +struct ChannelState { input_sender: Option>>, pty_master: Option, pty_request: Option, @@ -288,11 +307,73 @@ impl SshHandler { proxy_url, ca_file_paths, provider_env, - input_sender: None, - pty_master: None, - pty_request: None, + channels: HashMap::new(), + } + } + + fn channel_state(&mut self, channel: ChannelId) -> &mut ChannelState { + self.channels.entry(channel).or_default() + } + + fn record_pty_request( + &mut self, + channel: ChannelId, + term: &str, + col_width: u32, + row_height: u32, + ) { + self.channel_state(channel).pty_request = Some(PtyRequest { + term: term.to_string(), + col_width, + row_height, + pixel_width: 0, + pixel_height: 0, + }); + } + + fn resize_channel_pty( + &mut self, + channel: ChannelId, + col_width: u32, + row_height: u32, + pixel_width: u32, + pixel_height: u32, + ) { + if let Some(master) = self + .channels + .get(&channel) + .and_then(|state| state.pty_master.as_ref()) + { + let winsize = Winsize { + ws_row: to_u16(row_height.max(1)), + ws_col: to_u16(col_width.max(1)), + ws_xpixel: to_u16(pixel_width), + ws_ypixel: to_u16(pixel_height), + }; + let _ = unsafe_pty::set_winsize(master.as_raw_fd(), winsize); + } + } + + fn forward_channel_data(&mut self, channel: ChannelId, data: &[u8]) { + if let Some(sender) = self + .channels + .get(&channel) + .and_then(|state| state.input_sender.as_ref()) + .cloned() + { + let _ = sender.send(data.to_vec()); } } + + fn close_channel_input(&mut self, channel: ChannelId) { + if let Some(state) = self.channels.get_mut(&channel) { + state.input_sender.take(); + } + } + + fn cleanup_channel(&mut self, channel: ChannelId) { + self.channels.remove(&channel); + } } impl russh::server::Handler for SshHandler { @@ -377,35 +458,21 @@ impl russh::server::Handler for SshHandler { _modes: &[(russh::Pty, u32)], session: &mut Session, ) -> Result<(), Self::Error> { - self.pty_request = Some(PtyRequest { - term: term.to_string(), - col_width, - row_height, - pixel_width: 0, - pixel_height: 0, - }); + self.record_pty_request(channel, term, col_width, row_height); session.channel_success(channel)?; Ok(()) } async fn window_change_request( &mut self, - _channel: ChannelId, + channel: ChannelId, col_width: u32, row_height: u32, pixel_width: u32, pixel_height: u32, _session: &mut Session, ) -> Result<(), Self::Error> { - if let Some(master) = self.pty_master.as_ref() { - let winsize = Winsize { - ws_row: to_u16(row_height.max(1)), - ws_col: to_u16(col_width.max(1)), - ws_xpixel: to_u16(pixel_width), - ws_ypixel: to_u16(pixel_height), - }; - let _ = unsafe_pty::set_winsize(master.as_raw_fd(), winsize); - } + self.resize_channel_pty(channel, col_width, row_height, pixel_width, pixel_height); Ok(()) } @@ -457,26 +524,33 @@ impl russh::server::Handler for SshHandler { async fn data( &mut self, - _channel: ChannelId, + channel: ChannelId, data: &[u8], _session: &mut Session, ) -> Result<(), Self::Error> { - if let Some(sender) = self.input_sender.as_ref() { - let _ = sender.send(data.to_vec()); - } + self.forward_channel_data(channel, data); Ok(()) } async fn channel_eof( &mut self, - _channel: ChannelId, + channel: ChannelId, _session: &mut Session, ) -> Result<(), Self::Error> { // Drop the input sender so the stdin writer thread sees a // disconnected channel and closes the child's stdin pipe. This // is essential for commands like `cat | tar xf -` which need // stdin EOF to know the input stream is complete. - self.input_sender.take(); + self.close_channel_input(channel); + Ok(()) + } + + async fn channel_close( + &mut self, + channel: ChannelId, + _session: &mut Session, + ) -> Result<(), Self::Error> { + self.cleanup_channel(channel); Ok(()) } } @@ -488,7 +562,12 @@ impl SshHandler { handle: Handle, command: Option, ) -> anyhow::Result<()> { - if let Some(pty) = self.pty_request.take() { + let pty_request = self + .channels + .get_mut(&channel) + .and_then(|state| state.pty_request.take()); + + if let Some(pty) = pty_request { // PTY was requested — allocate a real PTY (interactive shell or // exec that explicitly asked for a terminal). let (pty_master, input_sender) = spawn_pty_shell( @@ -503,8 +582,9 @@ impl SshHandler { self.ca_file_paths.clone(), &self.provider_env, )?; - self.pty_master = Some(pty_master); - self.input_sender = Some(input_sender); + let state = self.channel_state(channel); + state.pty_master = Some(pty_master); + state.input_sender = Some(input_sender); } else { // No PTY requested — use plain pipes so stdout/stderr are // separate and output has clean LF line endings. This is the @@ -520,7 +600,7 @@ impl SshHandler { self.ca_file_paths.clone(), &self.provider_env, )?; - self.input_sender = Some(input_sender); + self.channel_state(channel).input_sender = Some(input_sender); } Ok(()) } @@ -787,7 +867,9 @@ fn spawn_pty_shell( // `nohup daemon &`) may hold the PTY slave open indefinitely, // preventing the reader from reaching EOF. Two seconds is enough // for any remaining buffered data to drain. - let _ = reader_done_rx.recv_timeout(Duration::from_secs(2)); + if reader_done_rx.recv_timeout(Duration::from_secs(2)).is_err() { + warn!(channel = %channel, "PTY reader did not drain before timeout"); + } drop(runtime_exit.spawn(async move { let _ = handle_exit.exit_status_request(channel, code).await; let _ = handle_exit.close(channel).await; @@ -958,8 +1040,12 @@ fn spawn_pipe_exec( unregister_managed_child(child_pid); let code = status.and_then(|s| s.code()).unwrap_or(1).unsigned_abs(); // Wait for both reader threads. - let _ = reader_done_rx.recv_timeout(Duration::from_secs(2)); - let _ = reader_done_rx.recv_timeout(Duration::from_secs(1)); + if reader_done_rx.recv_timeout(Duration::from_secs(2)).is_err() { + warn!(channel = %channel, "stdout reader did not drain before timeout"); + } + if reader_done_rx.recv_timeout(Duration::from_secs(1)).is_err() { + warn!(channel = %channel, "stderr reader did not drain before timeout"); + } drop(runtime_exit.spawn(async move { let _ = handle_exit.eof(channel).await; let _ = handle_exit.exit_status_request(channel, code).await; @@ -1063,6 +1149,31 @@ fn to_u16(value: u32) -> u16 { #[cfg(test)] mod tests { use super::*; + use crate::policy::{FilesystemPolicy, LandlockPolicy, NetworkPolicy, ProcessPolicy}; + + fn test_channel_id(id: u32) -> ChannelId { + #[allow(unsafe_code)] + unsafe { + std::mem::transmute(id) + } + } + + fn empty_handler() -> SshHandler { + SshHandler::new( + SandboxPolicy { + version: 1, + filesystem: FilesystemPolicy::default(), + network: NetworkPolicy::default(), + landlock: LandlockPolicy::default(), + process: ProcessPolicy::default(), + }, + None, + None, + None, + None, + HashMap::new(), + ) + } /// Verify that dropping the input sender (the operation `channel_eof` /// performs) causes the stdin writer loop to exit and close the child's @@ -1167,6 +1278,60 @@ mod tests { ); } + #[test] + fn channel_data_routes_only_to_matching_channel() { + let mut handler = empty_handler(); + let channel1 = test_channel_id(1); + let channel2 = test_channel_id(2); + let (sender1, receiver1) = mpsc::channel::>(); + let (sender2, receiver2) = mpsc::channel::>(); + handler.channel_state(channel1).input_sender = Some(sender1); + handler.channel_state(channel2).input_sender = Some(sender2); + + handler.forward_channel_data(channel1, b"hello"); + + assert_eq!(receiver1.recv().unwrap(), b"hello"); + assert!( + receiver2.try_recv().is_err(), + "unexpected data on second channel" + ); + } + + #[test] + fn channel_eof_only_closes_matching_channel_input() { + let mut handler = empty_handler(); + let channel1 = test_channel_id(1); + let channel2 = test_channel_id(2); + let (sender1, receiver1) = mpsc::channel::>(); + let (sender2, receiver2) = mpsc::channel::>(); + handler.channel_state(channel1).input_sender = Some(sender1); + handler.channel_state(channel2).input_sender = Some(sender2.clone()); + + handler.close_channel_input(channel1); + handler.forward_channel_data(channel2, b"still-open"); + + assert!( + receiver1.try_recv().is_err(), + "closed channel should not receive more data" + ); + assert_eq!(receiver2.recv().unwrap(), b"still-open"); + drop(sender2); + } + + #[test] + fn cleanup_channel_removes_only_matching_state() { + let mut handler = empty_handler(); + let channel1 = test_channel_id(1); + let channel2 = test_channel_id(2); + handler.record_pty_request(channel1, "xterm-256color", 80, 24); + handler.record_pty_request(channel2, "xterm-256color", 120, 40); + + handler.cleanup_channel(channel1); + + assert!(!handler.channels.contains_key(&channel1)); + assert!(handler.channels.contains_key(&channel2)); + } + // ----------------------------------------------------------------------- // verify_preface tests // ----------------------------------------------------------------------- @@ -1265,4 +1430,17 @@ mod tests { assert!(verify_preface(&line1, secret, 300, &cache).unwrap()); assert!(verify_preface(&line2, secret, 300, &cache).unwrap()); } + + #[test] + fn ssh_server_config_keeps_idle_sessions_alive() { + let mut rng = OsRng; + let host_key = PrivateKey::random(&mut rng, Algorithm::Ed25519).unwrap(); + let config = build_ssh_server_config(host_key); + + assert_eq!(config.inactivity_timeout, None); + assert_eq!(config.keepalive_interval, Some(SSH_KEEPALIVE_INTERVAL)); + assert_eq!(config.keepalive_max, SSH_KEEPALIVE_MAX_MISSES); + assert!(config.nodelay); + assert_eq!(config.keys.len(), 1); + } } diff --git a/crates/navigator-server/src/grpc.rs b/crates/navigator-server/src/grpc.rs index 4ea378a7..65e693ed 100644 --- a/crates/navigator-server/src/grpc.rs +++ b/crates/navigator-server/src/grpc.rs @@ -1853,6 +1853,9 @@ const SSH_CONNECT_INITIAL_BACKOFF: std::time::Duration = std::time::Duration::fr /// Maximum backoff duration between SSH connection retries (caps exponential growth). const SSH_CONNECT_MAX_BACKOFF: std::time::Duration = std::time::Duration::from_secs(2); +const SSH_PROXY_ACCEPT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); +const SSH_PROXY_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const SSH_PROXY_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); /// Returns `true` if the gRPC status represents a transient SSH connection error /// that is worth retrying (e.g. the sandbox SSH server is not yet listening). @@ -2002,9 +2005,13 @@ async fn run_exec_with_russh( stdin_payload: Vec, tx: mpsc::Sender>, ) -> Result { - let stream = TcpStream::connect(("127.0.0.1", local_proxy_port)) - .await - .map_err(|e| Status::internal(format!("failed to connect to ssh proxy: {e}")))?; + let stream = tokio::time::timeout( + SSH_PROXY_ACCEPT_TIMEOUT, + TcpStream::connect(("127.0.0.1", local_proxy_port)), + ) + .await + .map_err(|_| Status::deadline_exceeded("timed out connecting to ssh proxy"))? + .map_err(|e| Status::internal(format!("failed to connect to ssh proxy: {e}")))?; let config = Arc::new(russh::client::Config::default()); let mut client = russh::client::connect_stream(config, stream, SandboxSshClientHandler) @@ -2099,12 +2106,26 @@ async fn start_single_use_ssh_proxy( let handshake_secret = handshake_secret.to_string(); let task = tokio::spawn(async move { - let Ok((mut client_conn, _)) = listener.accept().await else { + let Ok(accept_result) = + tokio::time::timeout(SSH_PROXY_ACCEPT_TIMEOUT, listener.accept()).await + else { + warn!("SSH proxy: timed out waiting for local connection"); + return; + }; + let Ok((mut client_conn, _)) = accept_result else { warn!("SSH proxy: failed to accept local connection"); return; }; - let Ok(mut sandbox_conn) = TcpStream::connect((target_host.as_str(), target_port)).await + let Ok(connect_result) = tokio::time::timeout( + SSH_PROXY_CONNECT_TIMEOUT, + TcpStream::connect((target_host.as_str(), target_port)), + ) + .await else { + warn!(target_host = %target_host, target_port, "SSH proxy: timed out connecting to sandbox"); + return; + }; + let Ok(mut sandbox_conn) = connect_result else { warn!(target_host = %target_host, target_port, "SSH proxy: failed to connect to sandbox"); return; }; @@ -2113,12 +2134,36 @@ async fn start_single_use_ssh_proxy( warn!("SSH proxy: failed to build handshake preface"); return; }; - if let Err(e) = sandbox_conn.write_all(preface.as_bytes()).await { + if let Err(e) = tokio::time::timeout( + SSH_PROXY_HANDSHAKE_TIMEOUT, + sandbox_conn.write_all(preface.as_bytes()), + ) + .await + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out sending handshake preface", + ) + }) + .and_then(|result| result) + { warn!(error = %e, "SSH proxy: failed to send handshake preface"); return; } let mut response = String::new(); - if let Err(e) = read_line(&mut sandbox_conn, &mut response).await { + let read_response = match tokio::time::timeout( + SSH_PROXY_HANDSHAKE_TIMEOUT, + read_line(&mut sandbox_conn, &mut response), + ) + .await + { + Ok(result) => result.map_err(|e| std::io::Error::other(e.to_string())), + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out waiting for handshake response", + )), + }; + if let Err(e) = read_response { warn!(error = %e, "SSH proxy: failed to read handshake response"); return; } diff --git a/crates/navigator-server/src/ssh_tunnel.rs b/crates/navigator-server/src/ssh_tunnel.rs index 7fea3573..6e539edc 100644 --- a/crates/navigator-server/src/ssh_tunnel.rs +++ b/crates/navigator-server/src/ssh_tunnel.rs @@ -8,6 +8,7 @@ use http::StatusCode; use hyper::Request; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; +use navigator_core::net::enable_tcp_keepalive; use navigator_core::proto::{Sandbox, SandboxPhase, SshSession}; use prost::Message; use std::net::SocketAddr; @@ -30,6 +31,8 @@ const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; /// Maximum concurrent SSH tunnel connections per sandbox. const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; +const SSH_UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const SSH_UPSTREAM_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); pub fn router(state: Arc) -> Router { Router::new() @@ -116,7 +119,7 @@ async fn ssh_connect( let mut counts = state.ssh_connections_by_token.lock().unwrap(); let count = counts.entry(token.clone()).or_insert(0); if *count >= MAX_CONNECTIONS_PER_TOKEN { - warn!(token = %token, "SSH tunnel: per-token connection limit reached"); + warn!(sandbox_id = %sandbox_id, active_connections = *count, limit = MAX_CONNECTIONS_PER_TOKEN, "SSH tunnel: per-token connection limit reached"); return StatusCode::TOO_MANY_REQUESTS.into_response(); } *count += 1; @@ -135,30 +138,32 @@ async fn ssh_connect( token_counts.remove(&token); } } - warn!(sandbox_id = %sandbox_id, "SSH tunnel: per-sandbox connection limit reached"); + warn!(sandbox_id = %sandbox_id, active_connections = *count, limit = MAX_CONNECTIONS_PER_SANDBOX, "SSH tunnel: per-sandbox connection limit reached"); return StatusCode::TOO_MANY_REQUESTS.into_response(); } *count += 1; } + let upgrade = hyper::upgrade::on(req); let handshake_secret = state.config.ssh_handshake_secret.clone(); let sandbox_id_clone = sandbox_id.clone(); let token_clone = token.clone(); let state_clone = state.clone(); + let upstream = + match establish_upstream(&connect_target, &token, &handshake_secret, &sandbox_id).await { + Ok(upstream) => upstream, + Err(err) => { + warn!(sandbox_id = %sandbox_id, error = %err, "SSH tunnel setup failed"); + decrement_connection_count(&state.ssh_connections_by_token, &token); + decrement_connection_count(&state.ssh_connections_by_sandbox, &sandbox_id); + return err.status_code().into_response(); + } + }; - let upgrade = hyper::upgrade::on(req); tokio::spawn(async move { match upgrade.await { Ok(mut upgraded) => { - if let Err(err) = handle_tunnel( - &mut upgraded, - connect_target, - &token_clone, - &handshake_secret, - &sandbox_id_clone, - ) - .await - { + if let Err(err) = bridge_tunnel(&mut upgraded, upstream, &sandbox_id_clone).await { warn!(error = %err, "SSH tunnel failure"); } } @@ -175,18 +180,15 @@ async fn ssh_connect( StatusCode::OK.into_response() } -async fn handle_tunnel( - upgraded: &mut Upgraded, - target: ConnectTarget, +async fn establish_upstream( + target: &ConnectTarget, token: &str, secret: &str, sandbox_id: &str, -) -> Result<(), Box> { +) -> Result { // The sandbox pod may not be network-reachable immediately after the CRD // reports Ready (DNS propagation, pod IP assignment, SSH server startup). // Retry the TCP connection with exponential backoff. - let mut upstream = None; - let mut last_err = None; let delays = [ Duration::from_millis(100), Duration::from_millis(250), @@ -202,6 +204,33 @@ async fn handle_tunnel( ConnectTarget::Host(host, port) => format!("{host}:{port}"), }; info!(sandbox_id = %sandbox_id, target = %target_desc, "SSH tunnel: connecting to sandbox"); + establish_upstream_with_timeouts( + target, + token, + secret, + sandbox_id, + SSH_UPSTREAM_CONNECT_TIMEOUT, + SSH_UPSTREAM_HANDSHAKE_TIMEOUT, + &delays, + ) + .await +} + +async fn establish_upstream_with_timeouts( + target: &ConnectTarget, + token: &str, + secret: &str, + sandbox_id: &str, + connect_timeout: Duration, + handshake_timeout: Duration, + delays: &[Duration], +) -> Result { + let mut upstream = None; + let mut last_err = None; + let target_desc = match target { + ConnectTarget::Ip(addr) => format!("{addr}"), + ConnectTarget::Host(host, port) => format!("{host}:{port}"), + }; for (attempt, delay) in std::iter::once(&Duration::ZERO) .chain(delays.iter()) .enumerate() @@ -210,10 +239,7 @@ async fn handle_tunnel( info!(sandbox_id = %sandbox_id, attempt = attempt + 1, delay_ms = delay.as_millis() as u64, "SSH tunnel: retrying TCP connect"); tokio::time::sleep(*delay).await; } - let result = match &target { - ConnectTarget::Ip(addr) => TcpStream::connect(addr).await, - ConnectTarget::Host(host, port) => TcpStream::connect((host.as_str(), *port)).await, - }; + let result = connect_target_with_timeout(target, connect_timeout).await; match result { Ok(stream) => { info!( @@ -231,27 +257,69 @@ async fn handle_tunnel( } } let mut upstream = upstream.ok_or_else(|| { - let err = last_err.unwrap(); - format!("failed to connect to sandbox after retries: {err}") + last_err.unwrap_or_else(|| { + TunnelSetupError::Other(format!( + "failed to connect to sandbox after retries: {target_desc}" + )) + }) })?; - upstream.set_nodelay(true)?; + upstream + .set_nodelay(true) + .map_err(|err| TunnelSetupError::Other(err.to_string()))?; + if let Err(err) = enable_tcp_keepalive(&upstream) { + warn!(sandbox_id = %sandbox_id, error = %err, "SSH tunnel: failed to enable upstream TCP keepalive"); + } info!(sandbox_id = %sandbox_id, "SSH tunnel: sending NSSH1 handshake preface"); - let preface = build_preface(token, secret)?; - upstream.write_all(preface.as_bytes()).await?; + let preface = + build_preface(token, secret).map_err(|err| TunnelSetupError::Other(err.to_string()))?; + tokio::time::timeout(handshake_timeout, upstream.write_all(preface.as_bytes())) + .await + .map_err(|_| TunnelSetupError::Timeout("timed out sending sandbox handshake preface"))? + .map_err(|err| TunnelSetupError::Other(err.to_string()))?; info!(sandbox_id = %sandbox_id, "SSH tunnel: waiting for handshake response"); let mut response = String::new(); - read_line(&mut upstream, &mut response).await?; + tokio::time::timeout(handshake_timeout, read_line(&mut upstream, &mut response)) + .await + .map_err(|_| TunnelSetupError::Timeout("timed out waiting for sandbox handshake response"))? + .map_err(|err| TunnelSetupError::Other(err.to_string()))?; info!(sandbox_id = %sandbox_id, response = %response.trim(), "SSH tunnel: handshake response received"); if response.trim() != "OK" { - return Err("sandbox handshake rejected".into()); + return Err(TunnelSetupError::Other( + "sandbox handshake rejected".to_string(), + )); } + Ok(upstream) +} + +async fn bridge_tunnel( + upgraded: &mut Upgraded, + mut upstream: TcpStream, + sandbox_id: &str, +) -> Result<(), Box> { info!(sandbox_id = %sandbox_id, "SSH tunnel established"); let mut upgraded = TokioIo::new(upgraded); - // Discard the result entirely – connection-close errors are expected when - // the SSH session ends and do not represent a failure worth propagating. - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut upstream).await; + match tokio::io::copy_bidirectional(&mut upgraded, &mut upstream).await { + Ok((client_to_sandbox, sandbox_to_client)) => { + info!(sandbox_id = %sandbox_id, client_to_sandbox_bytes = client_to_sandbox, sandbox_to_client_bytes = sandbox_to_client, "SSH tunnel closed"); + } + Err(err) + if matches!( + err.kind(), + std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::UnexpectedEof + ) => + { + info!(sandbox_id = %sandbox_id, error = %err, "SSH tunnel closed during shutdown"); + } + Err(err) => { + warn!(sandbox_id = %sandbox_id, error = %err, "SSH tunnel I/O failure"); + } + } // Gracefully shut down the write-half of the upgraded connection so the // client receives a clean EOF instead of a TCP RST. This gives SSH time // to read any remaining protocol data (e.g. exit-status) from its buffer. @@ -259,6 +327,24 @@ async fn handle_tunnel( Ok(()) } +async fn connect_target_with_timeout( + target: &ConnectTarget, + timeout: Duration, +) -> Result { + let connect = match target { + ConnectTarget::Ip(addr) => tokio::time::timeout(timeout, TcpStream::connect(addr)).await, + ConnectTarget::Host(host, port) => { + tokio::time::timeout(timeout, TcpStream::connect((host.as_str(), *port))).await + } + }; + + match connect { + Ok(Ok(stream)) => Ok(stream), + Ok(Err(err)) => Err(TunnelSetupError::Other(err.to_string())), + Err(_) => Err(TunnelSetupError::Timeout("timed out connecting to sandbox")), + } +} + fn header_value(headers: &http::HeaderMap, name: &str) -> Result { let value = headers .get(name) @@ -348,6 +434,30 @@ enum ConnectTarget { Host(String, u16), } +#[derive(Debug, Clone)] +enum TunnelSetupError { + Timeout(&'static str), + Other(String), +} + +impl TunnelSetupError { + const fn status_code(&self) -> StatusCode { + match self { + Self::Timeout(_) => StatusCode::GATEWAY_TIMEOUT, + Self::Other(_) => StatusCode::BAD_GATEWAY, + } + } +} + +impl std::fmt::Display for TunnelSetupError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Timeout(msg) => write!(f, "{msg}"), + Self::Other(msg) => write!(f, "{msg}"), + } + } +} + /// Decrement a connection count entry, removing it if it reaches zero. fn decrement_connection_count( counts: &std::sync::Mutex>, @@ -422,6 +532,7 @@ mod tests { use crate::persistence::Store; use std::collections::HashMap; use std::sync::Mutex; + use tokio::net::TcpListener; fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { SshSession { @@ -602,4 +713,56 @@ mod tests { "session with zero expiry should never be expired" ); } + + #[tokio::test] + async fn establish_upstream_times_out_waiting_for_handshake_response() { + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + }); + + let err = establish_upstream_with_timeouts( + &ConnectTarget::Ip(addr), + "token", + "secret", + "sandbox-1", + Duration::from_millis(20), + Duration::from_millis(20), + &[], + ) + .await + .unwrap_err(); + + assert!(matches!(err, TunnelSetupError::Timeout(_))); + let _ = server.await; + } + + #[tokio::test] + async fn establish_upstream_rejects_non_ok_handshake_response() { + let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 256]; + let _ = stream.read(&mut buf).await.unwrap(); + stream.write_all(b"ERR\n").await.unwrap(); + }); + + let err = establish_upstream_with_timeouts( + &ConnectTarget::Ip(addr), + "token", + "secret", + "sandbox-1", + Duration::from_millis(20), + Duration::from_millis(20), + &[], + ) + .await + .unwrap_err(); + + assert!(matches!(err, TunnelSetupError::Other(_))); + let _ = server.await; + } } diff --git a/crates/navigator-server/src/ws_tunnel.rs b/crates/navigator-server/src/ws_tunnel.rs index 56073ff4..88936d7d 100644 --- a/crates/navigator-server/src/ws_tunnel.rs +++ b/crates/navigator-server/src/ws_tunnel.rs @@ -67,24 +67,22 @@ async fn handle_ws_tunnel( debug!(error = %e, "WS tunnel: multiplex service error"); } }); - let mut tunnel_to_ws = tokio::spawn(copy_reader_to_ws(tunnel_read, ws_sink)); - let mut ws_to_tunnel = tokio::spawn(copy_ws_to_writer(ws_source, tunnel_write)); + let tunnel_to_ws = tokio::spawn(copy_reader_to_ws(tunnel_read, ws_sink)); + let ws_to_tunnel = tokio::spawn(copy_ws_to_writer(ws_source, tunnel_write)); - tokio::select! { - res = &mut tunnel_to_ws => { - if let Ok(Err(e)) = res { - debug!(error = %e, "WS tunnel: tunnel->ws error"); - } - ws_to_tunnel.abort(); - } - res = &mut ws_to_tunnel => { - if let Ok(Err(e)) = res { - debug!(error = %e, "WS tunnel: ws->tunnel error"); - } - tunnel_to_ws.abort(); - } + let (tunnel_to_ws_res, ws_to_tunnel_res) = tokio::join!(tunnel_to_ws, ws_to_tunnel); + match tunnel_to_ws_res { + Ok(Ok(())) => {} + Ok(Err(e)) => debug!(error = %e, "WS tunnel: tunnel->ws error"), + Err(e) => debug!(error = %e, "WS tunnel: tunnel->ws task panicked"), + } + match ws_to_tunnel_res { + Ok(Ok(())) => {} + Ok(Err(e)) => debug!(error = %e, "WS tunnel: ws->tunnel error"), + Err(e) => debug!(error = %e, "WS tunnel: ws->tunnel task panicked"), } service_task.abort(); + let _ = service_task.await; Ok(()) } diff --git a/crates/navigator-tui/src/lib.rs b/crates/navigator-tui/src/lib.rs index 5181cbdb..12f27eb0 100644 --- a/crates/navigator-tui/src/lib.rs +++ b/crates/navigator-tui/src/lib.rs @@ -27,6 +27,8 @@ use event::{Event, EventHandler}; /// Duration to show the splash screen before auto-dismissing. const SPLASH_DURATION: Duration = Duration::from_secs(3); +const SSH_SERVER_ALIVE_INTERVAL_SECS: u64 = 30; +const SSH_SERVER_ALIVE_COUNT_MAX: u64 = 3; /// Launch the OpenShell TUI. /// @@ -750,6 +752,12 @@ async fn handle_shell_connect( .arg("-o") .arg("GlobalKnownHostsFile=/dev/null") .arg("-o") + .arg(format!( + "ServerAliveInterval={SSH_SERVER_ALIVE_INTERVAL_SECS}" + )) + .arg("-o") + .arg(format!("ServerAliveCountMax={SSH_SERVER_ALIVE_COUNT_MAX}")) + .arg("-o") .arg("LogLevel=ERROR") .arg("-tt") .arg("-o") @@ -899,6 +907,12 @@ async fn handle_exec_command( .arg("-o") .arg("GlobalKnownHostsFile=/dev/null") .arg("-o") + .arg(format!( + "ServerAliveInterval={SSH_SERVER_ALIVE_INTERVAL_SECS}" + )) + .arg("-o") + .arg(format!("ServerAliveCountMax={SSH_SERVER_ALIVE_COUNT_MAX}")) + .arg("-o") .arg("LogLevel=ERROR") .arg("-tt") .arg("-o") @@ -1320,6 +1334,12 @@ async fn start_port_forwards( .arg("-o") .arg("GlobalKnownHostsFile=/dev/null") .arg("-o") + .arg(format!( + "ServerAliveInterval={SSH_SERVER_ALIVE_INTERVAL_SECS}" + )) + .arg("-o") + .arg(format!("ServerAliveCountMax={SSH_SERVER_ALIVE_COUNT_MAX}")) + .arg("-o") .arg("LogLevel=ERROR") .arg("-o") .arg("ConnectTimeout=15")