diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c8f59c0..4e8673b 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -127,7 +127,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chainlink-data-streams-report" -version = "1.2.1" +version = "1.2.2" dependencies = [ "hex", "num-bigint", @@ -139,7 +139,7 @@ dependencies = [ [[package]] name = "chainlink-data-streams-sdk" -version = "1.2.1" +version = "1.2.2" dependencies = [ "byteorder", "chainlink-data-streams-report", diff --git a/rust/README.md b/rust/README.md index defcdad..e6f465b 100644 --- a/rust/README.md +++ b/rust/README.md @@ -20,8 +20,8 @@ Add the following to your `Cargo.toml`: ```toml [dependencies] -chainlink-data-streams-report = "1.2.1" -chainlink-data-streams-sdk = { version = "1.2.1", features = ["full"] } +chainlink-data-streams-report = "1.2.2" +chainlink-data-streams-sdk = { version = "1.2.2", features = ["full"] } ``` #### Features @@ -110,8 +110,8 @@ async fn main() -> Result<(), Box> { let api_key = "YOUR_API_KEY_GOES_HERE"; let user_secret = "YOUR_USER_SECRET_GOES_HERE"; - let rest_url = "https://api.testnet-dataengine.chain.link"; - let ws_url = "wss://ws.testnet-dataengine.chain.link"; + let rest_url = "https://api.dataengine.chain.link"; + let ws_url = "wss://ws.dataengine.chain.link"; let eth_usd_feed_id = ID::from_hex_str("0x000359843a543ee2fe414dc14c7e7920ef10f4372990b79d6361cdc0dd1ba782") diff --git a/rust/crates/report/Cargo.toml b/rust/crates/report/Cargo.toml index 718a87d..deb3510 100644 --- a/rust/crates/report/Cargo.toml +++ b/rust/crates/report/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chainlink-data-streams-report" -version = "1.2.1" +version = "1.2.2" edition = "2021" description = "Chainlink Data Streams Report" license = "MIT" diff --git a/rust/crates/sdk/Cargo.toml b/rust/crates/sdk/Cargo.toml index a2fc2d6..0320947 100644 --- a/rust/crates/sdk/Cargo.toml +++ b/rust/crates/sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chainlink-data-streams-sdk" -version = "1.2.1" +version = "1.2.2" edition = "2021" rust-version = "1.70" description = "Chainlink Data Streams client SDK" @@ -11,7 +11,7 @@ exclude = ["/target/*", "examples/*", "tests/*", "docs/*", "book/*"] keywords = ["chainlink"] [dependencies] -chainlink-data-streams-report = { path = "../report", version = "1.2.1" } +chainlink-data-streams-report = { path = "../report", version = "1.2.2" } reqwest = { version = "0.11.20", features = ["json", "rustls-tls"] } tokio = { version = "1.29.1", features = ["full"] } tokio-tungstenite = { version = "0.20.1", features = [ diff --git a/rust/crates/sdk/examples/wss_multiple.rs b/rust/crates/sdk/examples/wss_multiple.rs index fd19dc8..63b6026 100644 --- a/rust/crates/sdk/examples/wss_multiple.rs +++ b/rust/crates/sdk/examples/wss_multiple.rs @@ -14,14 +14,14 @@ async fn main() -> Result<(), Box> { let api_key = "YOUR_API_KEY_GOES_HERE"; let user_secret = "YOUR_USER_SECRET_GOES_HERE"; - let rest_url = "https://api.testnet-dataengine.chain.link"; - let ws_url = "wss://ws.testnet-dataengine.chain.link,wss://ws.testnet-dataengine.chain.link"; + let rest_url = "https://api.dataengine.chain.link"; + let ws_url = "wss://ws.dataengine.chain.link"; let eth_usd_feed_id = - ID::from_hex_str("0x000359843a543ee2fe414dc14c7e7920ef10f4372990b79d6361cdc0dd1ba782") + ID::from_hex_str("0x000362205e10b3a147d02792eccee483dca6c7b44ecce7012cb8c6e0b68b3ae9") .unwrap(); let btc_usd_feed_id: ID = - ID::from_hex_str("0x00037da06d56d083fe599397a4769a042d63aa73dc4ef57709d31e9971a5b439") + ID::from_hex_str("0x00039d9e45394f473ab1f050a1b963e6b05351e52d71e507509ada0c95ed75b8") .unwrap(); let feed_ids = vec![eth_usd_feed_id, btc_usd_feed_id]; diff --git a/rust/crates/sdk/src/config.rs b/rust/crates/sdk/src/config.rs index e55a551..0b841e7 100644 --- a/rust/crates/sdk/src/config.rs +++ b/rust/crates/sdk/src/config.rs @@ -108,13 +108,13 @@ impl Config { /// .build()?; /// /// // If you want to customize the configuration further, use the builder pattern - /// let ws_urls_multiple = "wss://api.testnet-dataengine.chain.link/ws,wss://api.testnet-dataengine.chain.link/ws"; - /// + /// // In HA mode, provide a single WebSocket URL — origins are discovered automatically + /// // via a HEAD request to the server (X-Cll-Available-Origins header). /// let config_custom = Config::new( /// api_key.to_string(), /// user_secret.to_string(), /// rest_url.to_string(), - /// ws_urls_multiple.to_string(), + /// ws_url.to_string(), /// ) /// .with_ws_ha(WebSocketHighAvailability::Enabled) // Enable WebSocket High Availability Mode /// .with_ws_max_reconnect(10) // Set maximum reconnection attempts to 10, instead of the default 5. diff --git a/rust/crates/sdk/src/endpoints.rs b/rust/crates/sdk/src/endpoints.rs index 43e46db..c61cd7e 100644 --- a/rust/crates/sdk/src/endpoints.rs +++ b/rust/crates/sdk/src/endpoints.rs @@ -19,11 +19,9 @@ impl CtxKey { } /// HTTP Header constants using `HeaderName` with `OnceLock` for lazy initialization -#[allow(dead_code)] // Currently unused static CLL_AVAIL_ORIGINS_HEADER: OnceLock = OnceLock::new(); -#[allow(dead_code)] // Currently unused static CLL_ORIGIN_HEADER: OnceLock = OnceLock::new(); -#[allow(dead_code)] // Currently unused +#[allow(dead_code)] static CLL_INT_HEADER: OnceLock = OnceLock::new(); static AUTHZ_HEADER: OnceLock = OnceLock::new(); static AUTHZ_TS_HEADER: OnceLock = OnceLock::new(); @@ -33,7 +31,6 @@ static HOST_HEADER: OnceLock = OnceLock::new(); /// Functions to retrieve header constants, initializing them on first access -#[allow(dead_code)] // Currently unused /// "X-Cll-Available-Origins" pub fn get_cll_avail_origins_header() -> &'static HeaderName { CLL_AVAIL_ORIGINS_HEADER.get_or_init(|| { @@ -42,7 +39,6 @@ pub fn get_cll_avail_origins_header() -> &'static HeaderName { }) } -#[allow(dead_code)] // Currently unused /// "X-Cll-Origin" pub fn get_cll_origin_header() -> &'static HeaderName { CLL_ORIGIN_HEADER.get_or_init(|| { @@ -50,8 +46,8 @@ pub fn get_cll_origin_header() -> &'static HeaderName { }) } -#[allow(dead_code)] // Currently unused /// "X-Cll-Eng-Int" +#[allow(dead_code)] pub fn get_cll_int_header() -> &'static HeaderName { CLL_INT_HEADER.get_or_init(|| { HeaderName::from_str("X-Cll-Eng-Int").expect("Invalid header name: X-Cll-Eng-Int") @@ -86,3 +82,20 @@ pub fn get_authz_sig_header() -> &'static HeaderName { pub fn get_host_header() -> &'static HeaderName { HOST_HEADER.get_or_init(|| HeaderName::from_str("Host").expect("Invalid header name: Host")) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cll_origin_header_name() { + let h = get_cll_origin_header(); + assert_eq!(h.as_str(), "x-cll-origin"); + } + + #[test] + fn test_cll_avail_origins_header_name() { + let h = get_cll_avail_origins_header(); + assert_eq!(h.as_str(), "x-cll-available-origins"); + } +} diff --git a/rust/crates/sdk/src/stream.rs b/rust/crates/sdk/src/stream.rs index 4c4006d..941b37a 100644 --- a/rust/crates/sdk/src/stream.rs +++ b/rust/crates/sdk/src/stream.rs @@ -4,11 +4,14 @@ mod monitor_connection; use establish_connection::connect; use monitor_connection::run_stream; -use crate::config::Config; +use crate::auth::generate_auth_headers; +use crate::config::{Config, WebSocketHighAvailability}; +use crate::endpoints::get_cll_avail_origins_header; use chainlink_data_streams_report::feed_id::ID; use chainlink_data_streams_report::report::Report; +use reqwest::Client as HttpClient; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, @@ -16,6 +19,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, + time::{SystemTime, UNIX_EPOCH}, }; use tokio::{ net::TcpStream, @@ -23,7 +27,7 @@ use tokio::{ time::{sleep, Duration}, }; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream as TungsteniteWebSocketStream}; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; pub const DEFAULT_WS_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); pub const MIN_WS_RECONNECT_INTERVAL: Duration = Duration::from_millis(1000); @@ -70,7 +74,7 @@ struct Stats { #[derive(Debug)] pub enum WebSocketConnection { Single(TungsteniteWebSocketStream>), - Multiple(Vec>>), + Multiple(Vec<(TungsteniteWebSocketStream>, String)>), } /// Stream represents a realtime report stream. @@ -141,7 +145,26 @@ impl Stream { active_connections: AtomicUsize::new(0), }); - let conn = connect(config, &feed_ids, stats.clone()).await?; + let origins: Vec = if config.ws_ha == WebSocketHighAvailability::Enabled { + match fetch_ha_origins(config).await { + Ok(o) if !o.is_empty() => { + info!("HA mode: discovered {} origins", o.len()); + o + } + Ok(_) => { + warn!("HA mode: no origins returned from HEAD request, degrading to single connection"); + vec![] + } + Err(e) => { + warn!("HA mode: origin discovery failed ({}), degrading to single connection", e); + vec![] + } + } + } else { + vec![] + }; + + let conn = connect(config, &origins, &feed_ids, stats.clone()).await?; let water_mark = Arc::new(Mutex::new(HashMap::new())); @@ -176,6 +199,7 @@ impl Stream { tokio::spawn(run_stream( stream, + String::new(), // no X-Cll-Origin header for non-HA connections report_sender, shutdown_receiver, stats, @@ -185,7 +209,7 @@ impl Stream { )); } WebSocketConnection::Multiple(streams) => { - for stream in streams { + for (stream, origin) in streams { let report_sender = self.report_sender.clone(); let shutdown_receiver = self.shutdown_sender.subscribe(); let stats = self.stats.clone(); @@ -195,6 +219,7 @@ impl Stream { tokio::spawn(run_stream( stream, + origin, report_sender, shutdown_receiver, stats, @@ -284,3 +309,139 @@ pub struct StatsSnapshot { /// Current number of active connections pub active_connections: usize, } + +fn parse_origins_from_header(header_value: &str) -> Vec { + let inner = header_value + .strip_prefix('{') + .and_then(|s| s.strip_suffix('}')) + .unwrap_or(header_value); + if inner.is_empty() { + return vec![]; + } + inner + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + +fn convert_ws_to_http_scheme(ws_url: &str) -> String { + if let Some(rest) = ws_url.strip_prefix("wss://") { + format!("https://{}", rest) + } else if let Some(rest) = ws_url.strip_prefix("ws://") { + format!("http://{}", rest) + } else { + ws_url.to_string() + } +} + +async fn fetch_ha_origins(config: &Config) -> Result, StreamError> { + let http = HttpClient::builder() + .danger_accept_invalid_certs(config.insecure_skip_verify.to_bool()) + .build() + .map_err(|e| StreamError::ConnectionError(e.to_string()))?; + + // Parse URL, normalize path to "/", keep scheme+host+port so the HMAC-signed + // path "/" matches the actual request path even when ws_url carries a subpath. + let http_url = { + let mut u = reqwest::Url::parse(&convert_ws_to_http_scheme(&config.ws_url)) + .map_err(|e| StreamError::ConnectionError(format!("Invalid ws_url: {}", e)))?; + u.set_path("/"); + u.set_query(None); + u.to_string() + }; + + let request_timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time error") + .as_millis(); + + let auth_headers = generate_auth_headers( + "HEAD", + "/", + b"", + &config.api_key, + &config.api_secret, + request_timestamp, + )?; + + let response = http + .head(&http_url) + .headers(auth_headers) + .send() + .await + .map_err(|e| StreamError::ConnectionError(format!("HA origin discovery request failed: {}", e)))?; + + if !response.status().is_success() { + return Err(StreamError::ConnectionError(format!( + "HA origin discovery HEAD request returned status {}", + response.status() + ))); + } + + let header_value = response + .headers() + .get(get_cll_avail_origins_header()) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + Ok(parse_origins_from_header(&header_value)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_origins_from_header_empty() { + assert_eq!(parse_origins_from_header(""), Vec::::new()); + } + + #[test] + fn test_parse_origins_from_header_with_braces() { + let result = parse_origins_from_header("{001,002}"); + assert_eq!(result, vec!["001".to_string(), "002".to_string()]); + } + + #[test] + fn test_parse_origins_from_header_without_braces() { + let result = parse_origins_from_header("001,002"); + assert_eq!(result, vec!["001".to_string(), "002".to_string()]); + } + + #[test] + fn test_parse_origins_from_header_single_origin() { + let result = parse_origins_from_header("{001}"); + assert_eq!(result, vec!["001".to_string()]); + } + + #[test] + fn test_parse_origins_from_header_empty_braces() { + assert_eq!(parse_origins_from_header("{}"), Vec::::new()); + } + + #[test] + fn test_convert_ws_scheme_wss() { + assert_eq!( + convert_ws_to_http_scheme("wss://ws.dataengine.chain.link"), + "https://ws.dataengine.chain.link" + ); + } + + #[test] + fn test_convert_ws_scheme_ws() { + assert_eq!( + convert_ws_to_http_scheme("ws://127.0.0.1:8080"), + "http://127.0.0.1:8080" + ); + } + + #[test] + fn test_convert_ws_scheme_passthrough() { + assert_eq!( + convert_ws_to_http_scheme("https://already.https.com"), + "https://already.https.com" + ); + } +} diff --git a/rust/crates/sdk/src/stream/establish_connection.rs b/rust/crates/sdk/src/stream/establish_connection.rs index a9828ea..7edcf6b 100644 --- a/rust/crates/sdk/src/stream/establish_connection.rs +++ b/rust/crates/sdk/src/stream/establish_connection.rs @@ -3,7 +3,7 @@ use super::{Stats, StreamError, WebSocketConnection}; use crate::{ auth::generate_auth_headers, config::{Config, WebSocketHighAvailability}, - endpoints::API_V1_WS, + endpoints::{get_cll_origin_header, API_V1_WS}, stream::{DEFAULT_WS_CONNECT_TIMEOUT, MAX_WS_RECONNECT_INTERVAL, MIN_WS_RECONNECT_INTERVAL}, }; @@ -23,20 +23,13 @@ use tokio_tungstenite::{ }; use tracing::{error, info}; -fn parse_origins(ws_url: &str) -> Vec { - ws_url - .split(',') - .map(|url| url.trim().to_string()) - .collect() -} - async fn connect_to_origin( config: &Config, - origin: &str, + cll_origin: &str, // X-Cll-Origin header value; empty string = no header feed_ids: &[ID], ) -> Result>, StreamError> { - let feed_ids: Vec = feed_ids.iter().map(|id| id.to_hex_string()).collect(); - let feed_ids_joined = feed_ids.join(","); + let feed_ids_str: Vec = feed_ids.iter().map(|id| id.to_hex_string()).collect(); + let feed_ids_joined = feed_ids_str.join(","); let method = "GET"; let path = format!("{}?feedIDs={}", API_V1_WS, feed_ids_joined.as_str()); @@ -48,7 +41,7 @@ async fn connect_to_origin( .expect("System time error") .as_millis(); - let headers = generate_auth_headers( + let mut headers = generate_auth_headers( method, &path, body, @@ -57,7 +50,17 @@ async fn connect_to_origin( request_timestamp, )?; - let url = format!("{}{}", origin, path); + if !cll_origin.is_empty() { + headers.insert( + get_cll_origin_header(), + reqwest::header::HeaderValue::from_str(cll_origin).map_err(|e| { + StreamError::ConnectionError(format!("Invalid X-Cll-Origin header value: {}", e)) + })?, + ); + } + + // Always connect to config.ws_url — cll_origin is a routing hint header, not a URL + let url = format!("{}{}", config.ws_url, path); let mut request = url.into_client_request().map_err(|e| { StreamError::ConnectionError(format!("Failed to create client request: {}", e)) })?; @@ -77,18 +80,21 @@ async fn connect_to_origin( pub(crate) async fn connect( config: &Config, + origins: &[String], // empty = single non-HA connection; populated = HA mode feed_ids: &[ID], stats: Arc, ) -> Result { - let origins = parse_origins(&config.ws_url); + if config.ws_ha == WebSocketHighAvailability::Enabled && origins.len() == 1 { + info!("HA mode enabled but only 1 origin discovered; connection will not be redundant"); + } - if config.ws_ha == WebSocketHighAvailability::Enabled && origins.len() > 1 { + if config.ws_ha == WebSocketHighAvailability::Enabled && !origins.is_empty() { let mut streams = Vec::new(); for origin in origins { - match connect_to_origin(config, &origin, feed_ids).await { + match connect_to_origin(config, origin, feed_ids).await { Ok(stream) => { - streams.push(stream); + streams.push((stream, origin.clone())); stats.configured_connections.fetch_add(1, Ordering::SeqCst); stats.active_connections.fetch_add(1, Ordering::SeqCst); } @@ -100,20 +106,15 @@ pub(crate) async fn connect( if streams.is_empty() { return Err(StreamError::ConnectionError( - "Failed to connect to any WebSocket origins".into(), + "Failed to connect to any WebSocket origins in HA mode".into(), )); } Ok(WebSocketConnection::Multiple(streams)) } else { - let origin = origins.first().ok_or_else(|| { - StreamError::ConnectionError("No WebSocket origin found in config".into()) - })?; - - let stream = connect_to_origin(config, origin, feed_ids).await?; + let stream = connect_to_origin(config, "", feed_ids).await?; stats.configured_connections.fetch_add(1, Ordering::SeqCst); stats.active_connections.fetch_add(1, Ordering::SeqCst); - Ok(WebSocketConnection::Single(stream)) } } @@ -121,15 +122,15 @@ pub(crate) async fn connect( pub(crate) async fn try_to_reconnect( stats: Arc, config: &Config, + origin: &str, // the X-Cll-Origin header value for this connection (empty = non-HA) feed_ids: &[ID], ) -> Result>, StreamError> { let mut reconnect_attempts = 0; let max_reconnect_attempts = config.ws_max_reconnect; - let origin = config.ws_url.split(',').next().unwrap(); let mut backoff = MIN_WS_RECONNECT_INTERVAL; loop { - info!("Attempting to reconnect to origin: {}", origin); + info!("Attempting to reconnect (origin: {})", origin); reconnect_attempts += 1; match connect_to_origin(config, origin, feed_ids).await { Ok(new_stream) => { @@ -150,7 +151,6 @@ pub(crate) async fn try_to_reconnect( } error!("Retrying in {:?}.", backoff); - sleep(backoff).await; backoff = (backoff * 2).min(MAX_WS_RECONNECT_INTERVAL); } diff --git a/rust/crates/sdk/src/stream/monitor_connection.rs b/rust/crates/sdk/src/stream/monitor_connection.rs index 1cd497b..14c4d6e 100644 --- a/rust/crates/sdk/src/stream/monitor_connection.rs +++ b/rust/crates/sdk/src/stream/monitor_connection.rs @@ -24,6 +24,7 @@ use tracing::{error, info, warn}; pub(crate) async fn run_stream( mut stream: TungsteniteWebSocketStream>, + origin: String, // X-Cll-Origin value for this connection; empty = non-HA report_sender: mpsc::Sender, mut shutdown_receiver: broadcast::Receiver<()>, stats: Arc, @@ -92,7 +93,7 @@ pub(crate) async fn run_stream( error!("Error receiving message: {:?}", e); stats.active_connections.fetch_sub(1, Ordering::SeqCst); - stream = handle_reconnection(stats.clone(), &config, &feed_ids).await?; + stream = handle_reconnection(stats.clone(), &config, &origin, &feed_ids).await?; } None => { info!("WebSocket stream closed."); @@ -102,7 +103,7 @@ pub(crate) async fn run_stream( info!("Stream closed gracefully after shutdown signal."); return Ok(()); } else { - stream = handle_reconnection(stats.clone(), &config, &feed_ids).await?; + stream = handle_reconnection(stats.clone(), &config, &origin, &feed_ids).await?; } } } @@ -126,6 +127,7 @@ pub(crate) async fn run_stream( async fn handle_reconnection( stats: Arc, config: &Config, + origin: &str, // X-Cll-Origin value; passed through to try_to_reconnect feed_ids: &[ID], ) -> Result>, StreamError> { if stats.active_connections.load(Ordering::SeqCst) == 0 { @@ -134,6 +136,6 @@ async fn handle_reconnection( stats.partial_reconnects.fetch_add(1, Ordering::SeqCst); } - let new_stream = try_to_reconnect(stats.clone(), config, feed_ids).await?; + let new_stream = try_to_reconnect(stats.clone(), config, origin, feed_ids).await?; Ok(new_stream) } diff --git a/rust/crates/sdk/tests/stream_integration_tests.rs b/rust/crates/sdk/tests/stream_integration_tests.rs index c0fc473..f8c02e4 100644 --- a/rust/crates/sdk/tests/stream_integration_tests.rs +++ b/rust/crates/sdk/tests/stream_integration_tests.rs @@ -19,11 +19,15 @@ async fn prepare_scenario() -> (MockWebSocketServer, Stream, Vec) { let mock_server_address = "127.0.0.1:0"; let mock_server = MockWebSocketServer::new(mock_server_address).await; - let origins = repeat(format!("ws://{}", mock_server.address())) - .take(NUMBER_OF_CONNECTIONS) - .collect::>(); + let ws_url = format!("ws://{}", mock_server.address()); - let ws_url = origins.join(","); + // Configure mock server to return N origins in HEAD response (all pointing to same mock). + // In production these would be distinct backends; in tests we use the same address + // to validate deduplication, reconnect, and HA connection-count behavior. + let origins: Vec = repeat(ws_url.clone()) + .take(NUMBER_OF_CONNECTIONS) + .collect(); + mock_server.set_ha_origins(origins).await; let config = Config::new( "mock_key".to_string(), @@ -259,6 +263,59 @@ async fn test_stream_ha_reconnect_merge() { assert_eq!(stats.deduplicated, expected_deduplicated); } +#[tokio::test] +async fn test_stream_ha_x_cll_origin_header() { + let mock_server = MockWebSocketServer::new("127.0.0.1:0").await; + let ws_url = format!("ws://{}", mock_server.address()); + + // Use distinct opaque origin IDs matching the real protocol format (e.g. {001,002}). + mock_server + .set_ha_origins(vec!["001".to_string(), "002".to_string()]) + .await; + + let config = Config::new( + "mock_key".to_string(), + "mock_secret".to_string(), + "mock_rest_url".to_string(), + ws_url, + ) + .with_ws_ha(WebSocketHighAvailability::Enabled) + .with_ws_max_reconnect(MAX_RECONNECT_ATTEMPTS) + .build() + .expect("Failed to build config"); + + let mut stream = Stream::new(&config, vec![]) + .await + .expect("Failed to create stream"); + + stream.listen().await.expect("Failed to start listening"); + + // Allow time for both WebSocket connections to be established. + sleep(Duration::from_millis(500)).await; + + let received = mock_server.get_received_cll_origins().await; + + // Assert 1: X-Cll-Origin header is present on every HA WebSocket connection. + assert_eq!(received.len(), 2, "Expected 2 WebSocket connections in HA mode"); + for origin in &received { + assert!( + origin.is_some(), + "X-Cll-Origin header was missing on a WebSocket connection" + ); + } + + // Assert 2: The server sees distinct origin values matching the configured origins. + let mut actual: Vec = received.into_iter().flatten().collect(); + actual.sort(); + assert_eq!( + actual, + vec!["001".to_string(), "002".to_string()], + "X-Cll-Origin header values did not match the configured origins" + ); + + stream.close().await.expect("Failed to close stream"); +} + #[tokio::test] #[ignore] // Ignored because it takes a while to complete. To run it, use this command: cargo test -- --ignored async fn test_stream_ha_max_reconnection_attempts() { diff --git a/rust/crates/sdk/tests/utils/mock_websocket_server.rs b/rust/crates/sdk/tests/utils/mock_websocket_server.rs index 9f3b3d6..1d6ef58 100644 --- a/rust/crates/sdk/tests/utils/mock_websocket_server.rs +++ b/rust/crates/sdk/tests/utils/mock_websocket_server.rs @@ -1,10 +1,14 @@ use futures::{SinkExt, StreamExt}; use std::sync::Arc; use tokio::{ - net::TcpListener, - sync::{mpsc, Mutex, Notify}, + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + sync::{mpsc, oneshot, Mutex, Notify}, +}; +use tokio_tungstenite::{ + accept_hdr_async, + tungstenite::{handshake::server::Request as WsRequest, protocol::Message}, }; -use tokio_tungstenite::{accept_async, tungstenite::protocol::Message}; enum ServerCommand { Send(Vec), @@ -16,6 +20,12 @@ pub struct MockWebSocketServer { address: String, command_sender: mpsc::Sender, shutdown_notify: Arc, + /// Origins returned in X-Cll-Available-Origins HEAD response. + /// When None, defaults to two copies of the server's own ws:// address. + ha_origins: Arc>>>, + /// X-Cll-Origin header values captured from incoming WebSocket upgrade requests. + /// Some(value) if the header was present, None if absent. + received_cll_origins: Arc>>>, } impl MockWebSocketServer { @@ -25,46 +35,37 @@ impl MockWebSocketServer { .expect("Failed to bind address"); let address = listener.local_addr().unwrap().to_string(); - println!("Mock WebSocket server started at: {}", address); let (command_sender, mut command_receiver) = mpsc::channel::(100); let clients = Arc::new(Mutex::new(Vec::new())); let shutdown_notify = Arc::new(Notify::new()); + let ha_origins: Arc>>> = Arc::new(Mutex::new(None)); + let received_cll_origins: Arc>>> = + Arc::new(Mutex::new(Vec::new())); let clients_accept = clients.clone(); let shutdown_accept = shutdown_notify.clone(); + let ha_origins_accept = ha_origins.clone(); + let received_accept = received_cll_origins.clone(); + let server_address = address.clone(); + tokio::spawn(async move { loop { tokio::select! { accept_result = listener.accept() => { match accept_result { Ok((stream, _)) => { - let ws_stream = accept_async(stream) - .await - .expect("Failed to accept connection"); - println!( - "Client connected: {}", - ws_stream.get_ref().peer_addr().unwrap() - ); - - let (mut ws_sender, _) = ws_stream.split(); - let (client_sender, mut client_receiver) = - mpsc::channel::(100); - - clients_accept.lock().await.push(client_sender); - - // Spawn a task to forward messages from the server to the client. - tokio::spawn(async move { - while let Some(message) = client_receiver.recv().await { - if ws_sender.send(message).await.is_err() { - break; - } - } - println!("Client connection closed"); - }); - - // Ignore messages from the client. There will none in this test. + let origins = { + let guard = ha_origins_accept.lock().await; + guard.clone().unwrap_or_else(|| vec![ + format!("ws://{}", server_address), + format!("ws://{}", server_address), + ]) + }; + let clients_clone = clients_accept.clone(); + let received_clone = received_accept.clone(); + tokio::spawn(handle_connection(stream, origins, clients_clone, received_clone)); } Err(e) => { println!("Error accepting connection: {:?}", e); @@ -72,11 +73,9 @@ impl MockWebSocketServer { } } } - // Listen for shutdown signal. _ = shutdown_accept.notified() => { println!("Shutting down"); - let mut clients = clients_accept.lock().await; - clients.clear(); + clients_accept.lock().await.clear(); break; } } @@ -95,8 +94,7 @@ impl MockWebSocketServer { } ServerCommand::DropConnections => { println!("Dropping all client connections"); - let mut clients = clients_command.lock().await; - clients.clear(); + clients_command.lock().await.clear(); } } } @@ -106,6 +104,8 @@ impl MockWebSocketServer { address, command_sender, shutdown_notify, + ha_origins, + received_cll_origins, } } @@ -127,4 +127,90 @@ impl MockWebSocketServer { pub async fn shutdown(&self) { self.shutdown_notify.notify_waiters(); } + + /// Configure the origins returned in the X-Cll-Available-Origins HEAD response. + /// Call this before Stream::new() is invoked in tests that exercise HA discovery. + pub async fn set_ha_origins(&self, origins: Vec) { + *self.ha_origins.lock().await = Some(origins); + } + + /// Returns the X-Cll-Origin header values captured from all WebSocket upgrade requests. + /// Some(value) means the header was present; None means it was absent. + pub async fn get_received_cll_origins(&self) -> Vec> { + self.received_cll_origins.lock().await.clone() + } +} + +async fn handle_connection( + mut stream: TcpStream, + ha_origins: Vec, + clients: Arc>>>, + received_cll_origins: Arc>>>, +) { + // Peek at first 4 bytes to distinguish HTTP HEAD from WebSocket upgrade. + // peek() does not consume data, so the full request remains readable by accept_hdr_async. + let mut peek_buf = [0u8; 4]; + let n = match stream.peek(&mut peek_buf).await { + Ok(n) => n, + Err(e) => { + println!("Peek error: {:?}", e); + return; + } + }; + + if n >= 4 && &peek_buf[..4] == b"HEAD" { + // Consume the HTTP request (drain until blank line) + let mut buf = [0u8; 4096]; + let _ = stream.read(&mut buf).await; + + let origins_str = ha_origins.join(","); + let response = format!( + "HTTP/1.1 200 OK\r\nX-Cll-Available-Origins: {{{}}}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + origins_str + ); + let _ = stream.write_all(response.as_bytes()).await; + } else { + // WebSocket upgrade — capture the X-Cll-Origin header from the upgrade request. + let (origin_tx, mut origin_rx) = oneshot::channel::>(); + + let ws_stream = match accept_hdr_async(stream, move |req: &WsRequest, resp| { + let origin = req + .headers() + .get("x-cll-origin") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let _ = origin_tx.send(origin); + Ok(resp) + }) + .await + { + Ok(s) => s, + Err(e) => { + println!("WebSocket accept error: {:?}", e); + return; + } + }; + + // origin_tx.send() has already run by the time accept_hdr_async resolves. + let origin = origin_rx.try_recv().unwrap_or(None); + received_cll_origins.lock().await.push(origin); + + println!( + "Client connected: {}", + ws_stream.get_ref().peer_addr().unwrap() + ); + + let (mut ws_sender, _) = ws_stream.split(); + let (client_sender, mut client_receiver) = mpsc::channel::(100); + clients.lock().await.push(client_sender); + + tokio::spawn(async move { + while let Some(message) = client_receiver.recv().await { + if ws_sender.send(message).await.is_err() { + break; + } + } + println!("Client connection closed"); + }); + } }