From 8626debcb78a90363ba5ade61b23fb6bcd9e1898 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 28 Apr 2026 16:33:24 -0700 Subject: [PATCH 1/4] feat(relay): route forwarding through ForwardTcp --- .../skills/debug-openshell-cluster/SKILL.md | 4 + architecture/gateway.md | 42 +- crates/openshell-cli/Cargo.toml | 1 + crates/openshell-cli/src/main.rs | 72 ++- crates/openshell-cli/src/run.rs | 309 +++++++++- crates/openshell-cli/src/ssh.rs | 259 +++------ crates/openshell-cli/src/tls.rs | 3 + .../tests/ensure_providers_integration.rs | 11 + .../openshell-cli/tests/mtls_integration.rs | 11 + .../tests/provider_commands_integration.rs | 11 + .../sandbox_create_lifecycle_integration.rs | 19 +- .../sandbox_name_fallback_integration.rs | 11 + crates/openshell-core/src/config.rs | 16 - crates/openshell-core/src/forward.rs | 111 ++-- crates/openshell-ocsf/src/format/shorthand.rs | 49 +- crates/openshell-sandbox/src/lib.rs | 2 +- .../src/supervisor_session.rs | 425 ++++++++++++-- crates/openshell-server/src/cli.rs | 9 - crates/openshell-server/src/grpc/mod.rs | 15 +- crates/openshell-server/src/grpc/sandbox.rs | 428 +++++++++++++- crates/openshell-server/src/http.rs | 4 +- crates/openshell-server/src/lib.rs | 4 +- crates/openshell-server/src/multiplex.rs | 14 - crates/openshell-server/src/ssh_sessions.rs | 185 ++++++ crates/openshell-server/src/ssh_tunnel.rs | 541 ------------------ .../src/supervisor_session.rs | 205 ++++--- .../tests/auth_endpoint_integration.rs | 15 + .../tests/edge_tunnel_auth.rs | 20 +- .../tests/multiplex_integration.rs | 20 +- .../tests/multiplex_tls_integration.rs | 20 +- .../tests/supervisor_relay_integration.rs | 21 +- .../tests/ws_tunnel_integration.rs | 20 +- crates/openshell-tui/src/lib.rs | 18 +- docs/sandboxes/manage-sandboxes.mdx | 8 + proto/openshell.proto | 62 +- 35 files changed, 1923 insertions(+), 1042 deletions(-) create mode 100644 crates/openshell-server/src/ssh_sessions.rs delete mode 100644 crates/openshell-server/src/ssh_tunnel.rs diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 16158c0dc..64f8bd83d 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -128,6 +128,10 @@ helm -n openshell get values openshell | grep -E 'repository|tag|supervisorImage The gateway image and `server.supervisorImage` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. +For local/external pull mode (the default local path via `mise run cluster`), local images are tagged to the configured local registry base, pushed to that registry, and pulled by k3s via the `registries.yaml` mirror endpoint. The `cluster` task rebuilds the local gateway image before tagging and pushing it, so a fresh bootstrap should not reuse stale `openshell/gateway:dev` bits from a previous source revision. + +Gateway image builds stage a partial Rust workspace from `deploy/docker/Dockerfile.images`. If cargo fails with a missing manifest under `/build/crates/...`, or an imported symbol exists locally but is missing in the image build, verify that every current gateway dependency crate is copied into the staged workspace there. + For plaintext local evaluation, confirm the chart has: ```bash diff --git a/architecture/gateway.md b/architecture/gateway.md index d89706e64..a1320cfaa 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -9,11 +9,12 @@ workloads. - Authenticate clients and sandbox callbacks. - Serve gRPC APIs for sandbox lifecycle, provider management, policy updates, - settings, inference configuration, logs, and watch streams. -- Serve HTTP endpoints for health, SSH tunnel upgrades, and edge-auth flows. + settings, inference configuration, logs, watch streams, and relay forwarding. +- Serve HTTP endpoints for health, WebSocket tunnels, and edge-auth flows. - Persist domain objects in SQLite or Postgres. - Resolve provider credentials and inference bundles for sandbox supervisors. -- Coordinate supervisor relay sessions for connect, exec, and file sync. +- Coordinate supervisor relay sessions for connect, exec, file sync, and + service forwarding. The gateway does not enforce agent network policy at request time. That happens inside each sandbox, where the supervisor and proxy can observe local process @@ -44,7 +45,7 @@ The gateway API is organized around platform objects and operational streams: | Area | Examples | |---|---| -| Sandbox lifecycle | Create, list, delete, watch, exec, SSH session bootstrap. | +| Sandbox lifecycle | Create, list, delete, watch, exec, SSH session bootstrap, ForwardTcp service forwarding. | | Providers | Store provider records, discover credentials, resolve runtime environment. | | Policy and settings | Get effective sandbox config, update sandbox policy, manage global settings. | | Inference | Set gateway-level model/provider config and resolve sandbox route bundles. | @@ -115,22 +116,35 @@ sequenceDiagram participant CLI participant GW as Gateway participant SUP as Sandbox supervisor - participant SSH as Sandbox SSH socket + participant Target as Sandbox target SUP->>GW: ConnectSupervisor stream - CLI->>GW: connect / exec / sync request - GW->>SUP: RelayOpen(channel) + CLI->>GW: ForwardTcp / exec / sync request + GW->>SUP: RelayOpen(channel, target) + SUP->>Target: Dial SSH socket or loopback service SUP->>GW: RelayStream(channel) - SUP->>SSH: Bridge bytes to Unix socket CLI->>GW: Client bytes GW-->>CLI: Client bytes GW->>SUP: Relay bytes SUP-->>GW: Relay bytes ``` -The same relay pattern backs interactive SSH, command execution, and file sync. -The gateway tracks live sessions in memory and persists session records so -tokens can expire or be revoked. +The same relay pattern backs interactive SSH, command execution, file sync, and +local service forwarding. The gateway tracks live sessions in memory and +persists session records so tokens can expire or be revoked. + +`ForwardTcp` is the client-facing byte stream for SSH and service forwarding. +The first frame is a `TcpForwardInit` that carries the sandbox ID, an +authorization token from `CreateSshSession`, and an explicit target: +`target.ssh` for the sandbox SSH socket or `target.tcp` for a loopback service +inside the sandbox. The gateway validates the token and sandbox readiness, +sends a targeted `RelayOpen` to the supervisor, then bridges +`TcpForwardFrame::Data` to `RelayFrame::Data` until either side closes. + +For `target.tcp`, the gateway only accepts loopback destinations such as +`localhost`, `127.0.0.0/8`, or `::1`. The gateway never needs to know or dial a +sandbox pod IP; supervisors connect outbound and bridge only the explicit target +requested for that relay. ## PKI Bootstrap @@ -143,13 +157,13 @@ created. Both deployment paths use it: | Filesystem | `--output-dir ` | `/{ca.crt, ca.key, server/tls.{crt,key}, client/tls.{crt,key}}`. Also copies client materials to `$XDG_CONFIG_HOME/openshell/gateways/openshell/mtls/` for CLI auto-discovery. | On Kubernetes, the Helm chart runs the command via a pre-install/pre-upgrade -hook Job using the gateway image itself — no separate cert-generation image, +hook Job using the gateway image itself -- no separate cert-generation image, no extra mirror burden in air-gapped environments. On the RPM gateway, the same command runs from the systemd unit's `ExecStartPre` to bootstrap PKI into the user's state directory on first start. -Both modes share the same idempotency contract: all targets present → skip; -partial state → fail with a recovery hint; nothing present → generate and +Both modes share the same idempotency contract: all targets present -> skip; +partial state -> fail with a recovery hint; nothing present -> generate and write. This guards mTLS continuity across restarts and upgrades while still recovering cleanly if an operator deletes everything and starts over. diff --git a/crates/openshell-cli/Cargo.toml b/crates/openshell-cli/Cargo.toml index 8b86544b7..21068ad99 100644 --- a/crates/openshell-cli/Cargo.toml +++ b/crates/openshell-cli/Cargo.toml @@ -68,6 +68,7 @@ tokio-tungstenite = { workspace = true } # Streams futures = { workspace = true } +tokio-stream = { workspace = true } nix = { workspace = true } # URL parsing diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index e25fb7576..4dc5c588b 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -8,6 +8,7 @@ use clap_complete::engine::ArgValueCompleter; use clap_complete::env::CompleteEnv; use miette::Result; use owo_colors::OwoColorize; +use std::collections::HashMap; use std::io::Write; use std::path::PathBuf; @@ -198,6 +199,7 @@ const HELP_TEMPLATE: &str = "\ \x1b[1mSANDBOX COMMANDS\x1b[0m sandbox: Manage sandboxes forward: Manage port forwarding to a sandbox + service: Forward sandbox services over gRPC logs: View sandbox logs policy: Manage sandbox policy settings: Manage sandbox and global settings @@ -270,6 +272,12 @@ const FORWARD_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell forward list "; +const SERVICE_EXAMPLES: &str = "\x1b[1mEXAMPLES\x1b[0m + $ openshell service forward my-sandbox --target-port 8080 + $ openshell service forward my-sandbox --target-port 5432 --local 15432 + $ openshell service forward my-sandbox --target-port 3000 --local 127.0.0.1:0 +"; + const LOGS_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m lg @@ -407,6 +415,13 @@ enum Commands { command: Option, }, + /// Forward sandbox services over gRPC. + #[command(after_help = SERVICE_EXAMPLES, help_template = SUBCOMMAND_HELP_TEMPLATE)] + Service { + #[command(subcommand)] + command: Option, + }, + /// View sandbox logs. #[command(alias = "lg", after_help = LOGS_EXAMPLES, help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Logs { @@ -1614,6 +1629,29 @@ enum ForwardCommands { List, } +#[derive(Subcommand, Debug)] +enum ServiceCommands { + /// Forward a local TCP port to a loopback service inside a sandbox over gRPC. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Forward { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + + /// Target service port inside the sandbox. + #[arg(long)] + target_port: u16, + + /// Target service host inside the sandbox. Phase 1 accepts loopback only. + #[arg(long, default_value = "127.0.0.1")] + target_host: String, + + /// Local bind address and port: [bind_address:]port. Use port 0 for dynamic assignment. + #[arg(long)] + local: Option, + }, +} + #[tokio::main] #[allow(clippy::large_stack_frames)] // CLI dispatch holds many futures; OK at top level. async fn main() -> Result<()> { @@ -1778,6 +1816,38 @@ async fn main() -> Result<()> { } } + Some(Commands::Service { + command: + Some(ServiceCommands::Forward { + name, + target_port, + target_host, + local, + }), + }) => { + let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; + let mut tls = tls.with_gateway_name(&ctx.name); + apply_edge_auth(&mut tls, &ctx.name); + let name = resolve_sandbox_name(name, &ctx.name)?; + run::service_forward_tcp( + &ctx.endpoint, + &name, + local.as_deref(), + &target_host, + target_port, + &tls, + ) + .await?; + } + + Some(Commands::Service { command: None }) => { + Cli::command() + .find_subcommand_mut("service") + .expect("service subcommand exists") + .print_help() + .expect("Failed to print help"); + } + // ----------------------------------------------------------- // Top-level forward (was `sandbox forward`) // ----------------------------------------------------------- @@ -2237,7 +2307,7 @@ async fn main() -> Result<()> { }; // Parse --label flags into a HashMap. - let mut labels_map = std::collections::HashMap::new(); + let mut labels_map = HashMap::new(); for label_str in &labels { let parts: Vec<&str> = label_str.splitn(2, '=').collect(); if parts.len() != 2 { diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 165713b6e..e0888a9cf 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -27,17 +27,19 @@ use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, AttachSandboxProviderRequest, ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, - DeleteProviderProfileRequest, DeleteProviderRequest, DeleteSandboxRequest, - DetachSandboxProviderRequest, ExecSandboxRequest, GetClusterInferenceRequest, - GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, - GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, + CreateSshSessionRequest, DeleteProviderProfileRequest, DeleteProviderRequest, + DeleteSandboxRequest, DetachSandboxProviderRequest, ExecSandboxRequest, + GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, + GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest, + GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, + GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, - RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, - SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, - UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, + RejectDraftChunkRequest, RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, SettingValue, + TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, + WatchSandboxRequest, exec_sandbox_event, setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1554,7 +1556,7 @@ pub async fn sandbox_create( status.message() )); } - Err(status) => return Err(status).into_diagnostic(), + Err(status) => return Err(miette::miette!(status.to_string())), }; let sandbox = response .into_inner() @@ -2438,6 +2440,295 @@ pub async fn sandbox_exec_grpc( Ok(exit_code) } +pub async fn service_forward_tcp( + server: &str, + name: &str, + local: Option<&str>, + target_host: &str, + target_port: u16, + tls: &TlsOptions, +) -> Result<()> { + let (bind_addr, bind_port) = parse_tcp_forward_spec(local, target_port)?; + let mut client = grpc_client(server, tls).await?; + + let sandbox = fetch_ready_sandbox_for_forward(&mut client, name).await?; + + let listener = tokio::net::TcpListener::bind((bind_addr.as_str(), bind_port)) + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to bind local forward on {bind_addr}:{bind_port}"))?; + let local_addr = listener + .local_addr() + .into_diagnostic() + .wrap_err("failed to read local forward address")?; + eprintln!( + "{} Forwarding {} -> {}:{} in sandbox {} via gRPC", + "✓".green().bold(), + local_addr, + target_host, + target_port, + name, + ); + + let sandbox_id = sandbox.object_id().to_string(); + let (fatal_tx, mut fatal_rx) = tokio::sync::mpsc::channel::(1); + let mut health_check = tokio::time::interval(Duration::from_secs(2)); + health_check.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + loop { + tokio::select! { + Some(reason) = fatal_rx.recv() => { + return Err(miette::miette!("service forward stopped: {reason}")); + } + + _ = health_check.tick() => { + fetch_ready_sandbox_for_forward(&mut client, name).await?; + } + + accepted = listener.accept() => { + let (socket, peer) = accepted + .into_diagnostic() + .wrap_err("failed to accept local forward connection")?; + let mut client = client.clone(); + let sandbox_id = sandbox_id.clone(); + let target_host = target_host.to_string(); + let service_id = format!("service-forward:{name}:{target_host}:{target_port}"); + let fatal_tx = fatal_tx.clone(); + tokio::spawn(async move { + let token = match create_forward_session_token(&mut client, &sandbox_id).await { + Ok(token) => token, + Err(err) => { + tracing::warn!(peer = %peer, error = %err, "service forward session creation failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + return; + } + }; + if let Err(err) = forward_one_tcp_connection( + &mut client, + socket, + sandbox_id, + target_host, + target_port, + service_id, + token.clone(), + ) + .await + { + tracing::warn!(peer = %peer, error = %err, "service forward connection failed"); + if err.fatal { + let _ = fatal_tx.send(err.message).await; + } + } + let _ = client + .revoke_ssh_session(RevokeSshSessionRequest { token }) + .await; + }); + } + } + } +} + +async fn create_forward_session_token( + client: &mut crate::tls::GrpcClient, + sandbox_id: &str, +) -> std::result::Result { + let response = client + .create_ssh_session(CreateSshSessionRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await + .map_err(ForwardTcpConnectionError::from_status)?; + Ok(response.into_inner().token) +} + +async fn fetch_ready_sandbox_for_forward( + client: &mut crate::tls::GrpcClient, + name: &str, +) -> Result { + let response = match client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + { + Ok(response) => response, + Err(status) if status.code() == Code::NotFound => { + return Err(miette::miette!( + "sandbox '{name}' no longer exists; stopping service forward" + )); + } + Err(status) => return Err(status).into_diagnostic(), + }; + + let sandbox = response + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox '{name}' not found"))?; + + if SandboxPhase::try_from(sandbox.phase) != Ok(SandboxPhase::Ready) { + return Err(miette::miette!( + "sandbox '{}' is no longer ready (phase: {}); stopping service forward", + name, + phase_name(sandbox.phase) + )); + } + + Ok(sandbox) +} + +#[derive(Debug)] +struct ForwardTcpConnectionError { + message: String, + fatal: bool, +} + +impl ForwardTcpConnectionError { + fn transient(message: impl Into) -> Self { + Self { + message: message.into(), + fatal: false, + } + } + + fn from_status(status: Status) -> Self { + let fatal = matches!(status.code(), Code::NotFound | Code::FailedPrecondition); + Self { + message: status.to_string(), + fatal, + } + } +} + +impl std::fmt::Display for ForwardTcpConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for ForwardTcpConnectionError {} + +fn parse_tcp_forward_spec(local: Option<&str>, default_port: u16) -> Result<(String, u16)> { + let Some(spec) = local else { + return Ok(("127.0.0.1".to_string(), default_port)); + }; + + if let Some(pos) = spec.rfind(':') { + let addr = &spec[..pos]; + let port_str = &spec[pos + 1..]; + if let Ok(port) = port_str.parse::() { + if addr.is_empty() { + return Err(miette::miette!("bind address is required before ':'")); + } + return Ok((addr.to_string(), port)); + } + } + + let port: u16 = spec.parse().map_err(|_| { + miette::miette!("invalid local forward spec '{spec}': expected [bind_address:]port") + })?; + Ok(("127.0.0.1".to_string(), port)) +} + +async fn forward_one_tcp_connection( + client: &mut crate::tls::GrpcClient, + socket: tokio::net::TcpStream, + sandbox_id: String, + target_host: String, + target_port: u16, + service_id: String, + authorization_token: String, +) -> std::result::Result<(), ForwardTcpConnectionError> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_stream::wrappers::ReceiverStream; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id, + service_id, + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: target_host, + port: u32::from(target_port), + })), + authorization_token, + }, + )), + }) + .await + .map_err(|_| ForwardTcpConnectionError::transient("failed to initialize forward stream"))?; + + let mut response = match client.forward_tcp(ReceiverStream::new(rx)).await { + Ok(response) => response.into_inner(), + Err(status) => { + let err = ForwardTcpConnectionError::from_status(status); + drain_and_shutdown_local_socket(socket).await; + return Err(err); + } + }; + + let (mut local_read, mut local_write) = socket.into_split(); + + let to_gateway = tokio::spawn(async move { + let mut buf = vec![0u8; 64 * 1024]; + loop { + let n = local_read.read(&mut buf).await?; + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + Ok::<(), std::io::Error>(()) + }); + + while let Some(frame) = response + .message() + .await + .map_err(ForwardTcpConnectionError::from_status)? + { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + local_write + .write_all(&data) + .await + .map_err(|err| ForwardTcpConnectionError::transient(err.to_string()))?; + } + + let _ = local_write.shutdown().await; + to_gateway.abort(); + Ok(()) +} + +async fn drain_and_shutdown_local_socket(mut socket: tokio::net::TcpStream) { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let mut buf = [0u8; 4096]; + loop { + match tokio::time::timeout(Duration::from_millis(25), socket.read(&mut buf)).await { + Ok(Ok(0)) | Err(_) => break, + Ok(Ok(_)) => continue, + Ok(Err(_)) => break, + } + } + let _ = socket.shutdown().await; +} + /// Print a single YAML line with dimmed keys and regular values. fn print_yaml_line(line: &str) { // Find leading whitespace diff --git a/crates/openshell-cli/src/ssh.rs b/crates/openshell-cli/src/ssh.rs index 89e9071e1..65e8605f1 100644 --- a/crates/openshell-cli/src/ssh.rs +++ b/crates/openshell-cli/src/ssh.rs @@ -3,30 +3,30 @@ //! SSH connection and proxy utilities. -use crate::tls::{TlsOptions, build_rustls_config, grpc_client, require_tls_materials}; +use crate::tls::{TlsOptions, grpc_client}; use miette::{IntoDiagnostic, Result, WrapErr}; #[cfg(unix)] use nix::sys::signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction}; use openshell_core::ObjectId; use openshell_core::forward::{ - build_proxy_command, find_ssh_forward_pid, resolve_ssh_gateway, shell_escape, - validate_ssh_session_response, write_forward_pid, + build_proxy_command, find_ssh_forward_pid, format_gateway_url, resolve_ssh_gateway, + shell_escape, validate_ssh_session_response, write_forward_pid, +}; +use openshell_core::proto::{ + CreateSshSessionRequest, GetSandboxRequest, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + tcp_forward_init, }; -use openshell_core::proto::{CreateSshSessionRequest, GetSandboxRequest}; use owo_colors::OwoColorize; -use rustls::pki_types::ServerName; use std::fs; use std::io::{IsTerminal, Write}; #[cfg(unix)] use std::os::unix::process::CommandExt; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; -use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command as TokioCommand; -use tokio_rustls::TlsConnector; +use tokio_stream::wrappers::ReceiverStream; const FOREGROUND_FORWARD_STARTUP_GRACE_PERIOD: Duration = Duration::from_secs(2); @@ -100,8 +100,7 @@ async fn ssh_session_config( // external tunnel endpoint (the cluster URL), not the server's internal // scheme/host/port which may be plaintext HTTP on 127.0.0.1. let gateway_url = if tls.is_bearer_auth() { - let base = server.trim_end_matches('/'); - format!("{base}{}", session.connect_path) + server.trim_end_matches('/').to_string() } else { // If the server returned a loopback gateway address, override it with the // cluster endpoint's host. This handles the case where the server defaults @@ -110,10 +109,7 @@ async fn ssh_session_config( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, server); - format!( - "{}://{}:{}{}", - session.gateway_scheme, gateway_host, gateway_port, session.connect_path - ) + format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port) }; let gateway_name = tls .gateway_name() @@ -821,18 +817,86 @@ pub async fn sandbox_ssh_proxy( token: &str, tls: &TlsOptions, ) -> Result<()> { - // The gateway returns 412 (Precondition Failed) when the sandbox pod - // exists but hasn't reached Ready phase yet. This is a transient state - // after sandbox allocation — retry with backoff instead of failing - // immediately. - const MAX_CONNECT_WAIT: Duration = Duration::from_secs(60); - const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + let server = grpc_server_from_ssh_gateway_url(gateway_url)?; + let mut client = grpc_client(&server, tls).await?; + + let (tx, rx) = tokio::sync::mpsc::channel::(16); + tx.send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Init( + TcpForwardInit { + sandbox_id: sandbox_id.to_string(), + service_id: format!("ssh-proxy:{sandbox_id}"), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget {})), + authorization_token: token.to_string(), + }, + )), + }) + .await + .map_err(|_| miette::miette!("failed to initialize SSH forward stream"))?; + + let mut response = client + .forward_tcp(ReceiverStream::new(rx)) + .await + .into_diagnostic()? + .into_inner(); + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let to_remote = tokio::spawn(async move { + let mut stdin = stdin; + let mut buf = vec![0u8; 64 * 1024]; + loop { + let Ok(n) = stdin.read(&mut buf).await else { + break; + }; + if n == 0 { + break; + } + if tx + .send(TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }) + .await + .is_err() + { + break; + } + } + }); + let from_remote = tokio::spawn(async move { + let mut stdout = stdout; + loop { + let frame = match response.message().await { + Ok(Some(frame)) => frame, + Ok(None) | Err(_) => break, + }; + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload + else { + continue; + }; + if data.is_empty() { + continue; + } + if stdout.write_all(&data).await.is_err() { + break; + } + let _ = stdout.flush().await; + } + }); + let _ = from_remote.await; + to_remote.abort(); + + Ok(()) +} + +fn grpc_server_from_ssh_gateway_url(gateway_url: &str) -> Result { let url: url::Url = gateway_url .parse() .into_diagnostic() .wrap_err("invalid gateway URL")?; - let scheme = url.scheme(); let gateway_host = url .host_str() @@ -840,69 +904,7 @@ pub async fn sandbox_ssh_proxy( let gateway_port = url .port_or_known_default() .ok_or_else(|| miette::miette!("gateway URL missing port"))?; - let connect_path = url.path(); - - let request = format!( - "CONNECT {connect_path} HTTP/1.1\r\nHost: {gateway_host}\r\nX-Sandbox-Id: {sandbox_id}\r\nX-Sandbox-Token: {token}\r\n\r\n" - ); - - let start = std::time::Instant::now(); - let mut backoff = INITIAL_BACKOFF; - let mut buf_stream; - - loop { - let mut stream: Box = - connect_gateway(scheme, gateway_host, gateway_port, tls).await?; - stream - .write_all(request.as_bytes()) - .await - .into_diagnostic()?; - - // Wrap in a BufReader **before** reading the HTTP response. The gateway - // may send the 200 OK response and the first SSH protocol bytes in the - // same TCP segment / WebSocket frame. A plain `read()` would consume - // those SSH bytes into our buffer and discard them, causing SSH to see a - // truncated protocol banner and exit with code 255. BufReader ensures - // any bytes read past the `\r\n\r\n` header boundary stay buffered and - // are returned by subsequent reads during the bidirectional copy phase. - buf_stream = BufReader::new(stream); - let status = read_connect_status(&mut buf_stream).await?; - if status == 200 { - break; - } - if status == 412 && start.elapsed() < MAX_CONNECT_WAIT { - tracing::debug!( - elapsed = ?start.elapsed(), - "sandbox not yet ready (HTTP 412), retrying in {backoff:?}" - ); - tokio::time::sleep(backoff).await; - backoff = (backoff * 2).min(Duration::from_secs(8)); - continue; - } - return Err(miette::miette!( - "gateway CONNECT failed with status {status}" - )); - } - - let (reader, writer) = tokio::io::split(buf_stream); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - - // Spawn both copy directions as independent tasks. Using separate spawned - // tasks (instead of try_join!/select!) ensures that when one direction - // completes or errors, the other continues independently until it also - // finishes. This is critical: when the remote side closes the connection, - // we must keep the stdin→gateway copy alive so SSH can finish sending its - // protocol-close packets, and vice-versa. - let to_remote = tokio::spawn(copy_ignoring_errors(stdin, writer)); - let from_remote = tokio::spawn(copy_ignoring_errors(reader, stdout)); - let _ = from_remote.await; - // Once the remote→stdout direction is done, SSH has received all the data - // it needs. Drop the stdin→gateway task – SSH will close its pipe when - // it's done regardless. - to_remote.abort(); - - Ok(()) + Ok(format_gateway_url(scheme, gateway_host, gateway_port)) } /// Run the SSH proxy in "name mode": create a session on the fly, then proxy. @@ -1122,93 +1124,6 @@ pub fn print_ssh_config(gateway: &str, name: &str) { print!("{}", render_ssh_config(gateway, name)); } -/// Copy all bytes from `reader` to `writer`, flushing on completion. -/// Errors are intentionally discarded – connection teardown errors are -/// expected during normal SSH session shutdown. -async fn copy_ignoring_errors(mut reader: R, mut writer: W) -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let _ = tokio::io::copy(&mut reader, &mut writer).await; - let _ = AsyncWriteExt::flush(&mut writer).await; - let _ = AsyncWriteExt::shutdown(&mut writer).await; -} - -async fn connect_gateway( - scheme: &str, - host: &str, - port: u16, - tls: &TlsOptions, -) -> Result> { - // When using Cloudflare edge bearer auth, route through the WebSocket - // tunnel proxy regardless of the origin scheme. The proxy handles edge - // auth headers and TLS termination at the edge; the origin may be - // plaintext HTTP behind the tunnel. OIDC tokens bypass the tunnel. - if let Some(token) = tls.edge_token.as_deref() { - let gateway_url = format!("https://{host}:{port}"); - let proxy = crate::edge_tunnel::start_tunnel_proxy(&gateway_url, token).await?; - let tcp = TcpStream::connect(proxy.local_addr) - .await - .into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - return Ok(Box::new(tcp)); - } - - let tcp = TcpStream::connect((host, port)).await.into_diagnostic()?; - tcp.set_nodelay(true).into_diagnostic()?; - if scheme.eq_ignore_ascii_case("https") { - let materials = require_tls_materials(&format!("https://{host}:{port}"), tls)?; - let config = build_rustls_config(&materials)?; - let connector = TlsConnector::from(Arc::new(config)); - let server_name = ServerName::try_from(host.to_string()) - .map_err(|_| miette::miette!("invalid server name: {host}"))?; - let tls = connector - .connect(server_name, tcp) - .await - .into_diagnostic()?; - Ok(Box::new(tls)) - } else { - Ok(Box::new(tcp)) - } -} - -/// Read exactly the HTTP response status line and headers up to `\r\n\r\n`. -/// -/// Uses byte-at-a-time reads so that the caller's `BufReader` retains any -/// bytes that arrived after the header boundary (e.g. the SSH protocol -/// banner that the gateway may send in the same TCP segment). -async fn read_connect_status(stream: &mut R) -> Result { - let mut buf = Vec::new(); - let mut byte = [0u8; 1]; - loop { - let n = stream.read(&mut byte).await.into_diagnostic()?; - if n == 0 { - break; - } - buf.push(byte[0]); - if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" { - break; - } - if buf.len() > 8192 { - break; - } - } - let text = String::from_utf8_lossy(&buf); - let line = text.lines().next().unwrap_or(""); - let status = line - .split_whitespace() - .nth(1) - .unwrap_or("0") - .parse::() - .unwrap_or(0); - Ok(status) -} - -trait ProxyStream: AsyncRead + AsyncWrite + Unpin + Send {} - -impl ProxyStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-cli/src/tls.rs b/crates/openshell-cli/src/tls.rs index c733d3db3..9c5de1773 100644 --- a/crates/openshell-cli/src/tls.rs +++ b/crates/openshell-cli/src/tls.rs @@ -342,6 +342,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -362,6 +363,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let endpoint = Endpoint::from_shared(local_url) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); return endpoint.connect().await.into_diagnostic(); @@ -389,6 +391,7 @@ pub async fn build_channel(server: &str, tls: &TlsOptions) -> Result { let mut endpoint = Endpoint::from_shared(server.to_string()) .into_diagnostic()? .connect_timeout(Duration::from_secs(10)) + .http2_adaptive_window(true) .http2_keep_alive_interval(Duration::from_secs(10)) .keep_alive_while_idle(true); diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index fec161c53..5d8d8f1b3 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -516,6 +516,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } // ── TLS helpers ────────────────────────────────────────────────────── diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index e833e7af9..a728643a8 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -407,6 +407,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } fn build_ca() -> (Certificate, KeyPair) { diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 3902bda34..16c0b97b1 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -625,6 +625,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index eb28a18b3..da18e79d1 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -222,7 +222,6 @@ impl OpenShell for TestOpenShell { gateway_scheme: "https".to_string(), gateway_host: "localhost".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), ..CreateSshSessionResponse::default() })) } @@ -491,6 +490,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { @@ -782,10 +792,9 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { let _env = test_env(&fake_ssh_dir, &xdg_dir); let tls = test_tls(&server); install_fake_ssh(&fake_ssh_dir); - let forward_port = { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - listener.local_addr().unwrap().port() - }; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let forward_port = listener.local_addr().unwrap().port(); + drop(listener); run::sandbox_create( &server.endpoint, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 7e6ea68b8..629421f59 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -428,6 +428,17 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } // ── helpers ─────────────────────────────────────────────────────────── diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index a2a973011..4922f5355 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -253,10 +253,6 @@ pub struct Config { #[serde(default = "default_ssh_gateway_port")] pub ssh_gateway_port: u16, - /// Path for SSH CONNECT/upgrade requests. - #[serde(default = "default_ssh_connect_path")] - pub ssh_connect_path: String, - /// SSH listen port inside sandbox containers that expose a TCP endpoint. #[serde(default = "default_sandbox_ssh_port")] pub sandbox_ssh_port: u16, @@ -410,7 +406,6 @@ impl Config { grpc_endpoint: String::new(), ssh_gateway_host: default_ssh_gateway_host(), ssh_gateway_port: default_ssh_gateway_port(), - ssh_connect_path: default_ssh_connect_path(), sandbox_ssh_port: default_sandbox_ssh_port(), sandbox_ssh_socket_path: default_sandbox_ssh_socket_path(), ssh_handshake_secret: String::new(), @@ -520,13 +515,6 @@ impl Config { self } - /// Create a new configuration with the SSH connect path. - #[must_use] - pub fn with_ssh_connect_path(mut self, path: impl Into) -> Self { - self.ssh_connect_path = path.into(); - self - } - /// Create a new configuration with the sandbox SSH port. #[must_use] pub const fn with_sandbox_ssh_port(mut self, port: u16) -> Self { @@ -601,10 +589,6 @@ const fn default_ssh_gateway_port() -> u16 { DEFAULT_SERVER_PORT } -fn default_ssh_connect_path() -> String { - "/connect/ssh".to_string() -} - fn default_sandbox_ssh_socket_path() -> String { "/run/openshell/ssh.sock".to_string() } diff --git a/crates/openshell-core/src/forward.rs b/crates/openshell-core/src/forward.rs index b48e5594a..82fe0114c 100644 --- a/crates/openshell-core/src/forward.rs +++ b/crates/openshell-core/src/forward.rs @@ -469,6 +469,20 @@ pub fn resolve_ssh_gateway( (gateway_host.to_string(), gateway_port) } +/// Format a gateway URL, bracketing IPv6 literals when needed. +pub fn format_gateway_url(scheme: &str, host: &str, port: u16) -> String { + let host = if host + .parse::() + .is_ok_and(|ip| ip.is_ipv6()) + && !host.starts_with('[') + { + format!("[{host}]") + } else { + host.to_string() + }; + format!("{scheme}://{host}:{port}") +} + /// Shell-escape a value for use inside a `ProxyCommand` string. pub fn shell_escape(value: &str) -> String { if value.is_empty() { @@ -525,14 +539,11 @@ pub enum SshSessionResponseError { InvalidScheme, #[error("gateway_port must be in range 1..=65535")] InvalidPort, - #[error("connect_path must start with '/'")] - ConnectPathNotAbsolute, } const MAX_SANDBOX_ID_LEN: usize = 128; const MAX_TOKEN_LEN: usize = 4096; const MAX_GATEWAY_HOST_LEN: usize = 253; -const MAX_CONNECT_PATH_LEN: usize = 2048; const MAX_FINGERPRINT_LEN: usize = 256; fn is_sandbox_id_byte(b: u8) -> bool { @@ -551,33 +562,6 @@ fn is_gateway_host_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b':' | b'[' | b']') } -fn is_connect_path_byte(b: u8) -> bool { - // RFC 3986 path charset (pchar) without `?`, `#`, space, backtick, or - // backslash. `%` is permitted so percent-encoded segments round-trip. - b.is_ascii_alphanumeric() - || matches!( - b, - b'-' | b'.' - | b'_' - | b'~' - | b'!' - | b'$' - | b'&' - | b'\'' - | b'(' - | b')' - | b'*' - | b'+' - | b',' - | b';' - | b'=' - | b':' - | b'@' - | b'/' - | b'%' - ) -} - fn is_fingerprint_byte(b: u8) -> bool { b.is_ascii_alphanumeric() || matches!(b, b':' | b'+' | b'/' | b'=' | b'-') } @@ -612,25 +596,6 @@ pub fn validate_ssh_session_response( if resp.gateway_port == 0 || resp.gateway_port > u32::from(u16::MAX) { return Err(SshSessionResponseError::InvalidPort); } - if resp.connect_path.is_empty() { - return Err(SshSessionResponseError::Empty { - field: "connect_path", - }); - } - if !resp.connect_path.starts_with('/') { - return Err(SshSessionResponseError::ConnectPathNotAbsolute); - } - if resp.connect_path.len() > MAX_CONNECT_PATH_LEN { - return Err(SshSessionResponseError::TooLong { - field: "connect_path", - max: MAX_CONNECT_PATH_LEN, - }); - } - if !resp.connect_path.bytes().all(is_connect_path_byte) { - return Err(SshSessionResponseError::InvalidChars { - field: "connect_path", - }); - } if !resp.host_key_fingerprint.is_empty() { if resp.host_key_fingerprint.len() > MAX_FINGERPRINT_LEN { return Err(SshSessionResponseError::TooLong { @@ -735,6 +700,26 @@ mod tests { assert_eq!(port, 8080); } + #[test] + fn format_gateway_url_brackets_ipv6_literals() { + assert_eq!( + format_gateway_url("https", "::1", 8080), + "https://[::1]:8080" + ); + } + + #[test] + fn format_gateway_url_leaves_dns_and_bracketed_ipv6_unchanged() { + assert_eq!( + format_gateway_url("https", "gateway.example.com", 443), + "https://gateway.example.com:443" + ); + assert_eq!( + format_gateway_url("https", "[::1]", 8080), + "https://[::1]:8080" + ); + } + #[test] fn shell_escape_empty() { assert_eq!(shell_escape(""), "''"); @@ -757,7 +742,6 @@ mod tests { gateway_scheme: "https".to_string(), gateway_host: "gateway.example.com".to_string(), gateway_port: 443, - connect_path: "/connect/ssh".to_string(), host_key_fingerprint: String::new(), expires_at_ms: 0, } @@ -857,33 +841,6 @@ mod tests { } } - #[test] - fn validate_ssh_session_response_rejects_connect_path_without_leading_slash() { - let mut r = valid_session_response(); - r.connect_path = "connect/ssh".to_string(); - assert!(matches!( - validate_ssh_session_response(&r), - Err(SshSessionResponseError::ConnectPathNotAbsolute) - )); - } - - #[test] - fn validate_ssh_session_response_rejects_injected_connect_path() { - // `$`, `(`, `)` are valid RFC 3986 sub-delims (pchar) so the validator - // permits them; shell_escape is the second defensive layer. The - // following characters are rejected at the validator boundary because - // they are either unambiguously hostile in a shell context or invalid - // per RFC 3986 in the path component. - for bad in ["/x`id`y", "/x y", "/x\nb", "/x\\b", "/x?q=1", "/x#frag"] { - let mut r = valid_session_response(); - r.connect_path = bad.to_string(); - assert!( - validate_ssh_session_response(&r).is_err(), - "expected reject for connect_path={bad:?}" - ); - } - } - #[test] fn build_proxy_command_escapes_shell_metacharacters() { // Attacker-controlled values in every escapable position. diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 42b30fbae..08b413429 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -61,6 +61,7 @@ pub fn severity_tag(severity_id: u8) -> &'static str { /// Max length for the reason text in `[reason:...]` before truncation. const MAX_REASON_LEN: usize = 80; +const MAX_MESSAGE_LEN: usize = 120; /// Format a `[reason:...]` tag from `status_detail` (or `message` fallback) /// for denied events. Returns an empty string if neither field is set. @@ -80,6 +81,19 @@ fn reason_tag(base: &BaseEventData) -> String { } } +fn message_tag(base: &BaseEventData) -> String { + let text = base.message.as_deref().unwrap_or(""); + if text.is_empty() { + return String::new(); + } + let text = text.replace(['\n', '\r'], " "); + if text.len() > MAX_MESSAGE_LEN { + format!(" [msg:{}...]", &text[..MAX_MESSAGE_LEN]) + } else { + format!(" [msg:{text}]") + } +} + impl OcsfEvent { /// Produce the single-line shorthand for `openshell.log` and gRPC log push. /// @@ -140,7 +154,13 @@ impl OcsfEvent { (false, true) => format!(" {action}"), (false, false) => format!(" {action}{arrow}"), }; - format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}") + let message_ctx = + if detail.is_empty() && rule_ctx.is_empty() && reason_ctx.is_empty() { + message_tag(&e.base) + } else { + String::new() + }; + format!("NET:{activity} {sev}{detail}{rule_ctx}{reason_ctx}{message_ctx}") } Self::HttpActivity(e) => { @@ -541,6 +561,33 @@ mod tests { ); } + #[test] + fn test_network_activity_shorthand_shows_message_when_no_key_fields() { + let event = OcsfEvent::NetworkActivity(NetworkActivityEvent { + base: { + let mut b = base(4001, "Network Activity", 4, "Network Activity", 1, "Open"); + b.set_message("relay open (channel_id=ch-42)"); + b + }, + src_endpoint: None, + dst_endpoint: None, + proxy_endpoint: None, + actor: None, + firewall_rule: None, + connection_info: None, + action: None, + disposition: None, + observation_point_id: None, + is_src_dst_assignment_known: None, + }); + + let shorthand = event.format_shorthand(); + assert_eq!( + shorthand, + "NET:OPEN [INFO] [msg:relay open (channel_id=ch-42)]" + ); + } + #[test] fn test_http_activity_shorthand_denied_shows_reason() { let mut b = base(4002, "HTTP Activity", 4, "Network Activity", 99, "Other"); diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 25a28af54..81c575cfa 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -817,7 +817,7 @@ pub async fn run_sandbox( sandbox_id.as_ref(), ssh_socket_path.as_ref(), ) { - supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone()); + supervisor_session::spawn(endpoint.clone(), id.clone(), socket.clone(), ssh_netns_fd); info!("supervisor session task spawned"); } diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs index 490a0cba7..49c52f9c2 100644 --- a/crates/openshell-sandbox/src/supervisor_session.rs +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -4,24 +4,28 @@ //! Persistent supervisor-to-gateway session. //! //! Maintains a long-lived `ConnectSupervisor` bidirectional gRPC stream to the -//! gateway. When the gateway sends `RelayOpen`, the supervisor initiates a -//! `RelayStream` gRPC call (a new HTTP/2 stream multiplexed over the same -//! TCP+TLS connection as the control stream) and bridges it to the local SSH -//! daemon. The supervisor is a dumb byte bridge — it has no protocol awareness -//! of the SSH or NSSH1 bytes flowing through. - +//! gateway. When the gateway sends `RelayOpen`, the supervisor dials the +//! requested local target, initiates a `RelayStream` gRPC call (a new HTTP/2 +//! stream multiplexed over the same TCP+TLS connection as the control stream), +//! and bridges bytes. The supervisor is a dumb byte bridge after target +//! selection — it has no protocol awareness of the bytes flowing through. + +use std::net::IpAddr; +#[cfg(target_os = "linux")] +use std::os::fd::RawFd; use std::time::Duration; use openshell_core::proto::open_shell_client::OpenShellClient; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorHeartbeat, SupervisorHello, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, RelayOpenResult, SupervisorHeartbeat, + SupervisorHello, SupervisorMessage, TcpRelayTarget, gateway_message, relay_open, + supervisor_message, }; use openshell_ocsf::{ - ActivityId, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, SeverityId, StatusId, - ocsf_emit, + ActivityId, ConnectionInfo, Endpoint, NetworkActivityBuilder, OcsfEvent, SandboxContext, + SeverityId, StatusId, ocsf_emit, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc; use tokio_stream::StreamExt; use tonic::transport::Channel; @@ -91,33 +95,103 @@ fn session_failed_event( .build() } -fn relay_open_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_target_endpoint(open: &RelayOpen) -> Option { + let relay_open::Target::Tcp(target) = open.target.as_ref()? else { + return None; + }; + let host = target.host.trim(); + let port = u16::try_from(target.port).ok()?; + if let Ok(ip) = host.parse() { + Some(Endpoint::from_ip(ip, port)) + } else { + Some(Endpoint::from_domain(host, port)) + } +} + +fn relay_target_kind(open: &RelayOpen) -> &'static str { + match open.target.as_ref() { + Some(relay_open::Target::Tcp(_)) => "tcp relay", + Some(relay_open::Target::Ssh(_)) | None => "ssh relay", + } +} + +fn relay_target_message( + open: &RelayOpen, + state: &str, + ssh_socket_path: &std::path::Path, +) -> String { + let target = match open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => { + format!("{}:{}", target.host.trim(), target.port) + } + Some(relay_open::Target::Ssh(_)) | None => { + format!("unix:{}", ssh_socket_path.display()) + } + }; + + format!( + "{} {state} (channel_id={}, target={target})", + relay_target_kind(open), + open.channel_id + ) +} + +fn relay_open_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Open) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay open (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "open", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_closed_event(ctx: &SandboxContext, channel_id: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_closed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Close) .severity(SeverityId::Informational) .status(StatusId::Success) - .message(format!("relay closed (channel_id={channel_id})")) - .build() + .message(relay_target_message(open, "closed", ssh_socket_path)); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } -fn relay_failed_event(ctx: &SandboxContext, channel_id: &str, error: &str) -> OcsfEvent { - NetworkActivityBuilder::new(ctx) +fn relay_failed_event( + ctx: &SandboxContext, + open: &RelayOpen, + ssh_socket_path: &std::path::Path, + error: &str, +) -> OcsfEvent { + let mut builder = NetworkActivityBuilder::new(ctx) .activity(ActivityId::Fail) .severity(SeverityId::Low) .status(StatusId::Failure) .message(format!( - "relay bridge failed (channel_id={channel_id}): {error}" - )) - .build() + "{}: {error}", + relay_target_message(open, "bridge failed", ssh_socket_path) + )); + if let Some(endpoint) = relay_target_endpoint(open) { + builder = builder + .dst_endpoint(endpoint) + .connection_info(ConnectionInfo::new("tcp")); + } + builder.build() } fn relay_close_from_gateway_event( @@ -139,6 +213,10 @@ fn relay_close_from_gateway_event( /// HTTP/2 frame size so each `RelayFrame::data` fits in one frame. const RELAY_CHUNK_SIZE: usize = 16 * 1024; +trait TargetStream: AsyncRead + AsyncWrite + Send + Unpin {} + +impl TargetStream for T where T: AsyncRead + AsyncWrite + Send + Unpin {} + fn map_stream_message( message: Result, tonic::Status>, eof_error: &'static str, @@ -158,14 +236,21 @@ pub fn spawn( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) -> tokio::task::JoinHandle<()> { - tokio::spawn(run_session_loop(endpoint, sandbox_id, ssh_socket_path)) + tokio::spawn(run_session_loop( + endpoint, + sandbox_id, + ssh_socket_path, + netns_fd, + )) } async fn run_session_loop( endpoint: String, sandbox_id: String, ssh_socket_path: std::path::PathBuf, + netns_fd: Option, ) { let mut backoff = INITIAL_BACKOFF; let mut attempt: u64 = 0; @@ -173,7 +258,7 @@ async fn run_session_loop( loop { attempt += 1; - match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path).await { + match run_single_session(&endpoint, &sandbox_id, &ssh_socket_path, netns_fd).await { Ok(()) => { let event = session_closed_event(crate::ocsf_ctx(), &endpoint, &sandbox_id); ocsf_emit!(event); @@ -194,6 +279,7 @@ async fn run_single_session( endpoint: &str, sandbox_id: &str, ssh_socket_path: &std::path::Path, + netns_fd: Option, ) -> Result<(), Box> { // Connect to the gateway. The same `Channel` is used for both the // long-lived control stream and all data-plane `RelayStream` calls, so @@ -262,7 +348,9 @@ async fn run_single_session( &msg, sandbox_id, ssh_socket_path, + netns_fd, &channel, + &tx, ); } _ = heartbeat_interval.tick() => { @@ -283,7 +371,9 @@ fn handle_gateway_message( msg: &GatewayMessage, sandbox_id: &str, ssh_socket_path: &std::path::Path, + netns_fd: Option, channel: &Channel, + tx: &mpsc::Sender, ) { match &msg.payload { Some(gateway_message::Payload::Heartbeat(_)) => { @@ -291,22 +381,30 @@ fn handle_gateway_message( } Some(gateway_message::Payload::RelayOpen(open)) => { let channel_id = open.channel_id.clone(); + let relay_open = open.clone(); let sandbox_id = sandbox_id.to_string(); let channel = channel.clone(); let ssh_socket_path = ssh_socket_path.to_path_buf(); + let tx = tx.clone(); - let event = relay_open_event(crate::ocsf_ctx(), &channel_id); + let event = relay_open_event(crate::ocsf_ctx(), &relay_open, &ssh_socket_path); ocsf_emit!(event); tokio::spawn(async move { - match handle_relay_open(&channel_id, &ssh_socket_path, channel).await { + let event_open = relay_open.clone(); + match handle_relay_open(relay_open, &ssh_socket_path, netns_fd, channel, tx).await { Ok(()) => { - let event = relay_closed_event(crate::ocsf_ctx(), &channel_id); + let event = + relay_closed_event(crate::ocsf_ctx(), &event_open, &ssh_socket_path); ocsf_emit!(event); } Err(e) => { - let event = - relay_failed_event(crate::ocsf_ctx(), &channel_id, &e.to_string()); + let event = relay_failed_event( + crate::ocsf_ctx(), + &event_open, + &ssh_socket_path, + &e.to_string(), + ); ocsf_emit!(event); warn!( sandbox_id = %sandbox_id, @@ -336,10 +434,23 @@ fn handle_gateway_message( /// TLS handshake. The first `RelayFrame` we send is a `RelayInit`; subsequent /// frames carry raw SSH bytes in `data`. async fn handle_relay_open( - channel_id: &str, + relay_open: RelayOpen, ssh_socket_path: &std::path::Path, + netns_fd: Option, channel: Channel, + tx: mpsc::Sender, ) -> Result<(), Box> { + let channel_id = relay_open.channel_id.clone(); + let target = match open_target(&relay_open, ssh_socket_path, netns_fd).await { + Ok(target) => target, + Err(err) => { + send_relay_open_result(&tx, &channel_id, false, err.to_string()).await; + return Err(err); + } + }; + + send_relay_open_result(&tx, &channel_id, true, String::new()).await; + let mut client = OpenShellClient::new(channel); // Outbound chunks to the gateway. @@ -351,7 +462,7 @@ async fn handle_relay_open( .send(RelayFrame { payload: Some(openshell_core::proto::relay_frame::Payload::Init( RelayInit { - channel_id: channel_id.to_string(), + channel_id: channel_id.clone(), }, )), }) @@ -366,21 +477,19 @@ async fn handle_relay_open( let mut inbound = response.into_inner(); // Connect to the local SSH daemon on its Unix socket. - let ssh = tokio::net::UnixStream::connect(ssh_socket_path).await?; - let (mut ssh_r, mut ssh_w) = ssh.into_split(); + let (mut target_r, mut target_w) = tokio::io::split(target); debug!( channel_id = %channel_id, - socket = %ssh_socket_path.display(), - "relay bridge: connected to local SSH daemon" + "relay bridge: connected to local target" ); - // SSH → gRPC (out_tx): read local SSH, forward as `RelayFrame::data`. + // Target → gRPC (out_tx): read local target, forward as `RelayFrame::data`. let out_tx_writer = out_tx.clone(); - let ssh_to_grpc = tokio::spawn(async move { + let target_to_grpc = tokio::spawn(async move { let mut buf = vec![0u8; RELAY_CHUNK_SIZE]; loop { - match ssh_r.read(&mut buf).await { + match target_r.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => { let chunk = RelayFrame { @@ -396,7 +505,7 @@ async fn handle_relay_open( } }); - // gRPC (inbound) → SSH: drain inbound chunks into the local SSH socket. + // gRPC (inbound) → target: drain inbound chunks into the local target socket. let mut inbound_err: Option = None; while let Some(next) = inbound.next().await { match next { @@ -409,8 +518,8 @@ async fn handle_relay_open( if data.is_empty() { continue; } - if let Err(e) = ssh_w.write_all(&data).await { - inbound_err = Some(format!("write to ssh failed: {e}")); + if let Err(e) = target_w.write_all(&data).await { + inbound_err = Some(format!("write to target failed: {e}")); break; } } @@ -421,13 +530,13 @@ async fn handle_relay_open( } } - // Half-close the SSH socket's write side so the daemon sees EOF. - let _ = ssh_w.shutdown().await; + // Half-close the target socket's write side so the service sees EOF. + let _ = target_w.shutdown().await; // Dropping out_tx closes the outbound gRPC stream, letting the gateway // observe EOF on its side too. drop(out_tx); - let _ = ssh_to_grpc.await; + let _ = target_to_grpc.await; if let Some(e) = inbound_err { return Err(e.into()); @@ -435,6 +544,165 @@ async fn handle_relay_open( Ok(()) } +async fn send_relay_open_result( + tx: &mpsc::Sender, + channel_id: &str, + success: bool, + error: String, +) { + let _ = tx + .send(SupervisorMessage { + payload: Some(supervisor_message::Payload::RelayOpenResult( + RelayOpenResult { + channel_id: channel_id.to_string(), + success, + error, + }, + )), + }) + .await; +} + +async fn open_target( + relay_open: &RelayOpen, + ssh_socket_path: &std::path::Path, + netns_fd: Option, +) -> Result, Box> { + match relay_open.target.as_ref() { + Some(relay_open::Target::Tcp(target)) => open_tcp_target(target, netns_fd).await, + Some(relay_open::Target::Ssh(_)) | None => { + let stream = tokio::net::UnixStream::connect(ssh_socket_path).await?; + Ok(Box::new(stream)) + } + } +} + +async fn open_tcp_target( + target: &TcpRelayTarget, + netns_fd: Option, +) -> Result, Box> { + let host = normalize_tcp_target_host(target)?; + let port = u16::try_from(target.port).map_err(|_| "tcp target port must fit in u16")?; + let stream = connect_tcp_target(host, port, netns_fd).await?; + Ok(Box::new(stream)) +} + +#[cfg(target_os = "linux")] +async fn connect_tcp_target( + host: String, + port: u16, + netns_fd: Option, +) -> Result> { + if let Some(fd) = netns_fd { + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || { + let result = (|| -> std::io::Result { + #[allow(unsafe_code)] + let rc = unsafe { libc::setns(fd, libc::CLONE_NEWNET) }; + if rc != 0 { + return Err(std::io::Error::last_os_error()); + } + std::net::TcpStream::connect((host.as_str(), port)) + })(); + let _ = tx.send(result); + }); + + let stream = rx + .await + .map_err(|_| "netns tcp connect thread panicked")??; + stream.set_nonblocking(true)?; + return Ok(tokio::net::TcpStream::from_std(stream)?); + } + + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(not(target_os = "linux"))] +async fn connect_tcp_target( + host: String, + port: u16, + _netns_fd: Option, +) -> Result> { + Ok(tokio::net::TcpStream::connect((host.as_str(), port)).await?) +} + +#[cfg(test)] +fn validate_tcp_target(target: &TcpRelayTarget) -> Result<(), String> { + normalize_tcp_target_host(target).map(|_| ()) +} + +fn normalize_tcp_target_host(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err("tcp target port must be between 1 and 65535".to_string()); + } + + let host = target.host.trim(); + if host.is_empty() { + return Err("tcp target host is required".to_string()); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| "tcp target host must be loopback".to_string())?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err("tcp target host must be loopback".to_string()) + } +} + +#[cfg(test)] +mod target_tests { + use super::*; + + fn tcp(host: &str, port: u32) -> TcpRelayTarget { + TcpRelayTarget { + host: host.to_string(), + port, + } + } + + #[test] + fn tcp_target_allows_loopback_hosts() { + validate_tcp_target(&tcp("127.0.0.1", 8080)).expect("ipv4 loopback"); + validate_tcp_target(&tcp("::1", 8080)).expect("ipv6 loopback"); + validate_tcp_target(&tcp("localhost", 8080)).expect("localhost"); + } + + #[test] + fn tcp_target_normalizes_localhost_before_dialing() { + assert_eq!( + normalize_tcp_target_host(&tcp("localhost", 8080)).expect("localhost"), + "127.0.0.1" + ); + assert_eq!( + normalize_tcp_target_host(&tcp("LOCALHOST", 8080)).expect("localhost"), + "127.0.0.1" + ); + } + + #[test] + fn tcp_target_rejects_non_loopback_hosts() { + let err = validate_tcp_target(&tcp("10.0.0.1", 8080)).expect_err("private ip rejected"); + assert_eq!(err, "tcp target host must be loopback"); + + let err = validate_tcp_target(&tcp("example.com", 8080)).expect_err("hostname rejected"); + assert_eq!(err, "tcp target host must be loopback"); + } + + #[test] + fn tcp_target_rejects_invalid_ports() { + let err = validate_tcp_target(&tcp("127.0.0.1", 0)).expect_err("zero rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + + let err = validate_tcp_target(&tcp("127.0.0.1", 70000)).expect_err("too large rejected"); + assert_eq!(err, "tcp target port must be between 1 and 65535"); + } +} + #[cfg(test)] mod ocsf_event_tests { use super::*; @@ -479,6 +747,29 @@ mod ocsf_event_tests { } } + fn ssh_relay_open(channel_id: &str) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Ssh(Default::default())), + service_id: String::new(), + } + } + + fn tcp_relay_open(channel_id: &str, host: &str, port: u32) -> RelayOpen { + RelayOpen { + channel_id: channel_id.to_string(), + target: Some(relay_open::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port, + })), + service_id: String::new(), + } + } + + fn ssh_socket_path() -> &'static std::path::Path { + std::path::Path::new("/run/openshell/ssh.sock") + } + #[test] fn session_established_emits_network_open_success() { let event = session_established_event(&ctx(), "https://gw:443", "sess-1", 30); @@ -518,22 +809,43 @@ mod ocsf_event_tests { #[test] fn relay_open_emits_network_open_success() { - let event = relay_open_event(&ctx(), "ch-42"); + let event = relay_open_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); assert_eq!(na.base.severity, SeverityId::Informational); + let msg = na.base.message.as_deref().unwrap_or_default(); + assert!(msg.contains("ch-42"), "message: {msg}"); assert!( - na.base - .message - .as_deref() - .unwrap_or_default() - .contains("ch-42") + msg.contains("target=unix:/run/openshell/ssh.sock"), + "message: {msg}" + ); + } + + #[test] + fn tcp_relay_open_emits_target_endpoint() { + let event = relay_open_event( + &ctx(), + &tcp_relay_open("ch-42", "127.0.0.1", 8765), + ssh_socket_path(), + ); + let na = network_activity(&event); + assert_eq!(na.base.activity_id, ActivityId::Open.as_u8()); + assert_eq!( + na.dst_endpoint.as_ref().and_then(|e| e.ip.as_deref()), + Some("127.0.0.1") + ); + assert_eq!(na.dst_endpoint.as_ref().and_then(|e| e.port), Some(8765)); + assert_eq!( + na.connection_info + .as_ref() + .map(|c| c.protocol_name.as_str()), + Some("tcp") ); } #[test] fn relay_closed_emits_network_close_success() { - let event = relay_closed_event(&ctx(), "ch-42"); + let event = relay_closed_event(&ctx(), &ssh_relay_open("ch-42"), ssh_socket_path()); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Close.as_u8()); assert_eq!(na.base.status, Some(StatusId::Success)); @@ -541,7 +853,12 @@ mod ocsf_event_tests { #[test] fn relay_failed_emits_network_fail_low() { - let event = relay_failed_event(&ctx(), "ch-42", "write to ssh failed"); + let event = relay_failed_event( + &ctx(), + &ssh_relay_open("ch-42"), + ssh_socket_path(), + "write to ssh failed", + ); let na = network_activity(&event); assert_eq!(na.base.activity_id, ActivityId::Fail.as_u8()); assert_eq!(na.base.severity, SeverityId::Low); diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index 534e3da37..a3098c1cf 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -127,14 +127,6 @@ struct RunArgs { #[arg(long, env = "OPENSHELL_SSH_GATEWAY_PORT", default_value_t = DEFAULT_SERVER_PORT)] ssh_gateway_port: u16, - /// HTTP path for SSH CONNECT/upgrade. - #[arg( - long, - env = "OPENSHELL_SSH_CONNECT_PATH", - default_value = "/connect/ssh" - )] - ssh_connect_path: String, - /// SSH port inside sandbox pods. #[arg(long, env = "OPENSHELL_SANDBOX_SSH_PORT", default_value_t = DEFAULT_SSH_PORT)] sandbox_ssh_port: u16, @@ -400,7 +392,6 @@ async fn run_from_args(args: RunArgs) -> Result<()> { .with_sandbox_namespace(args.sandbox_namespace) .with_ssh_gateway_host(args.ssh_gateway_host) .with_ssh_gateway_port(args.ssh_gateway_port) - .with_ssh_connect_path(args.ssh_connect_path) .with_sandbox_ssh_port(args.sandbox_ssh_port) .with_ssh_handshake_skew_secs(args.ssh_handshake_skew_secs); diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index ebb8b1021..16f016081 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -31,8 +31,9 @@ use openshell_core::proto::{ RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, - SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, - UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + SupervisorMessage, TcpForwardFrame, UndoDraftChunkRequest, UndoDraftChunkResponse, + UpdateConfigRequest, UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, + open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; @@ -240,6 +241,16 @@ impl OpenShell for OpenShellService { sandbox::handle_exec_sandbox(&self.state, request).await } + type ForwardTcpStream = + Pin> + Send + 'static>>; + + async fn forward_tcp( + &self, + request: Request>, + ) -> Result, Status> { + sandbox::handle_forward_tcp(&self.state, request).await + } + // --- SSH sessions --- async fn create_ssh_session( diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 65ac69acb..7786a0b13 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -12,6 +12,7 @@ use crate::ServerState; use crate::persistence::{ObjectType, generate_name}; use futures::future; +use openshell_core::ObjectId; use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteSandboxRequest, DeleteSandboxResponse, @@ -19,10 +20,13 @@ use openshell_core::proto::{ ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, WatchSandboxRequest, + SandboxStreamEvent, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, WatchSandboxRequest, + relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; +use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; @@ -40,6 +44,8 @@ use super::validation::{ }; use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms}; +const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; + // --------------------------------------------------------------------------- // Sandbox lifecycle handlers // --------------------------------------------------------------------------- @@ -646,9 +652,8 @@ pub(super) async fn handle_exec_sandbox( } // Open a relay channel through the supervisor session. Use a 15s - // session-wait timeout — enough to cover a transient supervisor - // reconnect, but shorter than `/connect/ssh` since `ExecSandbox` is - // typically called during normal operation (not right after create). + // session-wait timeout, enough to cover a transient supervisor reconnect + // while still failing quickly during normal operation. let (channel_id, relay_rx) = state .supervisor_sessions .open_relay(sandbox.object_id(), std::time::Duration::from_secs(15)) @@ -669,7 +674,12 @@ pub(super) async fn handle_exec_sandbox( let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) .await { - Ok(Ok(stream)) => stream, + Ok(Ok(Ok(stream))) => stream, + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandbox: relay target open failed"); + let _ = tx.send(Err(status)).await; + return; + } Ok(Err(_)) => { warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); let _ = tx @@ -706,6 +716,329 @@ pub(super) async fn handle_exec_sandbox( Ok(Response::new(ReceiverStream::new(rx))) } +pub(super) async fn handle_forward_tcp( + state: &Arc, + request: Request>, +) -> Result< + Response< + Pin> + Send + 'static>>, + >, + Status, +> { + let mut inbound = request.into_inner(); + let first = inbound + .message() + .await? + .ok_or_else(|| Status::invalid_argument("empty ForwardTcp stream"))?; + let init = match first.payload { + Some(openshell_core::proto::tcp_forward_frame::Payload::Init(init)) => init, + _ => { + return Err(Status::invalid_argument( + "first TcpForwardFrame must be init", + )); + } + }; + + let target = validate_tcp_forward_init(&init)?; + + let sandbox = state + .store + .get_message::(&init.sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { + return Err(Status::failed_precondition("sandbox is not ready")); + } + + let connection_guard = acquire_forward_connection_guard(state, &init, &sandbox).await?; + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay_with_target( + sandbox.object_id(), + target, + init.service_id.clone(), + std::time::Duration::from_secs(15), + ) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; + + let sandbox_id = sandbox.object_id().to_string(); + let (tx, rx) = mpsc::channel::>(256); + tokio::spawn(async move { + let _connection_guard = connection_guard; + let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) + .await + { + Ok(Ok(Ok(stream))) => stream, + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ForwardTcp: relay target open failed"); + let _ = tx.send(Err(status)).await; + return; + } + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + return; + } + }; + + bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await; + }); + + let stream: Pin< + Box> + Send + 'static>, + > = Box::pin(ReceiverStream::new(rx)); + Ok(Response::new(stream)) +} + +struct ForwardConnectionGuard { + state: Arc, + token: Option, + sandbox_id: String, +} + +impl Drop for ForwardConnectionGuard { + fn drop(&mut self) { + if let Some(token) = self.token.as_deref() { + decrement_ssh_connection_count(&self.state.ssh_connections_by_token, token); + decrement_ssh_connection_count( + &self.state.ssh_connections_by_sandbox, + &self.sandbox_id, + ); + } + } +} + +async fn acquire_forward_connection_guard( + state: &Arc, + init: &TcpForwardInit, + sandbox: &Sandbox, +) -> Result { + let sandbox_id = sandbox.object_id().to_string(); + let token = init.authorization_token.trim(); + if token.is_empty() { + return Err(Status::unauthenticated( + "authorization_token is required for ForwardTcp", + )); + } + + validate_ssh_forward_token(state, token, &sandbox_id).await?; + acquire_ssh_connection_slots( + &state.ssh_connections_by_token, + &state.ssh_connections_by_sandbox, + token, + &sandbox_id, + )?; + + Ok(ForwardConnectionGuard { + state: state.clone(), + token: Some(token.to_string()), + sandbox_id, + }) +} + +async fn validate_ssh_forward_token( + state: &Arc, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { + let session = state + .store + .get_message::(token) + .await + .map_err(|e| Status::internal(format!("fetch SSH session failed: {e}")))? + .ok_or_else(|| Status::unauthenticated("SSH session token not found"))?; + + if session.revoked || session.sandbox_id != sandbox_id { + return Err(Status::unauthenticated("SSH session token is not valid")); + } + + if session.expires_at_ms > 0 { + let now_ms = current_time_ms() + .map_err(|e| Status::internal(format!("timestamp generation failed: {e}")))?; + if now_ms > session.expires_at_ms { + return Err(Status::unauthenticated("SSH session token expired")); + } + } + + Ok(()) +} + +fn acquire_ssh_connection_slots( + token_counts: &std::sync::Mutex>, + sandbox_counts: &std::sync::Mutex>, + token: &str, + sandbox_id: &str, +) -> Result<(), Status> { + const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; + const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; + + { + let mut counts = token_counts.lock().unwrap(); + let count = counts.entry(token.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_TOKEN { + return Err(Status::resource_exhausted( + "SSH session connection limit reached", + )); + } + *count += 1; + } + + { + let mut counts = sandbox_counts.lock().unwrap(); + let count = counts.entry(sandbox_id.to_string()).or_insert(0); + if *count >= MAX_CONNECTIONS_PER_SANDBOX { + decrement_ssh_connection_count(token_counts, token); + return Err(Status::resource_exhausted( + "sandbox SSH connection limit reached", + )); + } + *count += 1; + } + + Ok(()) +} + +fn decrement_ssh_connection_count( + counts: &std::sync::Mutex>, + key: &str, +) { + let mut counts = counts.lock().unwrap(); + if let Some(count) = counts.get_mut(key) { + *count = count.saturating_sub(1); + if *count == 0 { + counts.remove(key); + } + } +} + +fn validate_tcp_forward_init(init: &TcpForwardInit) -> Result { + if init.sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + if let Some(target) = init.target.as_ref() { + return match target { + tcp_forward_init::Target::Ssh(_) => Ok(relay_open::Target::Ssh(Default::default())), + tcp_forward_init::Target::Tcp(target) => Ok(relay_open::Target::Tcp( + validate_tcp_forward_target(target)?, + )), + }; + } + + Err(Status::invalid_argument("tcp forward target is required")) +} + +fn validate_tcp_forward_target(target: &TcpRelayTarget) -> Result { + if target.port == 0 || target.port > u32::from(u16::MAX) { + return Err(Status::invalid_argument( + "tcp target port must be between 1 and 65535", + )); + } + + validate_tcp_target_parts(target.host.trim(), target.port).map(|host| TcpRelayTarget { + host, + port: target.port, + }) +} + +fn validate_tcp_target_parts(host: &str, _port: u32) -> Result { + if host.is_empty() { + return Err(Status::invalid_argument("tcp target host is required")); + } + if host.eq_ignore_ascii_case("localhost") { + return Ok("127.0.0.1".to_string()); + } + + let ip: IpAddr = host + .parse() + .map_err(|_| Status::invalid_argument("tcp target host must be loopback"))?; + if ip.is_loopback() { + Ok(ip.to_string()) + } else { + Err(Status::invalid_argument("tcp target host must be loopback")) + } +} + +async fn bridge_forward_tcp_stream( + mut inbound: tonic::Streaming, + relay_stream: tokio::io::DuplexStream, + tx: mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, +) { + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + + let sandbox_id_in = sandbox_id.to_string(); + let channel_id_in = channel_id.to_string(); + tokio::spawn(async move { + loop { + match inbound.message().await { + Ok(Some(frame)) => { + let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = + frame.payload + else { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, "ForwardTcp: received non-data frame after init"); + break; + }; + if data.is_empty() { + continue; + } + if let Err(err) = + tokio::io::AsyncWriteExt::write_all(&mut relay_write, &data).await + { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: write to relay failed"); + break; + } + } + Ok(None) => break, + Err(err) => { + warn!(sandbox_id = %sandbox_id_in, channel_id = %channel_id_in, error = %err, "ForwardTcp: inbound stream failed"); + break; + } + } + } + let _ = tokio::io::AsyncWriteExt::shutdown(&mut relay_write).await; + }); + + let mut buf = vec![0u8; TCP_FORWARD_CHUNK_SIZE]; + loop { + match tokio::io::AsyncReadExt::read(&mut relay_read, &mut buf).await { + Ok(0) => break, + Ok(n) => { + let frame = TcpForwardFrame { + payload: Some(openshell_core::proto::tcp_forward_frame::Payload::Data( + buf[..n].to_vec(), + )), + }; + if tx.send(Ok(frame)).await.is_err() { + break; + } + } + Err(err) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %err, "ForwardTcp: read from relay failed"); + let _ = tx + .send(Err(Status::unavailable(format!( + "relay read failed: {err}" + )))) + .await; + break; + } + } + } +} + // --------------------------------------------------------------------------- // SSH session handlers // --------------------------------------------------------------------------- @@ -773,7 +1106,6 @@ pub(super) async fn handle_create_ssh_session( gateway_host, gateway_port: gateway_port.into(), gateway_scheme: scheme.to_string(), - connect_path: state.config.ssh_connect_path.clone(), host_key_fingerprint: String::new(), expires_at_ms, })) @@ -882,8 +1214,7 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result /// /// This is the relay equivalent of `stream_exec_over_ssh`. Instead of dialing a /// sandbox endpoint directly, the SSH transport runs over a `DuplexStream` that -/// is bridged to the supervisor's local SSH daemon via a reverse HTTP CONNECT -/// tunnel. +/// is bridged to the supervisor's local SSH daemon via `RelayStream`. #[allow(clippy::too_many_arguments)] async fn stream_exec_over_relay( tx: mpsc::Sender>, @@ -1219,6 +1550,87 @@ mod tests { assert!(build_remote_exec_command(&req).is_err()); } + #[test] + fn tcp_forward_init_allows_loopback_targets() { + for host in ["127.0.0.1", "::1", "localhost"] { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: host.to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + validate_tcp_forward_init(&init).expect("loopback target should pass"); + } + } + + #[test] + fn tcp_forward_init_allows_ssh_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + target: Some(tcp_forward_init::Target::Ssh(Default::default())), + ..Default::default() + }; + match validate_tcp_forward_init(&init).expect("ssh target should pass") { + relay_open::Target::Ssh(_) => {} + other => panic!("expected SSH target, got {other:?}"), + } + } + + #[test] + fn tcp_forward_init_rejects_non_loopback_targets() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "example.com".to_string(), + port: 8080, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("hostname rejected") + .message(), + "tcp target host must be loopback" + ); + } + + #[test] + fn tcp_forward_init_rejects_invalid_port() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + service_id: String::new(), + target: Some(tcp_forward_init::Target::Tcp(TcpRelayTarget { + host: "127.0.0.1".to_string(), + port: 0, + })), + authorization_token: String::new(), + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("zero port rejected") + .message(), + "tcp target port must be between 1 and 65535" + ); + } + + #[test] + fn tcp_forward_init_requires_target() { + let init = TcpForwardInit { + sandbox_id: "sbx".to_string(), + ..Default::default() + }; + assert_eq!( + validate_tcp_forward_init(&init) + .expect_err("missing target rejected") + .message(), + "tcp forward target is required" + ); + } + // ---- petname / generate_name ---- #[test] diff --git a/crates/openshell-server/src/http.rs b/crates/openshell-server/src/http.rs index 7650c2339..7ca9cb8bf 100644 --- a/crates/openshell-server/src/http.rs +++ b/crates/openshell-server/src/http.rs @@ -59,7 +59,5 @@ async fn render_metrics(State(handle): State) -> impl IntoResp /// Create the HTTP router. pub fn http_router(state: Arc) -> Router { - crate::ssh_tunnel::router(state.clone()) - .merge(crate::ws_tunnel::router(state.clone())) - .merge(crate::auth::router(state)) + crate::ws_tunnel::router(state.clone()).merge(crate::auth::router(state)) } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 07c3cef5c..bca6e44aa 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -31,7 +31,7 @@ mod persistence; pub(crate) mod policy_store; mod sandbox_index; mod sandbox_watch; -mod ssh_tunnel; +mod ssh_sessions; pub mod supervisor_session; mod tls; pub mod tracing_bus; @@ -220,7 +220,7 @@ pub async fn run_server( } state.compute.spawn_watchers(); - ssh_tunnel::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); + ssh_sessions::spawn_session_reaper(store.clone(), Duration::from_secs(3600)); supervisor_session::spawn_relay_reaper(state.clone(), Duration::from_secs(30)); // Create the multiplexed service diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index 93e58d202..bca9a2171 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -470,7 +470,6 @@ fn grpc_status_from_response(res: &Response) -> String { fn normalize_http_path(path: &str) -> &'static str { match path { - p if p.starts_with("/connect/ssh") => "/connect/ssh", p if p.starts_with("/_ws_tunnel") => "/_ws_tunnel", p if p.starts_with("/auth/") => "/auth", _ => "unknown", @@ -724,19 +723,6 @@ mod tests { assert_eq!(grpc_method_from_path(""), ""); } - #[test] - fn normalize_ssh_path() { - assert_eq!(normalize_http_path("/connect/ssh"), "/connect/ssh"); - } - - #[test] - fn normalize_ssh_path_with_trailing_segments() { - assert_eq!( - normalize_http_path("/connect/ssh?token=abc"), - "/connect/ssh" - ); - } - #[test] fn normalize_ws_tunnel() { assert_eq!(normalize_http_path("/_ws_tunnel"), "/_ws_tunnel"); diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs new file mode 100644 index 000000000..f8d85033d --- /dev/null +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! SSH session token storage and cleanup. + +use openshell_core::ObjectId; +use openshell_core::proto::SshSession; +use prost::Message; +use std::sync::Arc; +use std::time::Duration; +use tracing::{info, warn}; + +use crate::persistence::{ObjectType, Store}; + +impl ObjectType for SshSession { + fn object_type() -> &'static str { + "ssh_session" + } +} + +/// Spawn a background task that periodically reaps expired and revoked SSH sessions. +pub fn spawn_session_reaper(store: Arc, interval: Duration) { + tokio::spawn(async move { + tokio::time::sleep(interval).await; + + loop { + if let Err(e) = reap_expired_sessions(&store).await { + warn!(error = %e, "SSH session reaper sweep failed"); + } + tokio::time::sleep(interval).await; + } + }); +} + +async fn reap_expired_sessions(store: &Store) -> Result<(), String> { + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as i64; + + let records = store + .list(SshSession::object_type(), 1000, 0) + .await + .map_err(|e| e.to_string())?; + + let mut reaped = 0u32; + for record in records { + let session: SshSession = match Message::decode(record.payload.as_slice()) { + Ok(s) => s, + Err(_) => continue, + }; + + let should_delete = + (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) || session.revoked; + + if should_delete { + if let Err(e) = store + .delete(SshSession::object_type(), session.object_id()) + .await + { + warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); + } else { + reaped += 1; + } + } + } + + if reaped > 0 { + info!(count = reaped, "SSH session reaper: cleaned up sessions"); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { + SshSession { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: id.to_string(), + name: format!("session-{id}"), + created_at_ms: 1000, + labels: HashMap::new(), + }), + sandbox_id: sandbox_id.to_string(), + token: id.to_string(), + expires_at_ms, + revoked, + } + } + + fn now_ms() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64 + } + + #[tokio::test] + async fn reaper_deletes_expired_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); + store.put_message(&expired).await.unwrap(); + + let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); + store.put_message(&valid).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("expired1") + .await + .unwrap() + .is_none(), + "expired session should be reaped" + ); + assert!( + store + .get_message::("valid1") + .await + .unwrap() + .is_some(), + "valid session should be kept" + ); + } + + #[tokio::test] + async fn reaper_deletes_revoked_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let revoked = make_session("revoked1", "sbx1", 0, true); + store.put_message(&revoked).await.unwrap(); + + let active = make_session("active1", "sbx1", 0, false); + store.put_message(&active).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("revoked1") + .await + .unwrap() + .is_none(), + "revoked session should be reaped" + ); + assert!( + store + .get_message::("active1") + .await + .unwrap() + .is_some(), + "active session should be kept" + ); + } + + #[tokio::test] + async fn reaper_preserves_zero_expiry_sessions() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let no_expiry = make_session("noexpiry1", "sbx1", 0, false); + store.put_message(&no_expiry).await.unwrap(); + + reap_expired_sessions(&store).await.unwrap(); + + assert!( + store + .get_message::("noexpiry1") + .await + .unwrap() + .is_some(), + "session with no expiry should be preserved" + ); + } +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs deleted file mode 100644 index bd317d53f..000000000 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ /dev/null @@ -1,541 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! SSH tunnel handler for the multiplexed gateway. - -use axum::{Router, extract::State, http::Method, response::IntoResponse, routing::any}; -use http::StatusCode; -use hyper::Request; -use hyper_util::rt::TokioIo; -use openshell_core::proto::{Sandbox, SandboxPhase, SshSession}; -use prost::Message; -use std::sync::Arc; -use std::time::Duration; -use tokio::io::AsyncWriteExt; -use tracing::{info, warn}; - -use crate::ServerState; -use crate::persistence::{ObjectType, Store}; - -const HEADER_SANDBOX_ID: &str = "x-sandbox-id"; -const HEADER_TOKEN: &str = "x-sandbox-token"; - -/// Maximum concurrent SSH tunnel connections per session token. -const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; - -/// Redact a bearer token for safe logging — show only the last 4 characters. -fn redact_token(token: &str) -> String { - if token.len() <= 4 { - "****".to_string() - } else { - format!("****{}", &token[token.len() - 4..]) - } -} - -/// Maximum concurrent SSH tunnel connections per sandbox. -const MAX_CONNECTIONS_PER_SANDBOX: u32 = 20; - -fn acquire_connection_slots( - token_counts: &std::sync::Mutex>, - sandbox_counts: &std::sync::Mutex>, - token: &str, - sandbox_id: &str, -) -> Result<(), ConnectionLimit> { - { - let mut counts = token_counts.lock().unwrap(); - let count = counts.entry(token.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_TOKEN { - return Err(ConnectionLimit::PerToken); - } - *count += 1; - } - - { - let mut counts = sandbox_counts.lock().unwrap(); - let count = counts.entry(sandbox_id.to_string()).or_insert(0); - if *count >= MAX_CONNECTIONS_PER_SANDBOX { - decrement_connection_count(token_counts, token); - return Err(ConnectionLimit::PerSandbox); - } - *count += 1; - } - - Ok(()) -} - -enum ConnectionLimit { - PerToken, - PerSandbox, -} - -pub fn router(state: Arc) -> Router { - Router::new() - .route("/connect/ssh", any(ssh_connect)) - .with_state(state) -} - -async fn ssh_connect( - State(state): State>, - req: Request, -) -> impl IntoResponse { - if req.method() != Method::CONNECT { - return StatusCode::METHOD_NOT_ALLOWED.into_response(); - } - - let sandbox_id = match header_value(req.headers(), HEADER_SANDBOX_ID) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - let token = match header_value(req.headers(), HEADER_TOKEN) { - Ok(value) => value, - Err(status) => return status.into_response(), - }; - - let session = match state.store.get_message::(&token).await { - Ok(Some(session)) => session, - Ok(None) => return StatusCode::UNAUTHORIZED.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch SSH session"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if session.revoked || session.sandbox_id != sandbox_id { - return StatusCode::UNAUTHORIZED.into_response(); - } - - // Check token expiry (0 means no expiry for backward compatibility). - if session.expires_at_ms > 0 { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - if now_ms > session.expires_at_ms { - return StatusCode::UNAUTHORIZED.into_response(); - } - } - - let sandbox = match state.store.get_message::(&sandbox_id).await { - Ok(Some(sandbox)) => sandbox, - Ok(None) => return StatusCode::NOT_FOUND.into_response(), - Err(err) => { - warn!(error = %err, "Failed to fetch sandbox"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - if SandboxPhase::try_from(sandbox.phase).ok() != Some(SandboxPhase::Ready) { - return StatusCode::PRECONDITION_FAILED.into_response(); - } - - // Enforce connection caps *before* opening a relay — otherwise denied - // calls churn pending relay slots and wake the supervisor until the relay - // timeout elapses. - if let Err(limit) = acquire_connection_slots( - &state.ssh_connections_by_token, - &state.ssh_connections_by_sandbox, - &token, - &sandbox_id, - ) { - match limit { - ConnectionLimit::PerToken => { - warn!(token = %redact_token(&token), "SSH tunnel: per-token connection limit reached"); - } - ConnectionLimit::PerSandbox => { - warn!(sandbox_id = %sandbox_id, "SSH tunnel: per-sandbox connection limit reached"); - } - } - return StatusCode::TOO_MANY_REQUESTS.into_response(); - } - - // Open a relay channel through the supervisor session. Use a generous - // 30s session-wait timeout because `/connect/ssh` is typically called - // immediately after `sandbox create`, so we need to cover the supervisor's - // initial TLS + gRPC handshake on a cold-started pod. The old - // direct-connect path tolerated ~34s here for similar reasons. - let (channel_id, relay_rx) = match state - .supervisor_sessions - .open_relay(&sandbox_id, Duration::from_secs(30)) - .await - { - Ok(pair) => pair, - Err(status) => { - warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); - decrement_connection_count(&state.ssh_connections_by_token, &token); - decrement_connection_count(&state.ssh_connections_by_sandbox, &sandbox_id); - return StatusCode::BAD_GATEWAY.into_response(); - } - }; - - let sandbox_id_clone = sandbox_id.clone(); - let token_clone = token.clone(); - let state_clone = state.clone(); - - let upgrade = hyper::upgrade::on(req); - tokio::spawn(async move { - // Wait for the supervisor to open its `RelayStream` and deliver the - // bridge half of the relay. - let mut relay = match tokio::time::timeout(Duration::from_secs(10), relay_rx).await { - Ok(Ok(stream)) => stream, - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay channel dropped"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay open timed out"); - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count( - &state_clone.ssh_connections_by_sandbox, - &sandbox_id_clone, - ); - return; - } - }; - - info!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay established, bridging client"); - - match upgrade.await { - Ok(upgraded) => { - let mut upgraded = TokioIo::new(upgraded); - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut relay).await; - let _ = AsyncWriteExt::shutdown(&mut upgraded).await; - } - Err(err) => { - warn!(error = %err, "SSH upgrade failed"); - } - } - - // Decrement connection counts on tunnel completion. - decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); - decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); - }); - - StatusCode::OK.into_response() -} - -fn header_value(headers: &http::HeaderMap, name: &str) -> Result { - let value = headers - .get(name) - .ok_or(StatusCode::UNAUTHORIZED)? - .to_str() - .map_err(|_| StatusCode::BAD_REQUEST)? - .trim() - .to_string(); - if value.is_empty() { - return Err(StatusCode::BAD_REQUEST); - } - Ok(value) -} - -impl ObjectType for SshSession { - fn object_type() -> &'static str { - "ssh_session" - } -} - -/// Decrement a connection count entry, removing it if it reaches zero. -fn decrement_connection_count( - counts: &std::sync::Mutex>, - key: &str, -) { - let mut map = counts.lock().unwrap(); - if let Some(count) = map.get_mut(key) { - *count = count.saturating_sub(1); - if *count == 0 { - map.remove(key); - } - } -} - -/// Spawn a background task that periodically reaps expired and revoked SSH sessions. -pub fn spawn_session_reaper(store: Arc, interval: Duration) { - tokio::spawn(async move { - // Initial delay to let startup settle. - tokio::time::sleep(interval).await; - - loop { - if let Err(e) = reap_expired_sessions(&store).await { - warn!(error = %e, "SSH session reaper sweep failed"); - } - tokio::time::sleep(interval).await; - } - }); -} - -async fn reap_expired_sessions(store: &Store) -> Result<(), String> { - let now_ms = i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(), - ) - .unwrap_or(i64::MAX); - - let records = store - .list(SshSession::object_type(), 1000, 0) - .await - .map_err(|e| e.to_string())?; - - let mut reaped = 0u32; - for record in records { - let session: SshSession = match Message::decode(record.payload.as_slice()) { - Ok(s) => s, - Err(_) => continue, - }; - - let should_delete = - // Expired sessions (expires_at_ms > 0 means expiry is set). - (session.expires_at_ms > 0 && now_ms > session.expires_at_ms) - // Revoked sessions — already invalidated, just cleaning up storage. - || session.revoked; - - if should_delete { - use openshell_core::ObjectId; - if let Err(e) = store - .delete(SshSession::object_type(), session.object_id()) - .await - { - warn!(session_id = %session.object_id(), error = %e, "Failed to reap SSH session"); - } else { - reaped += 1; - } - } - } - - if reaped > 0 { - info!(count = reaped, "SSH session reaper: cleaned up sessions"); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::persistence::Store; - use std::collections::HashMap; - use std::sync::Mutex; - - fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { - SshSession { - metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { - id: id.to_string(), - name: format!("session-{id}"), - created_at_ms: 1000, - labels: HashMap::new(), - }), - sandbox_id: sandbox_id.to_string(), - token: id.to_string(), - expires_at_ms, - revoked, - } - } - - fn now_ms() -> i64 { - i64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(), - ) - .unwrap_or(i64::MAX) - } - - // ---- Connection limit tests ---- - - #[test] - fn decrement_removes_entry_at_zero() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 1); - decrement_connection_count(&counts, "tok1"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn decrement_reduces_count() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts.lock().unwrap().insert("tok1".to_string(), 5); - decrement_connection_count(&counts, "tok1"); - assert_eq!(*counts.lock().unwrap().get("tok1").unwrap(), 4); - } - - #[test] - fn decrement_missing_key_is_noop() { - let counts: Mutex> = Mutex::new(HashMap::new()); - decrement_connection_count(&counts, "nonexistent"); - assert!(counts.lock().unwrap().is_empty()); - } - - #[test] - fn per_token_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - let current = *counts.lock().unwrap().get("tok1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_TOKEN); - } - - #[test] - fn per_sandbox_connection_limit_enforced() { - let counts: Mutex> = Mutex::new(HashMap::new()); - counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - let current = *counts.lock().unwrap().get("sbx1").unwrap(); - assert!(current >= MAX_CONNECTIONS_PER_SANDBOX); - } - - #[test] - fn acquire_connection_slots_rejects_per_token_limit_without_touching_sandbox() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - token_counts - .lock() - .unwrap() - .insert("tok1".to_string(), MAX_CONNECTIONS_PER_TOKEN); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerToken))); - assert!(sandbox_counts.lock().unwrap().is_empty()); - } - - #[test] - fn acquire_connection_slots_rolls_back_token_increment_on_sandbox_limit() { - let token_counts: Mutex> = Mutex::new(HashMap::new()); - let sandbox_counts: Mutex> = Mutex::new(HashMap::new()); - sandbox_counts - .lock() - .unwrap() - .insert("sbx1".to_string(), MAX_CONNECTIONS_PER_SANDBOX); - - let result = acquire_connection_slots(&token_counts, &sandbox_counts, "tok1", "sbx1"); - - assert!(matches!(result, Err(ConnectionLimit::PerSandbox))); - assert!(token_counts.lock().unwrap().is_empty()); - } - - // ---- Session reaper tests ---- - - #[tokio::test] - async fn reaper_deletes_expired_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let expired = make_session("expired1", "sbx1", now_ms() - 60_000, false); - store.put_message(&expired).await.unwrap(); - - let valid = make_session("valid1", "sbx1", now_ms() + 3_600_000, false); - store.put_message(&valid).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("expired1") - .await - .unwrap() - .is_none(), - "expired session should be reaped" - ); - assert!( - store - .get_message::("valid1") - .await - .unwrap() - .is_some(), - "valid session should be kept" - ); - } - - #[tokio::test] - async fn reaper_deletes_revoked_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - let revoked = make_session("revoked1", "sbx1", 0, true); - store.put_message(&revoked).await.unwrap(); - - let active = make_session("active1", "sbx1", 0, false); - store.put_message(&active).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("revoked1") - .await - .unwrap() - .is_none(), - "revoked session should be reaped" - ); - assert!( - store - .get_message::("active1") - .await - .unwrap() - .is_some(), - "active session should be kept" - ); - } - - #[tokio::test] - async fn reaper_preserves_zero_expiry_sessions() { - let store = Store::connect("sqlite::memory:?cache=shared") - .await - .unwrap(); - - // expires_at_ms = 0 means no expiry (backward compatible). - let no_expiry = make_session("noexpiry1", "sbx1", 0, false); - store.put_message(&no_expiry).await.unwrap(); - - reap_expired_sessions(&store).await.unwrap(); - - assert!( - store - .get_message::("noexpiry1") - .await - .unwrap() - .is_some(), - "session with no expiry should be preserved" - ); - } - - // ---- Expiry validation logic tests ---- - - #[test] - fn expired_session_is_detected() { - let session = make_session("tok1", "sbx1", now_ms() - 1000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(is_expired, "session in the past should be expired"); - } - - #[test] - fn future_session_is_not_expired() { - let session = make_session("tok1", "sbx1", now_ms() + 3_600_000, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!(!is_expired, "session in the future should not be expired"); - } - - #[test] - fn zero_expiry_is_not_expired() { - let session = make_session("tok1", "sbx1", 0, false); - let is_expired = session.expires_at_ms > 0 && now_ms() > session.expires_at_ms; - assert!( - !is_expired, - "session with zero expiry should never be expired" - ); - } -} diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 94c352ba5..4e943bcac 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -13,8 +13,8 @@ use tracing::{info, warn}; use uuid::Uuid; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SupervisorMessage, - gateway_message, supervisor_message, + GatewayMessage, RelayFrame, RelayInit, RelayOpen, Sandbox, SessionAccepted, SshRelayTarget, + SupervisorMessage, gateway_message, relay_open, supervisor_message, }; use crate::ServerState; @@ -58,8 +58,9 @@ struct LiveSession { connected_at: Instant, } -/// Holds a oneshot sender that will deliver the upgraded relay stream. -type RelayStreamSender = oneshot::Sender; +/// Holds a oneshot sender that will deliver the upgraded relay stream or a +/// target-open failure reported by the supervisor. +type RelayStreamSender = oneshot::Sender>; impl openshell_driver_docker::SupervisorReadiness for SupervisorSessionRegistry { fn is_supervisor_connected(&self, sandbox_id: &str) -> bool { @@ -79,6 +80,7 @@ pub struct SupervisorSessionRegistry { struct PendingRelay { sender: RelayStreamSender, sandbox_id: String, + relay_open: RelayOpen, created_at: Instant, } @@ -234,12 +236,45 @@ impl SupervisorSessionRegistry { &self, sandbox_id: &str, session_wait_timeout: Duration, - ) -> Result<(String, oneshot::Receiver), Status> { + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { + self.open_relay_with_target( + sandbox_id, + relay_open::Target::Ssh(SshRelayTarget {}), + "".to_string(), + session_wait_timeout, + ) + .await + } + + pub async fn open_relay_with_target( + &self, + sandbox_id: &str, + target: relay_open::Target, + service_id: String, + session_wait_timeout: Duration, + ) -> Result< + ( + String, + oneshot::Receiver>, + ), + Status, + > { let tx = self .wait_for_session(sandbox_id, session_wait_timeout) .await?; let channel_id = Uuid::new_v4().to_string(); + let relay_open = RelayOpen { + channel_id: channel_id.clone(), + target: Some(target), + service_id, + }; // Register the pending relay before sending RelayOpen to avoid a race. // Both caps are checked and the insert happens under a single lock hold @@ -267,15 +302,14 @@ impl SupervisorSessionRegistry { PendingRelay { sender: relay_tx, sandbox_id: sandbox_id.to_string(), + relay_open: relay_open.clone(), created_at: Instant::now(), }, ); } let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { @@ -287,7 +321,17 @@ impl SupervisorSessionRegistry { Ok((channel_id, relay_rx)) } - /// Claim a pending relay channel. Called by the `/relay/{channel_id}` HTTP handler + pub fn fail_pending_relay(&self, channel_id: &str, error: String) -> bool { + let pending = self.pending_relays.lock().unwrap().remove(channel_id); + if let Some(pending) = pending { + let _ = pending.sender.send(Err(Status::unavailable(error))); + true + } else { + false + } + } + + /// Claim a pending relay channel. Called by the /relay/{channel_id} HTTP handler /// when the supervisor's reverse CONNECT arrives. /// /// Returns the `DuplexStream` half that the supervisor side should read/write. @@ -308,8 +352,8 @@ impl SupervisorSessionRegistry { // the supervisor HTTP CONNECT handler. let (gateway_stream, supervisor_stream) = tokio::io::duplex(64 * 1024); - // Send the gateway-side stream to the waiter (ssh_tunnel or exec handler). - if pending.sender.send(gateway_stream).is_err() { + // Send the gateway-side stream to the waiter (exec handler or forward handler). + if pending.sender.send(Ok(gateway_stream)).is_err() { return Err(Status::internal("relay requester dropped")); } @@ -329,10 +373,17 @@ impl SupervisorSessionRegistry { pub async fn replay_pending_relays(&self, sandbox_id: &str, tx: &mpsc::Sender) { for channel_id in self.pending_channel_ids(sandbox_id) { + let relay_open = { + let pending = self.pending_relays.lock().unwrap(); + pending + .get(&channel_id) + .map(|pending| pending.relay_open.clone()) + }; + let Some(relay_open) = relay_open else { + continue; + }; let msg = GatewayMessage { - payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { - channel_id: channel_id.clone(), - })), + payload: Some(gateway_message::Payload::RelayOpen(relay_open)), }; if tx.send(msg).await.is_err() { warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "supervisor session: failed to replay pending relay to superseding session"); @@ -626,7 +677,7 @@ pub async fn handle_connect_supervisor( } async fn run_session_loop( - _state: &Arc, + state: &Arc, sandbox_id: &str, session_id: &str, tx: &mpsc::Sender, @@ -647,7 +698,7 @@ async fn run_session_loop( msg = inbound.message() => { match msg { Ok(Some(msg)) => { - handle_supervisor_message(sandbox_id, session_id, msg); + handle_supervisor_message(state, sandbox_id, session_id, msg); } Ok(None) => { info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: stream closed by supervisor"); @@ -674,7 +725,12 @@ async fn run_session_loop( } } -fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: SupervisorMessage) { +fn handle_supervisor_message( + state: &Arc, + sandbox_id: &str, + session_id: &str, + msg: SupervisorMessage, +) { match msg.payload { Some(supervisor_message::Payload::Heartbeat(_)) => { // Heartbeat received — nothing to do for now. @@ -688,11 +744,15 @@ fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: Supervisor "supervisor session: relay opened successfully" ); } else { + let failed = state + .supervisor_sessions + .fail_pending_relay(&result.channel_id, result.error.clone()); warn!( sandbox_id = %sandbox_id, session_id = %session_id, channel_id = %result.channel_id, error = %result.error, + pending_relay_failed = failed, "supervisor session: relay open failed" ); } @@ -745,6 +805,23 @@ mod tests { } } + fn pending_relay( + sandbox_id: &str, + relay_tx: RelayStreamSender, + created_at: Instant, + ) -> PendingRelay { + PendingRelay { + sender: relay_tx, + sandbox_id: sandbox_id.to_string(), + relay_open: RelayOpen { + channel_id: "ch-test".to_string(), + target: Some(relay_open::Target::Ssh(SshRelayTarget {})), + service_id: String::new(), + }, + created_at, + } + } + // ---- registry: register / remove ---- #[test] @@ -863,6 +940,7 @@ mod tests { match msg.payload { Some(gateway_message::Payload::RelayOpen(open)) => { assert_eq!(open.channel_id, channel_id); + assert!(matches!(open.target, Some(relay_open::Target::Ssh(_)))); } other => panic!("expected RelayOpen, got {other:?}"), } @@ -944,11 +1022,7 @@ mod tests { let sandbox_id = if i % 2 == 0 { "sbx-a" } else { "sbx-b" }; pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: sandbox_id.to_string(), - created_at: Instant::now(), - }, + pending_relay(sandbox_id, oneshot_tx, Instant::now()), ); } } @@ -973,11 +1047,7 @@ mod tests { let (oneshot_tx, _) = oneshot::channel(); pending.insert( format!("channel-{i}"), - PendingRelay { - sender: oneshot_tx, - sandbox_id: "sbx".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx", oneshot_tx, Instant::now()), ); } } @@ -1174,11 +1244,7 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let result = registry.claim_relay("ch-1"); @@ -1186,19 +1252,41 @@ mod tests { assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); } + #[tokio::test] + async fn relay_open_failure_completes_pending_waiter() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-fail".to_string(), + pending_relay("sbx-test", relay_tx, Instant::now()), + ); + + assert!(registry.fail_pending_relay("ch-fail", "target refused".to_string())); + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-fail") + ); + + let result = relay_rx.await.expect("failure should wake waiter"); + let status = result.expect_err("waiter should receive status failure"); + assert_eq!(status.code(), tonic::Code::Unavailable); + assert_eq!(status.message(), "target refused"); + } + #[test] fn claim_relay_expired_returns_deadline_exceeded() { let registry = SupervisorSessionRegistry::new(); let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() - .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + pending_relay( + "sbx-test", + relay_tx, + Instant::now() - Duration::from_secs(60), + ), ); let err = registry @@ -1218,15 +1306,11 @@ mod tests { #[test] fn claim_relay_receiver_dropped_returns_internal() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); drop(relay_rx); // Gateway-side waiter has given up already. registry.pending_relays.lock().unwrap().insert( "ch-1".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let err = registry @@ -1238,18 +1322,17 @@ mod tests { #[tokio::test] async fn claim_relay_connects_both_ends() { let registry = SupervisorSessionRegistry::new(); - let (relay_tx, relay_rx) = oneshot::channel::(); + let (relay_tx, relay_rx) = oneshot::channel::>(); registry.pending_relays.lock().unwrap().insert( "ch-io".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); let mut supervisor_side = registry.claim_relay("ch-io").expect("claim should succeed"); - let mut gateway_side = relay_rx.await.expect("gateway side should receive stream"); + let mut gateway_side = relay_rx + .await + .expect("gateway side should receive result") + .expect("gateway side should receive stream"); // Supervisor side writes → gateway side reads. supervisor_side.write_all(b"hello").await.unwrap(); @@ -1272,13 +1355,11 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-old".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now() - .checked_sub(Duration::from_secs(60)) - .expect("test instant subtraction underflow"), - }, + pending_relay( + "sbx-test", + relay_tx, + Instant::now() - Duration::from_secs(60), + ), ); registry.reap_expired_relays(); @@ -1297,11 +1378,7 @@ mod tests { let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( "ch-fresh".to_string(), - PendingRelay { - sender: relay_tx, - sandbox_id: "sbx-test".to_string(), - created_at: Instant::now(), - }, + pending_relay("sbx-test", relay_tx, Instant::now()), ); registry.reap_expired_relays(); diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index c66f2ad6b..f160f98b8 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -754,6 +754,21 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { ) -> Result, tonic::Status> { Err(tonic::Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = std::pin::Pin< + Box< + dyn tokio_stream::Stream< + Item = Result, + > + Send, + >, + >; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } /// Test 7: Plaintext server (no TLS) accepts both gRPC and HTTP. diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 706967d1f..689cfcf59 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -42,9 +42,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -379,14 +379,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index d5631319d..9cab950db 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -16,9 +16,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -347,14 +347,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } #[tokio::test] diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index c4f68eaf4..21b75c12c 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -18,9 +18,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -360,14 +360,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } /// PKI bundle: CA cert, server cert+key, client cert+key. diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index 8f5cac03a..2d722b051 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -23,7 +23,7 @@ use hyper_util::{ server::conn::auto::Builder, }; use openshell_core::proto::{ - GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, + GatewayMessage, RelayFrame, RelayInit, SupervisorMessage, TcpForwardFrame, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -87,6 +87,15 @@ impl OpenShell for RelayGateway { Err(Status::unimplemented("unused")) } + type ForwardTcpStream = + std::pin::Pin> + Send>>; + async fn forward_tcp( + &self, + _: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn health( &self, _: tonic::Request, @@ -439,7 +448,7 @@ async fn relay_round_trips_bytes() { tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, mut write_half) = tokio::io::split(relay); write_half.write_all(b"hello relay").await.expect("write"); @@ -464,7 +473,7 @@ async fn relay_closes_cleanly_when_gateway_drops() { let supervisor = tokio::spawn(run_echo_supervisor(channel, channel_id)); - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); drop(relay); // The supervisor's inbound stream should terminate shortly after the @@ -509,7 +518,7 @@ async fn relay_sees_eof_when_supervisor_closes() { }) }; - let relay = relay_rx.await.expect("relay duplex"); + let relay = relay_rx.await.expect("relay result").expect("relay duplex"); let (mut read_half, _write_half) = tokio::io::split(relay); let mut buf = [0u8; 16]; let n = tokio::time::timeout(Duration::from_secs(5), read_half.read(&mut buf)) @@ -555,8 +564,8 @@ async fn concurrent_relays_multiplex_independently() { tokio::spawn(run_echo_supervisor(channel.clone(), id_a)); tokio::spawn(run_echo_supervisor(channel, id_b)); - let relay_a = rx_a.await.expect("relay a"); - let relay_b = rx_b.await.expect("relay b"); + let relay_a = rx_a.await.expect("relay a result").expect("relay a"); + let relay_b = rx_b.await.expect("relay b result").expect("relay b"); let (mut ra, mut wa) = tokio::io::split(relay_a); let (mut rb, mut wb) = tokio::io::split(relay_b); diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 8212b1085..14d5e9bb7 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -45,9 +45,9 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, + ProviderResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, + SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, + UpdateProviderRequest, WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -373,14 +373,24 @@ impl OpenShell for TestOpenShell { Err(Status::unimplemented("not implemented in test")) } - type RelayStreamStream = ReceiverStream>; + type RelayStreamStream = ReceiverStream>; async fn relay_stream( &self, - _request: tonic::Request>, + _request: tonic::Request>, ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + type ForwardTcpStream = + std::pin::Pin> + Send>>; + + async fn forward_tcp( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 8571ebbe1..b96c0abbf 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -839,10 +839,7 @@ async fn handle_shell_connect( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Step 4: Build the ProxyCommand using our own binary. let exe = match std::env::current_exe() { @@ -988,10 +985,7 @@ async fn handle_exec_command( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, &app.endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); let exe = match std::env::current_exe() { Ok(p) => p, @@ -1080,7 +1074,8 @@ async fn handle_exec_command( // SSH utility functions are shared via openshell_core::forward. use openshell_core::forward::{ - build_proxy_command, resolve_ssh_gateway, shell_escape, validate_ssh_session_response, + build_proxy_command, format_gateway_url, resolve_ssh_gateway, shell_escape, + validate_ssh_session_response, }; /// Convert a `SandboxPolicy` proto into styled ratatui lines for the policy viewer. @@ -1424,10 +1419,7 @@ async fn start_port_forwards( let gateway_port_u16 = session.gateway_port as u16; let (gateway_host, gateway_port) = resolve_ssh_gateway(&session.gateway_host, gateway_port_u16, endpoint); - let gateway_url = format!( - "{}://{}:{gateway_port}{}", - session.gateway_scheme, gateway_host, session.connect_path - ); + let gateway_url = format_gateway_url(&session.gateway_scheme, &gateway_host, gateway_port); // Build ProxyCommand. let exe = match std::env::current_exe() { diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index b43a6846a..6763217d6 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -200,6 +200,14 @@ openshell forward list openshell forward stop 8000 my-sandbox ``` +Use the gRPC service-forwarding path when you want to test the OS-88 service relay path without SSH port forwarding: + +```shell +openshell service forward my-sandbox --target-port 8000 --local 8000 +``` + +This binds a local listener and opens one authenticated gRPC stream to the gateway for each accepted local TCP connection. The target must be a loopback TCP service inside the sandbox. Use `--local 127.0.0.1:0` to let OpenShell choose a free local port. + You can also forward a port at creation time with `--forward`: diff --git a/proto/openshell.proto b/proto/openshell.proto index b0291254a..883c1576c 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -54,6 +54,9 @@ service OpenShell { // Execute a command in a ready sandbox and stream output. rpc ExecSandbox(ExecSandboxRequest) returns (stream ExecSandboxEvent); + // Forward one CLI-side TCP connection to a loopback TCP target in a sandbox. + rpc ForwardTcp(stream TcpForwardFrame) returns (stream TcpForwardFrame); + // Create a provider. rpc CreateProvider(CreateProviderRequest) returns (ProviderResponse); @@ -127,8 +130,9 @@ service OpenShell { // // The supervisor opens this stream at startup and keeps it alive for the // sandbox lifetime. The gateway uses it to coordinate relay channels for - // SSH connect and ExecSandbox. Raw SSH bytes flow over RelayStream calls - // (separate HTTP/2 streams on the same connection), not over this stream. + // SSH connect, ExecSandbox, and targetable sandbox services. Raw service + // bytes flow over RelayStream calls (separate HTTP/2 streams on the same + // connection), not over this stream. rpc ConnectSupervisor(stream SupervisorMessage) returns (stream GatewayMessage); // Raw byte relay between supervisor and gateway. @@ -137,8 +141,8 @@ service OpenShell { // on its ConnectSupervisor stream. The first RelayFrame carries a // RelayInit with the channel_id to associate the new HTTP/2 stream with // the pending relay slot on the gateway. Subsequent frames carry raw bytes in either - // direction between the gateway-side waiter (ssh_tunnel / exec handler) - // and the supervisor-side local SSH daemon bridge. + // direction between the gateway-side waiter (ForwardTcp / exec handler) + // and the supervisor-side target bridge. // // This rides the same TCP+TLS+HTTP/2 connection as ConnectSupervisor — // no new TLS handshake, no reverse HTTP CONNECT. @@ -446,11 +450,6 @@ message CreateSshSessionResponse { // Gateway scheme. Must be exactly "http" or "https". string gateway_scheme = 5; - // HTTP path for the CONNECT/upgrade endpoint. Must begin with `/`. RFC - // 3986 path charset only ([A-Za-z0-9._~!$&'()*+,;=:@/-] plus %HH). - // Must not contain `?`, `#`, whitespace, backtick, or backslash. - string connect_path = 6; - // Optional host key fingerprint. If non-empty, [A-Za-z0-9:+/=-] only. string host_key_fingerprint = 7; @@ -518,6 +517,30 @@ message ExecSandboxEvent { } } +// Initial frame for one TCP forward stream. +message TcpForwardInit { + // Sandbox id. + string sandbox_id = 1; + // Optional service identifier for audit/correlation. + string service_id = 4; + // Target the gateway should request from the supervisor. + oneof target { + SshRelayTarget ssh = 5; + TcpRelayTarget tcp = 6; + } + // Optional target-specific authorization token. SSH targets use this as the + // short-lived SSH session token issued by CreateSshSession. + string authorization_token = 7; +} + +// A single frame on the CLI-to-gateway TCP forward stream. +message TcpForwardFrame { + oneof payload { + TcpForwardInit init = 1; + bytes data = 2; + } +} + // SSH session record stored in persistence. message SshSession { // Kubernetes-style metadata (id, name, labels, timestamps, resource version). @@ -1030,10 +1053,29 @@ message GatewayHeartbeat {} // On receiving this, the supervisor should initiate a RelayStream RPC to // the gateway, sending a RelayInit in the first RelayFrame to associate // the new HTTP/2 stream with the pending relay slot. The supervisor -// bridges that stream to the local SSH daemon. +// bridges that stream to the requested local target. message RelayOpen { // Gateway-allocated channel identifier (UUID). string channel_id = 1; + // Target the supervisor should dial inside the sandbox. + // If absent, supervisors treat the relay as SSH for compatibility. + oneof target { + SshRelayTarget ssh = 2; + TcpRelayTarget tcp = 3; + } + // Optional service identifier for audit/correlation. + string service_id = 5; +} + +// Built-in SSH relay target. +message SshRelayTarget {} + +// TCP target dialed by the supervisor from inside the sandbox. +message TcpRelayTarget { + // Phase 1 accepts loopback only: 127.0.0.1, ::1, or localhost. + string host = 1; + // Target port. Must fit in u16 and be non-zero. + uint32 port = 2; } // Initial RelayStream frame sent by the supervisor to claim a pending relay. From 9af8de3547b64c66adbc2add93665a8448c1a16e Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Wed, 29 Apr 2026 20:32:24 -0700 Subject: [PATCH 2/4] fix(cli): use forward command for service forwarding --- .../skills/debug-openshell-cluster/SKILL.md | 4 +- crates/openshell-cli/src/main.rs | 75 ++++++------------- crates/openshell-cli/src/run.rs | 13 ++-- crates/openshell-server/src/grpc/sandbox.rs | 4 +- docs/sandboxes/manage-sandboxes.mdx | 2 +- 5 files changed, 35 insertions(+), 63 deletions(-) diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 64f8bd83d..b7b2c898c 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -128,9 +128,9 @@ helm -n openshell get values openshell | grep -E 'repository|tag|supervisorImage The gateway image and `server.supervisorImage` should use the same build tag in branch and E2E deploys. A stale supervisor image can make sandbox behavior lag behind gateway policy or proto changes. -For local/external pull mode (the default local path via `mise run cluster`), local images are tagged to the configured local registry base, pushed to that registry, and pulled by k3s via the `registries.yaml` mirror endpoint. The `cluster` task rebuilds the local gateway image before tagging and pushing it, so a fresh bootstrap should not reuse stale `openshell/gateway:dev` bits from a previous source revision. +For local/external pull mode (the default local path via `mise run cluster`), local images are tagged to the configured local registry base, pushed to that registry, and pulled by k3s via the `registries.yaml` mirror endpoint. The `cluster` task pushes prebuilt local tags (`openshell/*:dev`, falling back to `localhost:5000/openshell/*:dev` or `127.0.0.1:5000/openshell/*:dev`). -Gateway image builds stage a partial Rust workspace from `deploy/docker/Dockerfile.images`. If cargo fails with a missing manifest under `/build/crates/...`, or an imported symbol exists locally but is missing in the image build, verify that every current gateway dependency crate is copied into the staged workspace there. +Gateway image builds stage a partial Rust workspace from `deploy/docker/Dockerfile.images`. If cargo fails with a missing manifest under `/build/crates/...`, or an imported symbol exists locally but is missing in the image build, verify that every current gateway dependency crate, including `openshell-driver-docker`, `openshell-driver-kubernetes`, and `openshell-ocsf`, is copied into the staged workspace there. For plaintext local evaluation, confirm the chart has: diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 4dc5c588b..5fa2dec02 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -199,7 +199,6 @@ const HELP_TEMPLATE: &str = "\ \x1b[1mSANDBOX COMMANDS\x1b[0m sandbox: Manage sandboxes forward: Manage port forwarding to a sandbox - service: Forward sandbox services over gRPC logs: View sandbox logs policy: Manage sandbox policy settings: Manage sandbox and global settings @@ -268,16 +267,11 @@ const FORWARD_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m \x1b[1mEXAMPLES\x1b[0m $ openshell forward start 8080 $ openshell forward start 3000 my-sandbox + $ openshell forward service my-sandbox --target-port 8000 --local 8000 $ openshell forward stop 8080 $ openshell forward list "; -const SERVICE_EXAMPLES: &str = "\x1b[1mEXAMPLES\x1b[0m - $ openshell service forward my-sandbox --target-port 8080 - $ openshell service forward my-sandbox --target-port 5432 --local 15432 - $ openshell service forward my-sandbox --target-port 3000 --local 127.0.0.1:0 -"; - const LOGS_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m lg @@ -415,13 +409,6 @@ enum Commands { command: Option, }, - /// Forward sandbox services over gRPC. - #[command(after_help = SERVICE_EXAMPLES, help_template = SUBCOMMAND_HELP_TEMPLATE)] - Service { - #[command(subcommand)] - command: Option, - }, - /// View sandbox logs. #[command(alias = "lg", after_help = LOGS_EXAMPLES, help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Logs { @@ -1627,13 +1614,10 @@ enum ForwardCommands { /// List active port forwards. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] List, -} -#[derive(Subcommand, Debug)] -enum ServiceCommands { /// Forward a local TCP port to a loopback service inside a sandbox over gRPC. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] - Forward { + Service { /// Sandbox name (defaults to last-used sandbox). #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, @@ -1646,7 +1630,7 @@ enum ServiceCommands { #[arg(long, default_value = "127.0.0.1")] target_host: String, - /// Local bind address and port: [bind_address:]port. Use port 0 for dynamic assignment. + /// Local bind address and port: [bind_address:]port. Defaults to the target port. Use port 0 for dynamic assignment. #[arg(long)] local: Option, }, @@ -1816,38 +1800,6 @@ async fn main() -> Result<()> { } } - Some(Commands::Service { - command: - Some(ServiceCommands::Forward { - name, - target_port, - target_host, - local, - }), - }) => { - let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; - let mut tls = tls.with_gateway_name(&ctx.name); - apply_edge_auth(&mut tls, &ctx.name); - let name = resolve_sandbox_name(name, &ctx.name)?; - run::service_forward_tcp( - &ctx.endpoint, - &name, - local.as_deref(), - &target_host, - target_port, - &tls, - ) - .await?; - } - - Some(Commands::Service { command: None }) => { - Cli::command() - .find_subcommand_mut("service") - .expect("service subcommand exists") - .print_help() - .expect("Failed to print help"); - } - // ----------------------------------------------------------- // Top-level forward (was `sandbox forward`) // ----------------------------------------------------------- @@ -1924,6 +1876,27 @@ async fn main() -> Result<()> { } } } + ForwardCommands::Service { + name, + target_port, + target_host, + local, + } => { + let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; + let mut tls = tls.with_gateway_name(&ctx.name); + apply_auth(&mut tls, &ctx.name); + let name = resolve_sandbox_name(name, &ctx.name)?; + let local = local.unwrap_or_else(|| target_port.to_string()); + run::service_forward_tcp( + &ctx.endpoint, + &name, + Some(&local), + &target_host, + target_port, + &tls, + ) + .await?; + } ForwardCommands::Start { port, name, diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index e0888a9cf..d4c8aab9e 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -26,13 +26,12 @@ use openshell_bootstrap::{ use openshell_core::proto::ProviderProfileCategory; use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, AttachSandboxProviderRequest, - ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, - CreateSshSessionRequest, DeleteProviderProfileRequest, DeleteProviderRequest, - DeleteSandboxRequest, DetachSandboxProviderRequest, ExecSandboxRequest, - GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, - GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest, - GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, - GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, + ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, + DeleteProviderProfileRequest, DeleteProviderRequest, DeleteSandboxRequest, + DetachSandboxProviderRequest, ExecSandboxRequest, GetClusterInferenceRequest, + GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, + GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, + GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 7786a0b13..9417b2971 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -20,8 +20,8 @@ use openshell_core::proto::{ ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxStreamEvent, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, WatchSandboxRequest, - relay_open, tcp_forward_init, + SandboxResponse, SandboxStreamEvent, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, + WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 6763217d6..66d9ef690 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -203,7 +203,7 @@ openshell forward stop 8000 my-sandbox Use the gRPC service-forwarding path when you want to test the OS-88 service relay path without SSH port forwarding: ```shell -openshell service forward my-sandbox --target-port 8000 --local 8000 +openshell forward service my-sandbox --target-port 8000 --local 8000 ``` This binds a local listener and opens one authenticated gRPC stream to the gateway for each accepted local TCP connection. The target must be a loopback TCP service inside the sandbox. Use `--local 127.0.0.1:0` to let OpenShell choose a free local port. From 95422ee9637b8773b466e574a38893d2c5336b05 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Mon, 11 May 2026 20:34:42 -0700 Subject: [PATCH 3/4] fix(auth): authorize ForwardTcp scoped access --- crates/openshell-cli/src/main.rs | 2 +- crates/openshell-cli/src/run.rs | 11 ++++----- crates/openshell-cli/src/ssh.rs | 10 +++----- .../tests/ensure_providers_integration.rs | 6 ++--- .../openshell-cli/tests/mtls_integration.rs | 6 ++--- .../tests/provider_commands_integration.rs | 6 ++--- .../sandbox_create_lifecycle_integration.rs | 6 ++--- .../sandbox_name_fallback_integration.rs | 6 ++--- .../openshell-driver-kubernetes/src/driver.rs | 14 ++++++++--- .../src/supervisor_session.rs | 13 ++++++----- crates/openshell-server/src/auth/authz.rs | 6 +++++ crates/openshell-server/src/grpc/sandbox.rs | 23 +++++++++---------- crates/openshell-server/src/ssh_sessions.rs | 20 +++++++++------- .../src/supervisor_session.rs | 12 ++++++---- 14 files changed, 78 insertions(+), 63 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 5fa2dec02..e370d1f27 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1630,7 +1630,7 @@ enum ForwardCommands { #[arg(long, default_value = "127.0.0.1")] target_host: String, - /// Local bind address and port: [bind_address:]port. Defaults to the target port. Use port 0 for dynamic assignment. + /// Local bind address and port: `[bind_address:]port`. Defaults to the target port. Use port 0 for dynamic assignment. #[arg(long)] local: Option, }, diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index d4c8aab9e..2797bd66c 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -2718,13 +2718,10 @@ async fn drain_and_shutdown_local_socket(mut socket: tokio::net::TcpStream) { use tokio::io::{AsyncReadExt, AsyncWriteExt}; let mut buf = [0u8; 4096]; - loop { - match tokio::time::timeout(Duration::from_millis(25), socket.read(&mut buf)).await { - Ok(Ok(0)) | Err(_) => break, - Ok(Ok(_)) => continue, - Ok(Err(_)) => break, - } - } + while matches!( + tokio::time::timeout(Duration::from_millis(25), socket.read(&mut buf)).await, + Ok(Ok(n)) if n != 0 + ) {} let _ = socket.shutdown().await; } diff --git a/crates/openshell-cli/src/ssh.rs b/crates/openshell-cli/src/ssh.rs index 65e8605f1..e99e6ee15 100644 --- a/crates/openshell-cli/src/ssh.rs +++ b/crates/openshell-cli/src/ssh.rs @@ -846,10 +846,7 @@ pub async fn sandbox_ssh_proxy( let to_remote = tokio::spawn(async move { let mut stdin = stdin; let mut buf = vec![0u8; 64 * 1024]; - loop { - let Ok(n) = stdin.read(&mut buf).await else { - break; - }; + while let Ok(n) = stdin.read(&mut buf).await { if n == 0 { break; } @@ -869,9 +866,8 @@ pub async fn sandbox_ssh_proxy( let from_remote = tokio::spawn(async move { let mut stdout = stdout; loop { - let frame = match response.message().await { - Ok(Some(frame)) => frame, - Ok(None) | Err(_) => break, + let Ok(Some(frame)) = response.message().await else { + break; }; let Some(openshell_core::proto::tcp_forward_frame::Payload::Data(data)) = frame.payload else { diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 5d8d8f1b3..f1a11e661 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -518,14 +518,14 @@ impl OpenShell for TestOpenShell { } type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< - Result, + Result, >; async fn forward_tcp( &self, _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) } } diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index a728643a8..c95e2cf98 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -409,14 +409,14 @@ impl OpenShell for TestOpenShell { } type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< - Result, + Result, >; async fn forward_tcp( &self, _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) } } diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 16c0b97b1..fbe824cbf 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -627,14 +627,14 @@ impl OpenShell for TestOpenShell { } type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< - Result, + Result, >; async fn forward_tcp( &self, _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) } } diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index da18e79d1..a2fedab82 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -492,14 +492,14 @@ impl OpenShell for TestOpenShell { } type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< - Result, + Result, >; async fn forward_tcp( &self, _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) } } diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 629421f59..94d5b3cfa 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -430,14 +430,14 @@ impl OpenShell for TestOpenShell { } type ForwardTcpStream = tokio_stream::wrappers::ReceiverStream< - Result, + Result, >; async fn forward_tcp( &self, _request: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented in test")) + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) } } diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index a6107f907..cde1f4b22 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -2308,10 +2308,18 @@ mod tests { #[test] fn log_level_propagates_as_env_var_to_sandbox_pod() { - let spec = SandboxSpec { log_level: "debug".to_string(), ..SandboxSpec::default() }; + let spec = SandboxSpec { + log_level: "debug".to_string(), + ..SandboxSpec::default() + }; let cr = sandbox_to_k8s_spec(Some(&spec), &SandboxPodParams::default()); - let env = cr["spec"]["podTemplate"]["spec"]["containers"][0]["env"].as_array().unwrap(); - assert!(env.iter().any(|e| e["name"] == "OPENSHELL_LOG_LEVEL" && e["value"] == "debug")); + let env = cr["spec"]["podTemplate"]["spec"]["containers"][0]["env"] + .as_array() + .unwrap(); + assert!( + env.iter() + .any(|e| e["name"] == "OPENSHELL_LOG_LEVEL" && e["value"] == "debug") + ); assert!(cr["spec"].get("logLevel").is_none()); } } diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs index 49c52f9c2..6485dddf0 100644 --- a/crates/openshell-sandbox/src/supervisor_session.rs +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -101,11 +101,10 @@ fn relay_target_endpoint(open: &RelayOpen) -> Option { }; let host = target.host.trim(); let port = u16::try_from(target.port).ok()?; - if let Ok(ip) = host.parse() { - Some(Endpoint::from_ip(ip, port)) - } else { - Some(Endpoint::from_domain(host, port)) - } + host.parse().map_or_else( + |_| Some(Endpoint::from_domain(host, port)), + |ip| Some(Endpoint::from_ip(ip, port)), + ) } fn relay_target_kind(open: &RelayOpen) -> &'static str { @@ -750,7 +749,9 @@ mod ocsf_event_tests { fn ssh_relay_open(channel_id: &str) -> RelayOpen { RelayOpen { channel_id: channel_id.to_string(), - target: Some(relay_open::Target::Ssh(Default::default())), + target: Some(relay_open::Target::Ssh( + openshell_core::proto::SshRelayTarget::default(), + )), service_id: String::new(), } } diff --git a/crates/openshell-server/src/auth/authz.rs b/crates/openshell-server/src/auth/authz.rs index 7e69b1cd8..70d9d738c 100644 --- a/crates/openshell-server/src/auth/authz.rs +++ b/crates/openshell-server/src/auth/authz.rs @@ -59,6 +59,7 @@ const SCOPED_METHODS: &[(&str, &str)] = &[ ("/openshell.v1.OpenShell/CreateSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/DeleteSandbox", "sandbox:write"), ("/openshell.v1.OpenShell/ExecSandbox", "sandbox:write"), + ("/openshell.v1.OpenShell/ForwardTcp", "sandbox:write"), ("/openshell.v1.OpenShell/CreateSshSession", "sandbox:write"), ("/openshell.v1.OpenShell/RevokeSshSession", "sandbox:write"), ( @@ -420,6 +421,11 @@ mod tests { .check(&id, "/openshell.v1.OpenShell/CreateSandbox") .is_ok() ); + assert!( + policy + .check(&id, "/openshell.v1.OpenShell/ForwardTcp") + .is_ok() + ); assert!( policy .check(&id, "/openshell.v1.OpenShell/AttachSandboxProvider") diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 9417b2971..ad37a5482 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -20,8 +20,8 @@ use openshell_core::proto::{ ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, - WatchSandboxRequest, relay_open, tcp_forward_init, + SandboxResponse, SandboxStreamEvent, SshRelayTarget, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; @@ -730,13 +730,10 @@ pub(super) async fn handle_forward_tcp( .message() .await? .ok_or_else(|| Status::invalid_argument("empty ForwardTcp stream"))?; - let init = match first.payload { - Some(openshell_core::proto::tcp_forward_frame::Payload::Init(init)) => init, - _ => { - return Err(Status::invalid_argument( - "first TcpForwardFrame must be init", - )); - } + let Some(openshell_core::proto::tcp_forward_frame::Payload::Init(init)) = first.payload else { + return Err(Status::invalid_argument( + "first TcpForwardFrame must be init", + )); }; let target = validate_tcp_forward_init(&init)?; @@ -930,7 +927,9 @@ fn validate_tcp_forward_init(init: &TcpForwardInit) -> Result Ok(relay_open::Target::Ssh(Default::default())), + tcp_forward_init::Target::Ssh(_) => { + Ok(relay_open::Target::Ssh(SshRelayTarget::default())) + } tcp_forward_init::Target::Tcp(target) => Ok(relay_open::Target::Tcp( validate_tcp_forward_target(target)?, )), @@ -1570,12 +1569,12 @@ mod tests { fn tcp_forward_init_allows_ssh_target() { let init = TcpForwardInit { sandbox_id: "sbx".to_string(), - target: Some(tcp_forward_init::Target::Ssh(Default::default())), + target: Some(tcp_forward_init::Target::Ssh(SshRelayTarget::default())), ..Default::default() }; match validate_tcp_forward_init(&init).expect("ssh target should pass") { relay_open::Target::Ssh(_) => {} - other => panic!("expected SSH target, got {other:?}"), + other @ relay_open::Target::Tcp(_) => panic!("expected SSH target, got {other:?}"), } } diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs index f8d85033d..c3294b361 100644 --- a/crates/openshell-server/src/ssh_sessions.rs +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -33,10 +33,7 @@ pub fn spawn_session_reaper(store: Arc, interval: Duration) { } async fn reap_expired_sessions(store: &Store) -> Result<(), String> { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as i64; + let now_ms = unix_epoch_millis(); let records = store .list(SshSession::object_type(), 1000, 0) @@ -71,6 +68,16 @@ async fn reap_expired_sessions(store: &Store) -> Result<(), String> { Ok(()) } +fn unix_epoch_millis() -> i64 { + i64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(), + ) + .unwrap_or(i64::MAX) +} + #[cfg(test)] mod tests { use super::*; @@ -92,10 +99,7 @@ mod tests { } fn now_ms() -> i64 { - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as i64 + unix_epoch_millis() } #[tokio::test] diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 4e943bcac..19d358826 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -246,7 +246,7 @@ impl SupervisorSessionRegistry { self.open_relay_with_target( sandbox_id, relay_open::Target::Ssh(SshRelayTarget {}), - "".to_string(), + String::new(), session_wait_timeout, ) .await @@ -331,7 +331,7 @@ impl SupervisorSessionRegistry { } } - /// Claim a pending relay channel. Called by the /relay/{channel_id} HTTP handler + /// Claim a pending relay channel. Called by the `/relay/{channel_id}` HTTP handler /// when the supervisor's reverse CONNECT arrives. /// /// Returns the `DuplexStream` half that the supervisor side should read/write. @@ -1285,7 +1285,9 @@ mod tests { pending_relay( "sbx-test", relay_tx, - Instant::now() - Duration::from_secs(60), + Instant::now() + .checked_sub(Duration::from_secs(60)) + .expect("test duration should be before now"), ), ); @@ -1358,7 +1360,9 @@ mod tests { pending_relay( "sbx-test", relay_tx, - Instant::now() - Duration::from_secs(60), + Instant::now() + .checked_sub(Duration::from_secs(60)) + .expect("test duration should be before now"), ), ); From b59c6edd0f24a1f1e1b3d567e4c79593324d6d45 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Mon, 11 May 2026 21:20:59 -0700 Subject: [PATCH 4/4] docs(sandboxes): reset manage sandboxes page --- docs/sandboxes/manage-sandboxes.mdx | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 66d9ef690..b43a6846a 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -200,14 +200,6 @@ openshell forward list openshell forward stop 8000 my-sandbox ``` -Use the gRPC service-forwarding path when you want to test the OS-88 service relay path without SSH port forwarding: - -```shell -openshell forward service my-sandbox --target-port 8000 --local 8000 -``` - -This binds a local listener and opens one authenticated gRPC stream to the gateway for each accepted local TCP connection. The target must be a loopback TCP service inside the sandbox. Use `--local 127.0.0.1:0` to let OpenShell choose a free local port. - You can also forward a port at creation time with `--forward`: