diff --git a/dash-spv/Cargo.toml b/dash-spv/Cargo.toml index 7af44f2fe..848598d93 100644 --- a/dash-spv/Cargo.toml +++ b/dash-spv/Cargo.toml @@ -18,6 +18,9 @@ key-wallet-manager = { path = "../key-wallet-manager" } # BLS signatures blsful = { git = "https://github.com/dashpay/agora-blsful", rev = "0c34a7a488a0bd1c9a9a2196e793b303ad35c900" } +# BIP324 v2 P2P encrypted transport +bip324 = { git = "https://github.com/rust-bitcoin/bip324", rev = "8c469432", features = ["std", "tokio"] } + # CLI clap = { version = "4.0", features = ["derive", "env"] } diff --git a/dash-spv/src/client/config.rs b/dash-spv/src/client/config.rs index 11ccfb568..4ed27cc6f 100644 --- a/dash-spv/src/client/config.rs +++ b/dash-spv/src/client/config.rs @@ -7,6 +7,7 @@ use std::time::Duration; use dashcore::Network; // Serialization removed due to complex Address types +use crate::network::transport::TransportPreference; use crate::types::ValidationMode; /// Strategy for handling mempool (unconfirmed) transactions. @@ -152,6 +153,10 @@ pub struct ClientConfig { /// Timeout for QRInfo requests (default: 30 seconds). pub qr_info_timeout: Duration, + + /// Transport preference for peer connections (V1, V2, or V2 with fallback). + /// Default is V2Preferred: try V2 encrypted transport first, fall back to V1. + pub transport_preference: TransportPreference, } impl Default for ClientConfig { @@ -201,6 +206,8 @@ impl Default for ClientConfig { // QRInfo defaults (simplified per plan) qr_info_extra_share: false, // Matches DMLviewer.patch default qr_info_timeout: Duration::from_secs(30), + // Transport preference (BIP324 v2 encrypted by default with v1 fallback) + transport_preference: TransportPreference::default(), } } } @@ -342,6 +349,16 @@ impl ClientConfig { self } + /// Set transport preference for peer connections. + /// + /// - `V2Preferred` (default): Try BIP324 v2 encrypted transport first, fall back to v1 + /// - `V2Only`: Require BIP324 v2 encrypted transport, fail if peer doesn't support it + /// - `V1Only`: Use traditional unencrypted v1 transport only + pub fn with_transport_preference(mut self, preference: TransportPreference) -> Self { + self.transport_preference = preference; + self + } + /// Validate the configuration. pub fn validate(&self) -> Result<(), String> { // Note: Empty peers list is now valid - DNS discovery will be used automatically diff --git a/dash-spv/src/error.rs b/dash-spv/src/error.rs index 5e411449d..7f745fbde 100644 --- a/dash-spv/src/error.rs +++ b/dash-spv/src/error.rs @@ -104,6 +104,19 @@ pub enum NetworkError { #[error("System time error: {0}")] SystemTime(String), + + // BIP324 V2 transport errors + #[error("V2 handshake failed: {0}")] + V2HandshakeFailed(String), + + #[error("V2 decryption failed: {0}")] + V2DecryptionFailed(String), + + #[error("V2 encryption failed: {0}")] + V2EncryptionFailed(String), + + #[error("V2 not supported by peer")] + V2NotSupported, } /// Storage-related errors. diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index c0dc87ff2..cd76bc4c6 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -27,6 +27,7 @@ use crate::network::pool::PeerPool; use crate::network::reputation::{ misbehavior_scores, positive_scores, PeerReputationManager, ReputationAware, }; +use crate::network::transport::TransportPreference; use crate::network::{HandshakeManager, NetworkManager, Peer}; use crate::types::PeerInfo; @@ -71,6 +72,8 @@ pub struct PeerNetworkManager { exclusive_mode: bool, /// Cached count of currently connected peers for fast, non-blocking queries connected_peer_count: Arc, + /// Transport preference for peer connections (V1, V2, or V2 with fallback) + transport_preference: TransportPreference, } impl PeerNetworkManager { @@ -124,6 +127,7 @@ impl PeerNetworkManager { user_agent: config.user_agent.clone(), exclusive_mode, connected_peer_count: Arc::new(AtomicUsize::new(0)), + transport_preference: config.transport_preference, }) } @@ -210,13 +214,16 @@ impl PeerNetworkManager { let mempool_strategy = self.mempool_strategy; let user_agent = self.user_agent.clone(); let connected_peer_count = self.connected_peer_count.clone(); + let transport_preference = self.transport_preference; // Spawn connection task let mut tasks = self.tasks.lock().await; tasks.spawn(async move { log::debug!("Attempting to connect to {}", addr); - match Peer::connect(addr, CONNECTION_TIMEOUT.as_secs(), network).await { + match Peer::connect(addr, CONNECTION_TIMEOUT.as_secs(), network, transport_preference) + .await + { Ok(mut peer) => { // Perform handshake let mut handshake_manager = @@ -1069,6 +1076,7 @@ impl Clone for PeerNetworkManager { user_agent: self.user_agent.clone(), exclusive_mode: self.exclusive_mode, connected_peer_count: self.connected_peer_count.clone(), + transport_preference: self.transport_preference, } } } diff --git a/dash-spv/src/network/mod.rs b/dash-spv/src/network/mod.rs index 89e8bde78..2b358c8e9 100644 --- a/dash-spv/src/network/mod.rs +++ b/dash-spv/src/network/mod.rs @@ -9,6 +9,7 @@ pub mod peer; pub mod persist; pub mod pool; pub mod reputation; +pub mod transport; #[cfg(test)] mod tests; @@ -25,6 +26,7 @@ use dashcore::BlockHash; pub use handshake::{HandshakeManager, HandshakeState}; pub use manager::PeerNetworkManager; pub use peer::Peer; +pub use transport::{Transport, TransportPreference, V1Transport}; /// Network manager trait for abstracting network operations. #[async_trait] diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index 1147a663b..cb0107c4c 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -2,36 +2,26 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio::sync::Mutex; -use dashcore::consensus::{encode, Decodable}; -use dashcore::network::message::{NetworkMessage, RawNetworkMessage}; +use dashcore::network::message::NetworkMessage; use dashcore::Network; use crate::error::{NetworkError, NetworkResult}; use crate::network::constants::PING_INTERVAL; +use crate::network::transport::{ + Transport, TransportPreference, V1Transport, V2HandshakeManager, V2HandshakeResult, V2Transport, +}; use crate::types::PeerInfo; -/// Internal state for the TCP connection -struct ConnectionState { - stream: TcpStream, - // Stateful message framing buffer to ensure full frames before decoding - framing_buffer: Vec, -} - /// Dash P2P peer pub struct Peer { address: SocketAddr, - // Use a single mutex to protect both the write stream and read buffer - // This ensures no concurrent access to the underlying socket - state: Option>>, + /// The transport layer (V1 or V2) + transport: Option>, timeout: Duration, connected_at: Option, - bytes_sent: u64, network: Network, // Ping/pong state last_ping_sent: Option, @@ -45,8 +35,8 @@ pub struct Peer { relay: Option, prefers_headers2: bool, sent_sendheaders2: bool, - // Basic telemetry for resync events - consecutive_resyncs: u32, + // Transport protocol version used (1 or 2) + transport_version: u8, } impl Peer { @@ -54,14 +44,19 @@ impl Peer { pub fn address(&self) -> SocketAddr { self.address } - /// Create a new peer. + + /// Get the transport protocol version (1 = unencrypted, 2 = BIP324 encrypted). + pub fn transport_version(&self) -> u8 { + self.transport_version + } + + /// Create a new peer (not connected). pub fn new(address: SocketAddr, timeout: Duration, network: Network) -> Self { Self { address, - state: None, + transport: None, timeout, connected_at: None, - bytes_sent: 0, network, last_ping_sent: None, last_pong_received: None, @@ -73,42 +68,62 @@ impl Peer { relay: None, prefers_headers2: false, sent_sendheaders2: false, - consecutive_resyncs: 0, + transport_version: 1, } } - /// Connect to a peer and return a connected instance. + /// Connect to a peer with the specified transport preference. + /// + /// # Arguments + /// * `address` - The peer's socket address + /// * `timeout_secs` - Connection timeout in seconds + /// * `network` - The Dash network (mainnet, testnet, etc.) + /// * `transport_pref` - V1Only, V2Only, or V2Preferred (default) + /// + /// # Returns + /// A connected Peer instance using the appropriate transport. pub async fn connect( address: SocketAddr, timeout_secs: u64, network: Network, + transport_pref: TransportPreference, ) -> NetworkResult { let timeout = Duration::from_secs(timeout_secs); - let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) - .await - .map_err(|_| { - NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) - })? - .map_err(|e| { - NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) - })?; - - stream.set_nodelay(true).map_err(|e| { - NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) - })?; - - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), + let (transport, transport_version): (Box, u8) = match transport_pref { + TransportPreference::V1Only => { + tracing::info!("Connecting to {} using V1 transport (unencrypted)", address); + let transport = Self::establish_v1_transport(address, timeout, network).await?; + (Box::new(transport), 1) + } + TransportPreference::V2Only => { + tracing::info!( + "Connecting to {} using V2 transport (BIP324 encrypted, no fallback)", + address + ); + let transport = Self::establish_v2_transport(address, timeout, network).await?; + (Box::new(transport), 2) + } + TransportPreference::V2Preferred => { + tracing::info!( + "Connecting to {} using V2 transport (BIP324 encrypted, with V1 fallback)", + address + ); + Self::try_v2_with_fallback(address, timeout, network).await? + } }; + tracing::info!( + "Successfully connected to {} using V{} transport", + address, + transport_version + ); + Ok(Self { address, - state: Some(Arc::new(Mutex::new(state))), + transport: Some(transport), timeout, connected_at: Some(SystemTime::now()), - bytes_sent: 0, network, last_ping_sent: None, last_pong_received: None, @@ -120,49 +135,142 @@ impl Peer { relay: None, prefers_headers2: false, sent_sendheaders2: false, - consecutive_resyncs: 0, + transport_version, }) } - /// Connect to the peer (instance method for compatibility). - pub async fn connect_instance(&mut self) -> NetworkResult<()> { - let stream = tokio::time::timeout(self.timeout, TcpStream::connect(self.address)) + /// Establish a V1 (unencrypted) transport connection. + async fn establish_v1_transport( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult { + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) .await .map_err(|_| { - NetworkError::ConnectionFailed(format!("Connection to {} timed out", self.address)) + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) + })? + .map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) + })?; + + stream.set_nodelay(true).map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) + })?; + + Ok(V1Transport::new(stream, network, address)) + } + + /// Establish a V2 (BIP324 encrypted) transport connection. + /// Fails if peer doesn't support V2. + async fn establish_v2_transport( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult { + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) + .await + .map_err(|_| { + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) })? .map_err(|e| { - NetworkError::ConnectionFailed(format!( - "Failed to connect to {}: {}", - self.address, e - )) + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) })?; - // Disable Nagle's algorithm for lower latency stream.set_nodelay(true).map_err(|e| { NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) })?; - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), + let handshake_manager = V2HandshakeManager::new_initiator(network, address); + match handshake_manager.perform_handshake(stream).await? { + V2HandshakeResult::Success(session) => { + Ok(V2Transport::new(session.stream, session.cipher, session.session_id, address)) + } + V2HandshakeResult::FallbackToV1 => Err(NetworkError::V2NotSupported), + } + } + + /// Try V2 transport first, fall back to V1 if peer doesn't support V2. + async fn try_v2_with_fallback( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult<(Box, u8)> { + // First, try to establish TCP connection + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) + .await + .map_err(|_| { + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) + })? + .map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) + })?; + + stream.set_nodelay(true).map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) + })?; + + // Try V2 handshake + let handshake_manager = V2HandshakeManager::new_initiator(network, address); + match handshake_manager.perform_handshake(stream).await { + Ok(V2HandshakeResult::Success(session)) => { + tracing::info!("V2 handshake succeeded with {}", address); + let transport = + V2Transport::new(session.stream, session.cipher, session.session_id, address); + Ok((Box::new(transport), 2)) + } + Ok(V2HandshakeResult::FallbackToV1) => { + tracing::info!( + "V2 handshake detected V1-only peer {}, reconnecting with V1 transport", + address + ); + // Need to reconnect since the stream was consumed + let transport = Self::establish_v1_transport(address, timeout, network).await?; + Ok((Box::new(transport), 1)) + } + Err(e) => { + tracing::warn!("V2 handshake failed with {}: {}, falling back to V1", address, e); + // Try V1 as fallback + let transport = Self::establish_v1_transport(address, timeout, network).await?; + Ok((Box::new(transport), 1)) + } + } + } + + /// Connect to the peer (instance method for compatibility). + pub async fn connect_instance( + &mut self, + transport_pref: TransportPreference, + ) -> NetworkResult<()> { + let (transport, transport_version): (Box, u8) = match transport_pref { + TransportPreference::V1Only => { + let t = + Self::establish_v1_transport(self.address, self.timeout, self.network).await?; + (Box::new(t), 1) + } + TransportPreference::V2Only => { + let t = + Self::establish_v2_transport(self.address, self.timeout, self.network).await?; + (Box::new(t), 2) + } + TransportPreference::V2Preferred => { + Self::try_v2_with_fallback(self.address, self.timeout, self.network).await? + } }; - self.state = Some(Arc::new(Mutex::new(state))); + self.transport = Some(transport); + self.transport_version = transport_version; self.connected_at = Some(SystemTime::now()); - tracing::info!("Connected to peer {}", self.address); + tracing::info!("Connected to peer {} using V{} transport", self.address, transport_version); Ok(()) } /// Disconnect from the peer. pub async fn disconnect(&mut self) -> NetworkResult<()> { - if let Some(state_arc) = self.state.take() { - if let Ok(state_mutex) = Arc::try_unwrap(state_arc) { - let mut state = state_mutex.into_inner(); - let _ = state.stream.shutdown().await; - } + if let Some(mut transport) = self.transport.take() { + transport.shutdown().await?; } self.connected_at = None; @@ -272,372 +380,28 @@ impl Peer { ); } - /// Helper function to read some bytes into the framing buffer. - async fn read_some(state: &mut ConnectionState) -> std::io::Result { - let mut tmp = [0u8; 8192]; - match state.stream.read(&mut tmp).await { - Ok(0) => Ok(0), - Ok(n) => { - state.framing_buffer.extend_from_slice(&tmp[..n]); - Ok(n) - } - Err(e) => Err(e), - } - } - /// Send a message to the peer. pub async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { - let state_arc = self - .state - .as_ref() + let transport = self + .transport + .as_mut() .ok_or_else(|| NetworkError::ConnectionFailed("Not connected".to_string()))?; - let raw_message = RawNetworkMessage { - magic: self.network.magic(), - payload: message, - }; - - let serialized = encode::serialize(&raw_message); - - // Log details for debugging headers2 issues - if matches!( - raw_message.payload, - NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_) - ) { - let msg_type = match raw_message.payload { - NetworkMessage::GetHeaders2(_) => "GetHeaders2", - NetworkMessage::GetHeaders(_) => "GetHeaders", - _ => "Unknown", - }; - tracing::debug!( - "Sending {} raw bytes (len={}): {:02x?}", - msg_type, - serialized.len(), - &serialized[..std::cmp::min(100, serialized.len())] - ); - } - - // Lock the state for the entire write operation - let mut state = state_arc.lock().await; - - // Write with error handling - match state.stream.write_all(&serialized).await { - Ok(_) => { - // Flush to ensure data is sent immediately - if let Err(e) = state.stream.flush().await { - tracing::warn!("Failed to flush socket {}: {}", self.address, e); - } - self.bytes_sent += serialized.len() as u64; - tracing::debug!("Sent message to {}: {:?}", self.address, raw_message.payload); - Ok(()) - } - Err(e) => { - tracing::warn!("Disconnecting {} due to write error: {}", self.address, e); - // Drop the lock before clearing connection state - drop(state); - // Clear connection state on write error - self.state = None; - self.connected_at = None; - Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) - } - } + transport.send_message(message).await } /// Receive a message from the peer. pub async fn receive_message(&mut self) -> NetworkResult> { - // First check if we have a state - let state_arc = self - .state - .as_ref() + let transport = self + .transport + .as_mut() .ok_or_else(|| NetworkError::ConnectionFailed("Not connected".to_string()))?; - // Lock the state for the entire read operation - // This ensures no concurrent access to the socket - let mut state = state_arc.lock().await; - - // Buffered, stateful framing - const HEADER_LEN: usize = 24; // magic[4] + cmd[12] + length[4] + checksum[4] - const MAX_RESYNC_STEPS_PER_CALL: usize = 64; - - let result = async { - let magic_bytes = self.network.magic().to_le_bytes(); - let mut resync_steps = 0usize; - - loop { - // Ensure header availability - if state.framing_buffer.len() < HEADER_LEN { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(ref e) - if e.kind() == std::io::ErrorKind::ConnectionAborted - || e.kind() == std::io::ErrorKind::ConnectionReset => - { - tracing::info!("Peer {} connection reset/aborted", self.address); - return Err(NetworkError::PeerDisconnected); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - } - - // Align to magic - if state.framing_buffer.len() >= 4 && state.framing_buffer[..4] != magic_bytes { - if let Some(pos) = - state.framing_buffer.windows(4).position(|w| w == magic_bytes) - { - if pos > 0 { - tracing::warn!( - "{}: stream desync: skipping {} stray bytes before magic", - self.address, - pos - ); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - state.framing_buffer.drain(0..pos); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } else { - // Keep last 3 bytes of potential magic prefix - if state.framing_buffer.len() > 3 { - let dropped = state.framing_buffer.len() - 3; - tracing::warn!( - "{}: stream desync: dropping {} bytes (no magic found)", - self.address, - dropped - ); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - state.framing_buffer.drain(0..dropped); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - } - // Need more data - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - } - - // Ensure full header - if state.framing_buffer.len() < HEADER_LEN { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - - // Parse header fields - let length_le = u32::from_le_bytes([ - state.framing_buffer[16], - state.framing_buffer[17], - state.framing_buffer[18], - state.framing_buffer[19], - ]) as usize; - let header_checksum = [ - state.framing_buffer[20], - state.framing_buffer[21], - state.framing_buffer[22], - state.framing_buffer[23], - ]; - // Validate announced length to prevent unbounded accumulation or overflow - if length_le > dashcore::network::message::MAX_MSG_SIZE { - return Err(NetworkError::ProtocolError(format!( - "Declared payload length {} exceeds MAX_MSG_SIZE {}", - length_le, - dashcore::network::message::MAX_MSG_SIZE - ))); - } - let total_len = match HEADER_LEN.checked_add(length_le) { - Some(v) => v, - None => { - return Err(NetworkError::ProtocolError( - "Message length overflow".to_string(), - )); - } - }; - - // Ensure full frame available - if state.framing_buffer.len() < total_len { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - - // Verify checksum - let payload_slice = &state.framing_buffer[HEADER_LEN..total_len]; - let expected = { - let checksum = ::hash( - payload_slice, - ); - [checksum[0], checksum[1], checksum[2], checksum[3]] - }; - if expected != header_checksum { - tracing::warn!( - "Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", - self.address, - expected, - header_checksum - ); - if header_checksum == [0, 0, 0, 0] { - tracing::warn!( - "All-zeros checksum detected from {}, likely corrupted stream - resyncing", - self.address - ); - } - // Resync by dropping a byte and retrying - state.framing_buffer.drain(0..1); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - - // Decode full RawNetworkMessage from the frame using existing decoder - let mut cursor = std::io::Cursor::new(&state.framing_buffer[..total_len]); - match RawNetworkMessage::consensus_decode(&mut cursor) { - Ok(raw_message) => { - // Consume bytes - state.framing_buffer.drain(0..total_len); - self.consecutive_resyncs = 0; - - // Validate magic matches our network - if raw_message.magic != self.network.magic() { - tracing::warn!( - "Received message with wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ); - return Err(NetworkError::ProtocolError(format!( - "Wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ))); - } - - tracing::trace!( - "Successfully decoded message from {}: {:?}", - self.address, - raw_message.payload.cmd() - ); - - if raw_message.payload.cmd() == "headers2" { - tracing::info!("🎉 Received Headers2 message from {}!", self.address); - } - - if let NetworkMessage::Block(ref block) = raw_message.payload { - let block_hash = block.block_hash(); - tracing::info!( - "Successfully decoded block {} from {}", - block_hash, - self.address - ); - } - - if let NetworkMessage::Headers2(ref headers2) = raw_message.payload { - tracing::info!( - "Successfully decoded Headers2 message from {} with {} compressed headers", - self.address, - headers2.headers.len() - ); - } - - return Ok(Some(raw_message.payload)); - } - Err(e) => { - tracing::warn!( - "{}: decode error after framing ({}), attempting resync", - self.address, - e - ); - state.framing_buffer.drain(0..1); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } - } - } - .await; + let result = transport.receive_message().await; - // Drop the lock before disconnecting - drop(state); - - // Handle disconnection if needed + // Handle disconnection if let Err(NetworkError::PeerDisconnected) = &result { - self.state = None; + self.transport = None; self.connected_at = None; } @@ -646,7 +410,7 @@ impl Peer { /// Check if the connection is active. pub fn is_connected(&self) -> bool { - self.state.is_some() + self.transport.as_ref().map(|t| t.is_connected()).unwrap_or(false) } /// Check if connection appears healthy (not just connected). @@ -701,7 +465,11 @@ impl Peer { /// Get connection statistics. pub fn stats(&self) -> (u64, u64) { - (self.bytes_sent, 0) // TODO: Track bytes received + if let Some(transport) = &self.transport { + (transport.bytes_sent(), transport.bytes_received()) + } else { + (0, 0) + } } /// Send a ping message with a random nonce. diff --git a/dash-spv/src/network/transport/message_ids.rs b/dash-spv/src/network/transport/message_ids.rs new file mode 100644 index 000000000..874238fa0 --- /dev/null +++ b/dash-spv/src/network/transport/message_ids.rs @@ -0,0 +1,334 @@ +//! BIP324 short message IDs for Dash. +//! +//! BIP324 uses 1-byte short IDs for common messages to reduce bandwidth. +//! Less common messages use extended format: 0x00 + 12-byte ASCII command. +//! +//! Dash extends BIP324 with its own message IDs in the 128-255 range: +//! - IDs 0-32: Standard BIP324 (Bitcoin) messages +//! - IDs 128-167: Dash-specific messages +//! +//! ## Design Notes +//! +//! There's intentional asymmetry between the two main functions: +//! - `short_id_to_command`: Handles ALL short IDs (for receiving messages) +//! - `network_message_to_short_id`: Only handles NetworkMessage variants that exist +//! +//! This means we can decode incoming messages with short IDs even if dashcore +//! doesn't have a dedicated NetworkMessage variant for them (they'll be decoded +//! as Unknown messages via the extended format fallback in decode_by_command). + +use dashcore::network::message::NetworkMessage; + +/// Extended message marker (12-byte ASCII command follows). +pub const MSG_ID_EXTENDED: u8 = 0; + +// ============================================================================= +// Standard BIP324 short message IDs (1-28) +// Matches Dash Core's V2_BITCOIN_IDS array +// ============================================================================= +pub const MSG_ID_ADDR: u8 = 1; +pub const MSG_ID_BLOCK: u8 = 2; +pub const MSG_ID_BLOCKTXN: u8 = 3; +pub const MSG_ID_CMPCTBLOCK: u8 = 4; +// ID 5 is reserved for FEEFILTER but not implemented in Dash +pub const MSG_ID_FILTERADD: u8 = 6; +pub const MSG_ID_FILTERCLEAR: u8 = 7; +pub const MSG_ID_FILTERLOAD: u8 = 8; +pub const MSG_ID_GETBLOCKS: u8 = 9; +pub const MSG_ID_GETBLOCKTXN: u8 = 10; +pub const MSG_ID_GETDATA: u8 = 11; +pub const MSG_ID_GETHEADERS: u8 = 12; +pub const MSG_ID_HEADERS: u8 = 13; +pub const MSG_ID_INV: u8 = 14; +pub const MSG_ID_MEMPOOL: u8 = 15; +pub const MSG_ID_MERKLEBLOCK: u8 = 16; +pub const MSG_ID_NOTFOUND: u8 = 17; +pub const MSG_ID_PING: u8 = 18; +pub const MSG_ID_PONG: u8 = 19; +pub const MSG_ID_SENDCMPCT: u8 = 20; +pub const MSG_ID_TX: u8 = 21; +pub const MSG_ID_GETCFILTERS: u8 = 22; +pub const MSG_ID_CFILTER: u8 = 23; +pub const MSG_ID_GETCFHEADERS: u8 = 24; +pub const MSG_ID_CFHEADERS: u8 = 25; +pub const MSG_ID_GETCFCHECKPT: u8 = 26; +pub const MSG_ID_CFCHECKPT: u8 = 27; +pub const MSG_ID_ADDRV2: u8 = 28; +// IDs 29-32 are reserved but unimplemented in BIP324 + +// ============================================================================= +// Dash-specific short message IDs (128-167) +// Matches Dash Core's V2_DASH_IDS array +// ============================================================================= +pub const MSG_ID_SPORK: u8 = 128; +pub const MSG_ID_GETSPORKS: u8 = 129; +pub const MSG_ID_SENDDSQUEUE: u8 = 130; +pub const MSG_ID_DSACCEPT: u8 = 131; +pub const MSG_ID_DSVIN: u8 = 132; +pub const MSG_ID_DSFINALTX: u8 = 133; +pub const MSG_ID_DSSIGNFINALTX: u8 = 134; +pub const MSG_ID_DSCOMPLETE: u8 = 135; +pub const MSG_ID_DSSTATUSUPDATE: u8 = 136; +pub const MSG_ID_DSTX: u8 = 137; +pub const MSG_ID_DSQUEUE: u8 = 138; +pub const MSG_ID_SYNCSTATUSCOUNT: u8 = 139; +pub const MSG_ID_MNGOVERNANCESYNC: u8 = 140; +pub const MSG_ID_MNGOVERNANCEOBJECT: u8 = 141; +pub const MSG_ID_MNGOVERNANCEOBJECTVOTE: u8 = 142; +pub const MSG_ID_GETMNLISTDIFF: u8 = 143; +pub const MSG_ID_MNLISTDIFF: u8 = 144; +pub const MSG_ID_QSENDRECSIGS: u8 = 145; +pub const MSG_ID_QFCOMMITMENT: u8 = 146; +pub const MSG_ID_QCONTRIB: u8 = 147; +pub const MSG_ID_QCOMPLAINT: u8 = 148; +pub const MSG_ID_QJUSTIFICATION: u8 = 149; +pub const MSG_ID_QPCOMMITMENT: u8 = 150; +pub const MSG_ID_QWATCH: u8 = 151; +pub const MSG_ID_QSIGSESANN: u8 = 152; +pub const MSG_ID_QSIGSHARESINV: u8 = 153; +pub const MSG_ID_QGETSIGSHARES: u8 = 154; +pub const MSG_ID_QBSIGSHARES: u8 = 155; +pub const MSG_ID_QSIGREC: u8 = 156; +pub const MSG_ID_QSIGSHARE: u8 = 157; +pub const MSG_ID_QGETDATA: u8 = 158; +pub const MSG_ID_QDATA: u8 = 159; +pub const MSG_ID_CLSIG: u8 = 160; +pub const MSG_ID_ISDLOCK: u8 = 161; +pub const MSG_ID_MNAUTH: u8 = 162; +pub const MSG_ID_GETHEADERS2: u8 = 163; +pub const MSG_ID_SENDHEADERS2: u8 = 164; +pub const MSG_ID_HEADERS2: u8 = 165; +pub const MSG_ID_GETQUORUMROTATIONINFO: u8 = 166; +pub const MSG_ID_QUORUMROTATIONINFO: u8 = 167; + +/// Get the short message ID for a NetworkMessage, if one exists. +/// +/// Returns `Some(id)` for common messages that have short IDs, +/// or `None` for messages that require extended format. +pub fn network_message_to_short_id(msg: &NetworkMessage) -> Option { + match msg { + // Standard BIP324 messages + NetworkMessage::Addr(_) => Some(MSG_ID_ADDR), + NetworkMessage::Block(_) => Some(MSG_ID_BLOCK), + NetworkMessage::BlockTxn(_) => Some(MSG_ID_BLOCKTXN), + NetworkMessage::CmpctBlock(_) => Some(MSG_ID_CMPCTBLOCK), + // Note: FeeFilter is ID 5 in BIP324 but not implemented in Dash + NetworkMessage::FilterAdd(_) => Some(MSG_ID_FILTERADD), + NetworkMessage::FilterClear => Some(MSG_ID_FILTERCLEAR), + NetworkMessage::FilterLoad(_) => Some(MSG_ID_FILTERLOAD), + NetworkMessage::GetBlocks(_) => Some(MSG_ID_GETBLOCKS), + NetworkMessage::GetBlockTxn(_) => Some(MSG_ID_GETBLOCKTXN), + NetworkMessage::GetData(_) => Some(MSG_ID_GETDATA), + NetworkMessage::GetHeaders(_) => Some(MSG_ID_GETHEADERS), + NetworkMessage::Headers(_) => Some(MSG_ID_HEADERS), + NetworkMessage::Inv(_) => Some(MSG_ID_INV), + NetworkMessage::MemPool => Some(MSG_ID_MEMPOOL), + NetworkMessage::MerkleBlock(_) => Some(MSG_ID_MERKLEBLOCK), + NetworkMessage::NotFound(_) => Some(MSG_ID_NOTFOUND), + NetworkMessage::Ping(_) => Some(MSG_ID_PING), + NetworkMessage::Pong(_) => Some(MSG_ID_PONG), + NetworkMessage::SendCmpct(_) => Some(MSG_ID_SENDCMPCT), + NetworkMessage::Tx(_) => Some(MSG_ID_TX), + NetworkMessage::GetCFilters(_) => Some(MSG_ID_GETCFILTERS), + NetworkMessage::CFilter(_) => Some(MSG_ID_CFILTER), + NetworkMessage::GetCFHeaders(_) => Some(MSG_ID_GETCFHEADERS), + NetworkMessage::CFHeaders(_) => Some(MSG_ID_CFHEADERS), + NetworkMessage::GetCFCheckpt(_) => Some(MSG_ID_GETCFCHECKPT), + NetworkMessage::CFCheckpt(_) => Some(MSG_ID_CFCHECKPT), + NetworkMessage::AddrV2(_) => Some(MSG_ID_ADDRV2), + + // Dash-specific messages (only variants that exist in dashcore) + NetworkMessage::SendDsq(_) => Some(MSG_ID_SENDDSQUEUE), + NetworkMessage::GetMnListD(_) => Some(MSG_ID_GETMNLISTDIFF), + NetworkMessage::MnListDiff(_) => Some(MSG_ID_MNLISTDIFF), + NetworkMessage::CLSig(_) => Some(MSG_ID_CLSIG), + NetworkMessage::ISLock(_) => Some(MSG_ID_ISDLOCK), + NetworkMessage::GetHeaders2(_) => Some(MSG_ID_GETHEADERS2), + NetworkMessage::SendHeaders2 => Some(MSG_ID_SENDHEADERS2), + NetworkMessage::Headers2(_) => Some(MSG_ID_HEADERS2), + NetworkMessage::GetQRInfo(_) => Some(MSG_ID_GETQUORUMROTATIONINFO), + NetworkMessage::QRInfo(_) => Some(MSG_ID_QUORUMROTATIONINFO), + + // All other messages use extended format + _ => None, + } +} + +/// Get the command string for a short message ID. +/// +/// Returns `Some(command)` for valid short IDs, +/// or `None` for unknown IDs. +pub fn short_id_to_command(id: u8) -> Option<&'static str> { + match id { + // Standard BIP324 messages + MSG_ID_ADDR => Some("addr"), + MSG_ID_BLOCK => Some("block"), + MSG_ID_BLOCKTXN => Some("blocktxn"), + MSG_ID_CMPCTBLOCK => Some("cmpctblock"), + MSG_ID_FILTERADD => Some("filteradd"), + MSG_ID_FILTERCLEAR => Some("filterclear"), + MSG_ID_FILTERLOAD => Some("filterload"), + MSG_ID_GETBLOCKS => Some("getblocks"), + MSG_ID_GETBLOCKTXN => Some("getblocktxn"), + MSG_ID_GETDATA => Some("getdata"), + MSG_ID_GETHEADERS => Some("getheaders"), + MSG_ID_HEADERS => Some("headers"), + MSG_ID_INV => Some("inv"), + MSG_ID_MEMPOOL => Some("mempool"), + MSG_ID_MERKLEBLOCK => Some("merkleblock"), + MSG_ID_NOTFOUND => Some("notfound"), + MSG_ID_PING => Some("ping"), + MSG_ID_PONG => Some("pong"), + MSG_ID_SENDCMPCT => Some("sendcmpct"), + MSG_ID_TX => Some("tx"), + MSG_ID_GETCFILTERS => Some("getcfilters"), + MSG_ID_CFILTER => Some("cfilter"), + MSG_ID_GETCFHEADERS => Some("getcfheaders"), + MSG_ID_CFHEADERS => Some("cfheaders"), + MSG_ID_GETCFCHECKPT => Some("getcfcheckpt"), + MSG_ID_CFCHECKPT => Some("cfcheckpt"), + MSG_ID_ADDRV2 => Some("addrv2"), + + // Dash-specific messages + MSG_ID_SPORK => Some("spork"), + MSG_ID_GETSPORKS => Some("getsporks"), + MSG_ID_SENDDSQUEUE => Some("senddsq"), + MSG_ID_DSACCEPT => Some("dsa"), + MSG_ID_DSVIN => Some("dsi"), + MSG_ID_DSFINALTX => Some("dsf"), + MSG_ID_DSSIGNFINALTX => Some("dss"), + MSG_ID_DSCOMPLETE => Some("dsc"), + MSG_ID_DSSTATUSUPDATE => Some("dssu"), + MSG_ID_DSTX => Some("dstx"), + MSG_ID_DSQUEUE => Some("dsq"), + MSG_ID_SYNCSTATUSCOUNT => Some("ssc"), + MSG_ID_MNGOVERNANCESYNC => Some("govsync"), + MSG_ID_MNGOVERNANCEOBJECT => Some("govobj"), + MSG_ID_MNGOVERNANCEOBJECTVOTE => Some("govobjvote"), + MSG_ID_GETMNLISTDIFF => Some("getmnlistd"), + MSG_ID_MNLISTDIFF => Some("mnlistdiff"), + MSG_ID_QSENDRECSIGS => Some("qsendrecsigs"), + MSG_ID_QFCOMMITMENT => Some("qfcommit"), + MSG_ID_QCONTRIB => Some("qcontrib"), + MSG_ID_QCOMPLAINT => Some("qcomplaint"), + MSG_ID_QJUSTIFICATION => Some("qjustify"), + MSG_ID_QPCOMMITMENT => Some("qpcommit"), + MSG_ID_QWATCH => Some("qwatch"), + MSG_ID_QSIGSESANN => Some("qsigsesann"), + MSG_ID_QSIGSHARESINV => Some("qsigsinv"), + MSG_ID_QGETSIGSHARES => Some("qgetsigs"), + MSG_ID_QBSIGSHARES => Some("qbsigs"), + MSG_ID_QSIGREC => Some("qsigrec"), + MSG_ID_QSIGSHARE => Some("qsigshare"), + MSG_ID_QGETDATA => Some("qgetdata"), + MSG_ID_QDATA => Some("qdata"), + MSG_ID_CLSIG => Some("clsig"), + MSG_ID_ISDLOCK => Some("isdlock"), + MSG_ID_MNAUTH => Some("mnauth"), + MSG_ID_GETHEADERS2 => Some("getheaders2"), + MSG_ID_SENDHEADERS2 => Some("sendheaders2"), + MSG_ID_HEADERS2 => Some("headers2"), + MSG_ID_GETQUORUMROTATIONINFO => Some("getqrinfo"), + MSG_ID_QUORUMROTATIONINFO => Some("qrinfo"), + + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ping_pong_ids() { + assert_eq!(network_message_to_short_id(&NetworkMessage::Ping(0)), Some(MSG_ID_PING)); + assert_eq!(network_message_to_short_id(&NetworkMessage::Pong(0)), Some(MSG_ID_PONG)); + } + + #[test] + fn test_short_id_to_command() { + assert_eq!(short_id_to_command(MSG_ID_PING), Some("ping")); + assert_eq!(short_id_to_command(MSG_ID_PONG), Some("pong")); + assert_eq!(short_id_to_command(MSG_ID_BLOCK), Some("block")); + assert_eq!(short_id_to_command(255), None); + } + + #[test] + fn test_dash_short_ids() { + // Test Dash-specific short IDs + assert_eq!(short_id_to_command(MSG_ID_SPORK), Some("spork")); + assert_eq!(short_id_to_command(MSG_ID_SENDDSQUEUE), Some("senddsq")); + assert_eq!(short_id_to_command(MSG_ID_CLSIG), Some("clsig")); + assert_eq!(short_id_to_command(MSG_ID_ISDLOCK), Some("isdlock")); + assert_eq!(short_id_to_command(MSG_ID_MNLISTDIFF), Some("mnlistdiff")); + assert_eq!(short_id_to_command(MSG_ID_HEADERS2), Some("headers2")); + } + + #[test] + fn test_extended_format_for_non_short_id_messages() { + // Version is not a short ID message + use dashcore::network::address::Address; + use dashcore::network::constants::ServiceFlags; + use dashcore::network::message_network::VersionMessage; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + let addr = Address::new( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8333), + ServiceFlags::NONE, + ); + + let version = VersionMessage { + version: 70015, + services: ServiceFlags::NONE, + timestamp: 0, + receiver: addr.clone(), + sender: addr, + nonce: 0, + user_agent: "/test/".to_string(), + start_height: 0, + relay: false, + mn_auth_challenge: [0u8; 32], + masternode_connection: false, + }; + + assert!(network_message_to_short_id(&NetworkMessage::Version(version)).is_none()); + } + + #[test] + fn test_short_id_to_command_bidirectional_consistency() { + // For messages that have short IDs, verify the command string matches + // what dashcore returns via cmd() + let test_cases: Vec<(NetworkMessage, u8)> = vec![ + (NetworkMessage::Ping(0), MSG_ID_PING), + (NetworkMessage::Pong(0), MSG_ID_PONG), + (NetworkMessage::Inv(vec![]), MSG_ID_INV), + (NetworkMessage::GetData(vec![]), MSG_ID_GETDATA), + (NetworkMessage::MemPool, MSG_ID_MEMPOOL), + (NetworkMessage::FilterClear, MSG_ID_FILTERCLEAR), + (NetworkMessage::SendHeaders2, MSG_ID_SENDHEADERS2), + (NetworkMessage::SendDsq(false), MSG_ID_SENDDSQUEUE), + ]; + + for (msg, expected_id) in test_cases { + // Verify network_message_to_short_id returns the expected ID + let short_id = network_message_to_short_id(&msg); + assert_eq!( + short_id, + Some(expected_id), + "Message {} should have short ID {}", + msg.cmd(), + expected_id + ); + + // Verify short_id_to_command returns the correct command + let cmd = short_id_to_command(expected_id); + assert_eq!( + cmd, + Some(msg.cmd()), + "Short ID {} should map to command '{}'", + expected_id, + msg.cmd() + ); + } + } +} diff --git a/dash-spv/src/network/transport/mod.rs b/dash-spv/src/network/transport/mod.rs new file mode 100644 index 000000000..9af05d2fc --- /dev/null +++ b/dash-spv/src/network/transport/mod.rs @@ -0,0 +1,71 @@ +//! Transport layer abstraction for Dash P2P connections. +//! +//! This module provides a `Transport` trait that abstracts the underlying +//! communication protocol (V1 unencrypted or V2 BIP324 encrypted). + +pub mod message_ids; +pub mod v1; +pub mod v2; +pub mod v2_handshake; + +use async_trait::async_trait; +use dashcore::network::message::NetworkMessage; + +use crate::error::NetworkResult; + +pub use v1::V1Transport; +pub use v2::V2Transport; +pub use v2_handshake::{V2HandshakeManager, V2HandshakeResult, V2Session}; + +/// Transport preference for peer connections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TransportPreference { + /// Use V2 encrypted transport only (fail if peer doesn't support). + V2Only, + /// Prefer V2 encrypted transport, fallback to V1 if needed (default). + #[default] + V2Preferred, + /// Use V1 unencrypted transport only (for compatibility testing). + V1Only, +} + +/// Abstract transport layer for P2P communication. +/// +/// This trait is implemented by both V1Transport (unencrypted) and +/// V2Transport (BIP324 encrypted) to provide a unified interface +/// for message exchange. +#[async_trait] +pub trait Transport: Send + Sync { + /// Send a network message over the transport. + /// + /// # Arguments + /// * `message` - The network message to send + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(NetworkError)` on failure + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()>; + + /// Receive a network message from the transport. + /// + /// # Returns + /// * `Ok(Some(message))` if a complete message was received + /// * `Ok(None)` if no complete message is available yet (non-blocking) + /// * `Err(NetworkError)` on failure or disconnection + async fn receive_message(&mut self) -> NetworkResult>; + + /// Check if the transport is connected. + fn is_connected(&self) -> bool; + + /// Get the transport protocol version (1 or 2). + fn protocol_version(&self) -> u8; + + /// Get the number of bytes sent over this transport. + fn bytes_sent(&self) -> u64; + + /// Get the number of bytes received over this transport. + fn bytes_received(&self) -> u64; + + /// Shutdown the transport connection. + async fn shutdown(&mut self) -> NetworkResult<()>; +} diff --git a/dash-spv/src/network/transport/v1.rs b/dash-spv/src/network/transport/v1.rs new file mode 100644 index 000000000..0d331c462 --- /dev/null +++ b/dash-spv/src/network/transport/v1.rs @@ -0,0 +1,478 @@ +//! V1 Transport - Unencrypted Dash P2P protocol transport. +//! +//! This implements the traditional Bitcoin/Dash P2P message framing: +//! - 4 bytes: Network magic +//! - 12 bytes: Command string +//! - 4 bytes: Payload length (little-endian) +//! - 4 bytes: Checksum (first 4 bytes of SHA256d of payload) +//! - Variable: Payload + +use std::net::SocketAddr; + +use async_trait::async_trait; +use dashcore::consensus::{encode, Decodable}; +use dashcore::network::message::{NetworkMessage, RawNetworkMessage, MAX_MSG_SIZE}; +use dashcore::Network; +use dashcore_hashes::Hash; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use super::Transport; +use crate::error::{NetworkError, NetworkResult}; + +/// Header length for V1 protocol: magic(4) + command(12) + length(4) + checksum(4) +const HEADER_LEN: usize = 24; + +/// Maximum resync steps per receive call to prevent infinite loops. +const MAX_RESYNC_STEPS_PER_CALL: usize = 64; + +/// Read buffer size for TCP reads. +const READ_BUFFER_SIZE: usize = 8192; + +/// V1 Transport implementation for unencrypted P2P communication. +pub struct V1Transport { + /// The underlying TCP stream. + stream: TcpStream, + /// Stateful message framing buffer. + framing_buffer: Vec, + /// Network for magic byte validation. + network: Network, + /// Remote peer address (for logging). + peer_address: SocketAddr, + /// Bytes sent counter. + bytes_sent: u64, + /// Bytes received counter. + bytes_received: u64, + /// Whether the connection is active. + connected: bool, + /// Consecutive resync counter (for telemetry). + consecutive_resyncs: u32, +} + +impl V1Transport { + /// Create a new V1 transport from an established TCP stream. + /// + /// # Arguments + /// * `stream` - An already-connected TCP stream + /// * `network` - The Dash network (for magic byte validation) + /// * `peer_address` - Remote peer address (for logging) + pub fn new(stream: TcpStream, network: Network, peer_address: SocketAddr) -> Self { + Self { + stream, + framing_buffer: Vec::with_capacity(READ_BUFFER_SIZE), + network, + peer_address, + bytes_sent: 0, + bytes_received: 0, + connected: true, + consecutive_resyncs: 0, + } + } + + /// Helper function to read some bytes into the framing buffer. + async fn read_some(&mut self) -> std::io::Result { + let mut tmp = [0u8; READ_BUFFER_SIZE]; + match self.stream.read(&mut tmp).await { + Ok(0) => Ok(0), + Ok(n) => { + self.framing_buffer.extend_from_slice(&tmp[..n]); + self.bytes_received += n as u64; + Ok(n) + } + Err(e) => Err(e), + } + } + + /// Get the consecutive resync count (for telemetry). + pub fn consecutive_resyncs(&self) -> u32 { + self.consecutive_resyncs + } +} + +#[async_trait] +impl Transport for V1Transport { + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + let raw_message = RawNetworkMessage { + magic: self.network.magic(), + payload: message, + }; + + let serialized = encode::serialize(&raw_message); + + // Log details for debugging headers2 issues + if matches!( + raw_message.payload, + NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_) + ) { + let msg_type = match raw_message.payload { + NetworkMessage::GetHeaders2(_) => "GetHeaders2", + NetworkMessage::GetHeaders(_) => "GetHeaders", + _ => "Unknown", + }; + tracing::debug!( + "V1Transport: Sending {} raw bytes (len={}): {:02x?}", + msg_type, + serialized.len(), + &serialized[..std::cmp::min(100, serialized.len())] + ); + } + + // Write with error handling + match self.stream.write_all(&serialized).await { + Ok(_) => { + // Flush to ensure data is sent immediately + if let Err(e) = self.stream.flush().await { + tracing::warn!( + "V1Transport: Failed to flush socket {}: {}", + self.peer_address, + e + ); + } + self.bytes_sent += serialized.len() as u64; + tracing::debug!( + "V1Transport: Sent message to {}: {:?}", + self.peer_address, + raw_message.payload + ); + Ok(()) + } + Err(e) => { + tracing::warn!( + "V1Transport: Disconnecting {} due to write error: {}", + self.peer_address, + e + ); + self.connected = false; + Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) + } + } + } + + async fn receive_message(&mut self) -> NetworkResult> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + let magic_bytes = self.network.magic().to_le_bytes(); + let mut resync_steps = 0usize; + + loop { + // Ensure header availability + if self.framing_buffer.len() < HEADER_LEN { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(ref e) + if e.kind() == std::io::ErrorKind::ConnectionAborted + || e.kind() == std::io::ErrorKind::ConnectionReset => + { + tracing::info!( + "V1Transport: Peer {} connection reset/aborted", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Align to magic + if self.framing_buffer.len() >= 4 && self.framing_buffer[..4] != magic_bytes { + if let Some(pos) = self.framing_buffer.windows(4).position(|w| w == magic_bytes) { + if pos > 0 { + tracing::warn!( + "V1Transport {}: stream desync: skipping {} stray bytes before magic", + self.peer_address, + pos + ); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + self.framing_buffer.drain(0..pos); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + } else { + // Keep last 3 bytes of potential magic prefix + if self.framing_buffer.len() > 3 { + let dropped = self.framing_buffer.len() - 3; + tracing::warn!( + "V1Transport {}: stream desync: dropping {} bytes (no magic found)", + self.peer_address, + dropped + ); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + self.framing_buffer.drain(0..dropped); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + } + // Need more data + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!( + "Read failed: {}", + e + ))); + } + } + continue; + } + } + + // Ensure full header + if self.framing_buffer.len() < HEADER_LEN { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + continue; + } + + // Parse header fields + let length_le = u32::from_le_bytes([ + self.framing_buffer[16], + self.framing_buffer[17], + self.framing_buffer[18], + self.framing_buffer[19], + ]) as usize; + let header_checksum = [ + self.framing_buffer[20], + self.framing_buffer[21], + self.framing_buffer[22], + self.framing_buffer[23], + ]; + + // Validate announced length to prevent unbounded accumulation or overflow + if length_le > MAX_MSG_SIZE { + return Err(NetworkError::ProtocolError(format!( + "Declared payload length {} exceeds MAX_MSG_SIZE {}", + length_le, MAX_MSG_SIZE + ))); + } + let total_len = match HEADER_LEN.checked_add(length_le) { + Some(v) => v, + None => { + return Err(NetworkError::ProtocolError("Message length overflow".to_string())); + } + }; + + // Ensure full frame available + if self.framing_buffer.len() < total_len { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + continue; + } + + // Verify checksum + let payload_slice = &self.framing_buffer[HEADER_LEN..total_len]; + let expected = { + let checksum = dashcore_hashes::sha256d::Hash::hash(payload_slice); + [checksum[0], checksum[1], checksum[2], checksum[3]] + }; + if expected != header_checksum { + tracing::warn!( + "V1Transport: Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", + self.peer_address, + expected, + header_checksum + ); + if header_checksum == [0, 0, 0, 0] { + tracing::warn!( + "V1Transport: All-zeros checksum detected from {}, likely corrupted stream - resyncing", + self.peer_address + ); + } + // Resync by dropping a byte and retrying + self.framing_buffer.drain(0..1); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + + // Decode full RawNetworkMessage from the frame using existing decoder + let mut cursor = std::io::Cursor::new(&self.framing_buffer[..total_len]); + match RawNetworkMessage::consensus_decode(&mut cursor) { + Ok(raw_message) => { + // Consume bytes + self.framing_buffer.drain(0..total_len); + self.consecutive_resyncs = 0; + + // Validate magic matches our network + if raw_message.magic != self.network.magic() { + tracing::warn!( + "V1Transport: Received message with wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic + ); + return Err(NetworkError::ProtocolError(format!( + "Wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic + ))); + } + + tracing::trace!( + "V1Transport: Successfully decoded message from {}: {:?}", + self.peer_address, + raw_message.payload.cmd() + ); + + if raw_message.payload.cmd() == "headers2" { + tracing::info!( + "V1Transport: Received Headers2 message from {}!", + self.peer_address + ); + } + + if let NetworkMessage::Block(ref block) = raw_message.payload { + let block_hash = block.block_hash(); + tracing::info!( + "V1Transport: Successfully decoded block {} from {}", + block_hash, + self.peer_address + ); + } + + if let NetworkMessage::Headers2(ref headers2) = raw_message.payload { + tracing::info!( + "V1Transport: Successfully decoded Headers2 message from {} with {} compressed headers", + self.peer_address, + headers2.headers.len() + ); + } + + return Ok(Some(raw_message.payload)); + } + Err(e) => { + tracing::warn!( + "V1Transport {}: decode error after framing ({}), attempting resync", + self.peer_address, + e + ); + self.framing_buffer.drain(0..1); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + } + } + } + + fn is_connected(&self) -> bool { + self.connected + } + + fn protocol_version(&self) -> u8 { + 1 + } + + fn bytes_sent(&self) -> u64 { + self.bytes_sent + } + + fn bytes_received(&self) -> u64 { + self.bytes_received + } + + async fn shutdown(&mut self) -> NetworkResult<()> { + if self.connected { + let _ = self.stream.shutdown().await; + self.connected = false; + tracing::info!("V1Transport: Shutdown connection to {}", self.peer_address); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_len() { + // Verify our header length constant is correct + assert_eq!(HEADER_LEN, 4 + 12 + 4 + 4); // magic + command + length + checksum + } +} diff --git a/dash-spv/src/network/transport/v2.rs b/dash-spv/src/network/transport/v2.rs new file mode 100644 index 000000000..0018af23c --- /dev/null +++ b/dash-spv/src/network/transport/v2.rs @@ -0,0 +1,577 @@ +//! V2 Transport - BIP324 encrypted Dash P2P protocol transport. +//! +//! This implements the BIP324 encrypted transport protocol: +//! - 3 bytes: Encrypted length +//! - 1 byte: Header (flags, short message ID or 0x00 for extended) +//! - Variable: Contents (for extended format: 12-byte command + payload) +//! - 16 bytes: Authentication tag (ChaCha20-Poly1305) + +use std::net::SocketAddr; + +use async_trait::async_trait; +use bip324::{CipherSession, PacketType, NUM_LENGTH_BYTES}; +use dashcore::network::message::{NetworkMessage, MAX_MSG_SIZE}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use super::message_ids::{network_message_to_short_id, short_id_to_command, MSG_ID_EXTENDED}; +use super::Transport; +use crate::error::{NetworkError, NetworkResult}; + +/// Read buffer size for TCP reads. +const READ_BUFFER_SIZE: usize = 8192; + +/// Extended command length in bytes. +const COMMAND_LEN: usize = 12; + +/// V2 Transport implementation for BIP324 encrypted P2P communication. +pub struct V2Transport { + /// The underlying TCP stream. + stream: TcpStream, + /// The cipher session for encryption/decryption. + cipher: CipherSession, + /// Session ID for optional MitM verification. + session_id: [u8; 32], + /// Stateful receive buffer for partial reads. + receive_buffer: Vec, + /// Remote peer address (for logging). + peer_address: SocketAddr, + /// Bytes sent counter. + bytes_sent: u64, + /// Bytes received counter. + bytes_received: u64, + /// Whether the connection is active. + connected: bool, + /// Cached decrypted packet length (to avoid re-decrypting on partial reads). + /// This is needed because `decrypt_packet_len` advances the cipher state. + pending_packet_len: Option, +} + +impl V2Transport { + /// Create a new V2 transport from a successful handshake. + /// + /// # Arguments + /// * `stream` - The TCP stream (ownership transferred from handshake) + /// * `cipher` - The cipher session for encryption/decryption + /// * `session_id` - Session ID for optional MitM verification + /// * `peer_address` - Remote peer address (for logging) + pub fn new( + stream: TcpStream, + cipher: CipherSession, + session_id: [u8; 32], + peer_address: SocketAddr, + ) -> Self { + Self { + stream, + cipher, + session_id, + receive_buffer: Vec::with_capacity(READ_BUFFER_SIZE), + peer_address, + bytes_sent: 0, + bytes_received: 0, + connected: true, + pending_packet_len: None, + } + } + + /// Get the session ID for optional out-of-band MitM verification. + pub fn session_id(&self) -> &[u8; 32] { + &self.session_id + } + + /// Encode a NetworkMessage into V2 plaintext format. + /// + /// Format: + /// - Short format (common messages): payload bytes (header byte added by cipher) + /// - Extended format (Dash-specific): 12-byte command + payload bytes + fn encode_message(&self, message: &NetworkMessage) -> NetworkResult> { + // Serialize the message payload using dashcore's canonical serialization + let payload = message.consensus_encode_payload(); + + // Check for short message ID + if let Some(short_id) = network_message_to_short_id(message) { + // Short format: just the short ID byte followed by payload + // The short ID will be put in the header byte by the cipher + // So we return: [short_id] + payload + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(short_id); + plaintext.extend_from_slice(&payload); + Ok(plaintext) + } else { + // Extended format: 0x00 header + 12-byte command + payload + let cmd = message.cmd(); + let cmd_bytes = cmd.as_bytes(); + + // Create 12-byte null-padded command + let mut command = [0u8; COMMAND_LEN]; + let copy_len = std::cmp::min(cmd_bytes.len(), COMMAND_LEN); + command[..copy_len].copy_from_slice(&cmd_bytes[..copy_len]); + + let mut plaintext = Vec::with_capacity(1 + COMMAND_LEN + payload.len()); + plaintext.push(MSG_ID_EXTENDED); // 0x00 marker for extended format + plaintext.extend_from_slice(&command); + plaintext.extend_from_slice(&payload); + Ok(plaintext) + } + } + + /// Decode a V2 plaintext into a NetworkMessage. + /// + /// # Arguments + /// * `plaintext` - Decrypted plaintext (header byte + optional command + payload) + fn decode_message(&self, plaintext: &[u8]) -> NetworkResult { + // The bip324 crate prepends a "packet type" byte (0 for Genuine, 128 for Decoy) + // Our actual message ID/content starts at byte 1 + if plaintext.len() < 2 { + return Err(NetworkError::ProtocolError("V2 message too short".to_string())); + } + + // Byte 0 is the crate's packet type indicator (always 0 for genuine messages) + // Byte 1 is our actual message ID (short ID or 0 for extended format) + let _crate_header = plaintext[0]; // Should be 0 for genuine, 128 for decoy + let message_id = plaintext[1]; + + // Trace: log first bytes of decrypted plaintext (verbose, for debugging only) + let preview_len = std::cmp::min(20, plaintext.len()); + tracing::trace!( + "V2Transport: Decrypted message preview ({} bytes total): {:02x?}, message_id={}", + plaintext.len(), + &plaintext[..preview_len], + message_id + ); + + let (cmd, payload) = if message_id == MSG_ID_EXTENDED { + // Extended format: 12-byte command + payload (starting at byte 2) + if plaintext.len() < 2 + COMMAND_LEN { + return Err(NetworkError::ProtocolError( + "V2 extended message too short".to_string(), + )); + } + + let command_bytes = &plaintext[2..2 + COMMAND_LEN]; + let payload = &plaintext[2 + COMMAND_LEN..]; + + // Find null terminator in command + let cmd_end = command_bytes.iter().position(|&b| b == 0).unwrap_or(COMMAND_LEN); + let cmd = std::str::from_utf8(&command_bytes[..cmd_end]).map_err(|_| { + NetworkError::ProtocolError("Invalid UTF-8 in V2 command".to_string()) + })?; + + tracing::trace!( + "V2Transport: Decoding extended format message '{}' ({} bytes payload) from {}", + cmd, + payload.len(), + self.peer_address + ); + + (cmd, payload) + } else { + // Short format: message_id is the short message ID, payload starts at byte 2 + let payload = &plaintext[2..]; + + let cmd = short_id_to_command(message_id).ok_or_else(|| { + NetworkError::ProtocolError(format!("Unknown V2 short message ID: {}", message_id)) + })?; + + tracing::trace!( + "V2Transport: Decoding short format message '{}' (ID={}, {} bytes payload) from {}", + cmd, + message_id, + payload.len(), + self.peer_address + ); + + (cmd, payload) + }; + + // Decode the NetworkMessage using dashcore's canonical decoder + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode '{}': {}", cmd, e))) + } + + /// Helper function to read some bytes into the receive buffer. + async fn read_some(&mut self) -> std::io::Result { + let mut tmp = [0u8; READ_BUFFER_SIZE]; + match self.stream.read(&mut tmp).await { + Ok(0) => Ok(0), + Ok(n) => { + self.receive_buffer.extend_from_slice(&tmp[..n]); + self.bytes_received += n as u64; + Ok(n) + } + Err(e) => Err(e), + } + } +} + +#[async_trait] +impl Transport for V2Transport { + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + // Encode the message to V2 plaintext format + let plaintext = self.encode_message(&message)?; + + tracing::debug!( + "V2Transport: Encoding message {:?} ({} bytes plaintext) for {}", + message.cmd(), + plaintext.len(), + self.peer_address + ); + + // Encrypt the message + // Note: The bip324 crate handles the header byte internally, but we're + // putting our message type in the plaintext, so we use Genuine packet type + let encrypted = + self.cipher.outbound().encrypt_to_vec(&plaintext, PacketType::Genuine, None); + + // Write the encrypted packet + match self.stream.write_all(&encrypted).await { + Ok(_) => { + // Flush to ensure data is sent immediately + if let Err(e) = self.stream.flush().await { + tracing::warn!( + "V2Transport: Failed to flush socket {}: {}", + self.peer_address, + e + ); + } + self.bytes_sent += encrypted.len() as u64; + tracing::debug!( + "V2Transport: Sent encrypted message to {}: {:?} ({} bytes)", + self.peer_address, + message.cmd(), + encrypted.len() + ); + Ok(()) + } + Err(e) => { + tracing::warn!( + "V2Transport: Disconnecting {} due to write error: {}", + self.peer_address, + e + ); + self.connected = false; + Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) + } + } + } + + async fn receive_message(&mut self) -> NetworkResult> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + loop { + // Step 1: Ensure we have at least 3 bytes for the length + while self.receive_buffer.len() < NUM_LENGTH_BYTES { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V2Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(ref e) + if e.kind() == std::io::ErrorKind::ConnectionAborted + || e.kind() == std::io::ErrorKind::ConnectionReset => + { + tracing::info!( + "V2Transport: Peer {} connection reset/aborted", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Step 2: Decrypt the length (only if we haven't already for this packet) + // IMPORTANT: decrypt_packet_len advances the cipher state, so we must + // cache the result if we don't have enough bytes for the full packet yet. + let packet_len = if let Some(cached_len) = self.pending_packet_len { + cached_len + } else { + let len_bytes: [u8; NUM_LENGTH_BYTES] = + self.receive_buffer[..NUM_LENGTH_BYTES].try_into().expect("3 bytes for length"); + + // Note: decrypt_packet_len returns the length of remaining data to read + // (header + contents + tag), NOT just the contents length + let decrypted_len = self.cipher.inbound().decrypt_packet_len(len_bytes); + + // Validate packet length + if decrypted_len > MAX_MSG_SIZE + 1 + 16 { + // MAX_MSG_SIZE + header + tag + return Err(NetworkError::ProtocolError(format!( + "V2 packet too large: {} bytes", + decrypted_len + ))); + } + + // Cache the length in case we need to return early + self.pending_packet_len = Some(decrypted_len); + decrypted_len + }; + + let total_len = NUM_LENGTH_BYTES + packet_len; + + // Step 3: Ensure we have the complete packet + while self.receive_buffer.len() < total_len { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V2Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Step 4: Extract and decrypt the packet (excluding length bytes which are already consumed) + let ciphertext = &self.receive_buffer[NUM_LENGTH_BYTES..total_len]; + + let (packet_type, plaintext) = + self.cipher.inbound().decrypt_to_vec(ciphertext, None).map_err(|e| { + NetworkError::V2DecryptionFailed(format!("Decryption failed: {}", e)) + })?; + + // Consume the packet from the buffer and clear cached length + self.receive_buffer.drain(0..total_len); + self.pending_packet_len = None; + + // Step 5: Handle decoy packets + if packet_type == PacketType::Decoy { + tracing::debug!( + "V2Transport: Received decoy packet from {}, ignoring", + self.peer_address + ); + continue; // Read next packet + } + + // Step 6: Decode the message + // Note: plaintext includes the header byte at position 0 + let message = self.decode_message(&plaintext)?; + + tracing::trace!( + "V2Transport: Successfully decoded message from {}: {:?}", + self.peer_address, + message.cmd() + ); + + return Ok(Some(message)); + } + } + + fn is_connected(&self) -> bool { + self.connected + } + + fn protocol_version(&self) -> u8 { + 2 + } + + fn bytes_sent(&self) -> u64 { + self.bytes_sent + } + + fn bytes_received(&self) -> u64 { + self.bytes_received + } + + async fn shutdown(&mut self) -> NetworkResult<()> { + if self.connected { + let _ = self.stream.shutdown().await; + self.connected = false; + tracing::info!("V2Transport: Shutdown connection to {}", self.peer_address); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_command_len() { + // Verify command length constant + assert_eq!(COMMAND_LEN, 12); + } + + #[test] + fn test_short_id_encoding() { + // Verify ping/pong use short IDs + assert!(network_message_to_short_id(&NetworkMessage::Ping(0)).is_some()); + assert!(network_message_to_short_id(&NetworkMessage::Pong(0)).is_some()); + } + + /// Helper: Encode a message the same way V2Transport::encode_message does + fn test_encode_v2_message(message: &NetworkMessage) -> Vec { + let payload = message.consensus_encode_payload(); + + if let Some(short_id) = network_message_to_short_id(message) { + // Short format: [short_id] + payload + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(short_id); + plaintext.extend_from_slice(&payload); + plaintext + } else { + // Extended format: [0x00] + [12-byte command] + payload + let cmd = message.cmd(); + let cmd_bytes = cmd.as_bytes(); + let mut command = [0u8; COMMAND_LEN]; + let copy_len = std::cmp::min(cmd_bytes.len(), COMMAND_LEN); + command[..copy_len].copy_from_slice(&cmd_bytes[..copy_len]); + + let mut plaintext = Vec::with_capacity(1 + COMMAND_LEN + payload.len()); + plaintext.push(MSG_ID_EXTENDED); + plaintext.extend_from_slice(&command); + plaintext.extend_from_slice(&payload); + plaintext + } + } + + /// Helper: Decode a V2 message with simulated cipher header byte + fn test_decode_v2_message(plaintext: &[u8]) -> Result { + // Simulate: prepend a packet type byte (0 for genuine) like the cipher does + let mut with_header = vec![0u8]; // Packet type = genuine + with_header.extend_from_slice(plaintext); + + if with_header.len() < 2 { + return Err(NetworkError::ProtocolError("V2 message too short".to_string())); + } + + let message_id = with_header[1]; + + if message_id == MSG_ID_EXTENDED { + // Extended format + if with_header.len() < 2 + COMMAND_LEN { + return Err(NetworkError::ProtocolError( + "Extended format message too short".to_string(), + )); + } + + let command_bytes = &with_header[2..2 + COMMAND_LEN]; + let cmd = std::str::from_utf8(command_bytes) + .map_err(|_| NetworkError::ProtocolError("Invalid UTF-8 in command".to_string()))? + .trim_end_matches('\0'); + + let payload = &with_header[2 + COMMAND_LEN..]; + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode: {}", e))) + } else { + // Short format + let cmd = short_id_to_command(message_id).ok_or_else(|| { + NetworkError::ProtocolError(format!("Unknown short ID: {}", message_id)) + })?; + + let payload = &with_header[2..]; + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode: {}", e))) + } + } + + #[test] + fn test_short_id_round_trip_common_messages() { + // Messages that should use short format (1 byte ID) + let short_format_messages: Vec = vec![ + NetworkMessage::Ping(0x1234567890abcdef), + NetworkMessage::Pong(0xfedcba0987654321), + NetworkMessage::Inv(vec![]), + NetworkMessage::GetData(vec![]), + NetworkMessage::NotFound(vec![]), + NetworkMessage::MemPool, + NetworkMessage::FilterClear, + NetworkMessage::SendHeaders2, + NetworkMessage::SendDsq(true), + ]; + + for original in &short_format_messages { + // Verify it uses short format (first byte is the short ID, not 0x00) + let encoded = test_encode_v2_message(original); + assert_ne!( + encoded[0], + MSG_ID_EXTENDED, + "{} should use short format, not extended", + original.cmd() + ); + + // Verify round-trip + let decoded = test_decode_v2_message(&encoded) + .expect(&format!("Failed to decode {}", original.cmd())); + assert_eq!(original, &decoded, "Round-trip failed for {} message", original.cmd()); + } + } + + #[test] + fn test_extended_format_round_trip() { + use dashcore::network::address::Address; + use dashcore::network::constants::ServiceFlags; + use dashcore::network::message_network::VersionMessage; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + let addr = Address::new( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8333), + ServiceFlags::NONE, + ); + + let version = VersionMessage { + version: 70015, + services: ServiceFlags::NONE, + timestamp: 0, + receiver: addr.clone(), + sender: addr, + nonce: 0, + user_agent: "/test/".to_string(), + start_height: 0, + relay: false, + mn_auth_challenge: [0u8; 32], + masternode_connection: false, + }; + + // Version message should use extended format (no short ID) + let original = NetworkMessage::Version(version); + + let encoded = test_encode_v2_message(&original); + + // Verify extended format: first byte should be 0x00 + assert_eq!(encoded[0], MSG_ID_EXTENDED, "Version message should use extended format"); + + // Verify command is in bytes 1-12 + let cmd_bytes = &encoded[1..1 + COMMAND_LEN]; + let cmd = std::str::from_utf8(cmd_bytes).unwrap().trim_end_matches('\0'); + assert_eq!(cmd, "version", "Command should be 'version'"); + + // Verify round-trip + let decoded = test_decode_v2_message(&encoded).expect("Failed to decode version message"); + assert_eq!(original, decoded, "Version round-trip failed"); + } +} diff --git a/dash-spv/src/network/transport/v2_handshake.rs b/dash-spv/src/network/transport/v2_handshake.rs new file mode 100644 index 000000000..e9cf2aae2 --- /dev/null +++ b/dash-spv/src/network/transport/v2_handshake.rs @@ -0,0 +1,437 @@ +//! V2 Handshake implementation for BIP324 encrypted transport. +//! +//! This module implements the BIP324 handshake protocol: +//! 1. Key Exchange: ElligatorSwift-encoded public keys + garbage data +//! 2. Version Negotiation: Encrypted version packets confirm mutual v2 support +//! +//! The handshake detects v1-only peers by checking if the first bytes +//! received match the network magic (indicating v1 protocol). +//! +//! ## Why Not Use `bip324::futures::handshake()`? +//! +//! The bip324 crate provides a high-level `futures::handshake()` function, but +//! it doesn't meet dash-spv's requirements: +//! +//! 1. **V1 Detection Strategy**: bip324 detects V1-only peers *after* reading the +//! 64-byte remote key (consuming the bytes). dash-spv uses `stream.peek()` to +//! detect V1 magic *without* consuming bytes, allowing the same TCP connection +//! to be reused for V1 fallback. +//! +//! 2. **Return Type Mismatch**: bip324 returns split ciphers and a wrapped +//! `ProtocolSessionReader`. dash-spv needs the original `TcpStream` back +//! plus a `CipherSession` for the transport layer. +//! +//! 3. **Timeout Handling**: bip324's async handshake has no built-in timeouts. +//! dash-spv needs per-operation and cumulative timeout handling. +//! +//! Therefore, we use bip324's low-level `Handshake` state machine with custom +//! async I/O wrappers that provide the control we need. + +use std::net::SocketAddr; +use std::time::Duration; + +use bip324::{CipherSession, GarbageResult, Handshake, Role, VersionResult}; +use dashcore::Network; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use crate::error::{NetworkError, NetworkResult}; + +/// Maximum garbage data size per BIP324 spec. +const MAX_GARBAGE_LEN: usize = 4095; + +/// Timeout for handshake operations. +const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); + +/// Size of ElligatorSwift public key. +const ELLIGATOR_SWIFT_KEY_SIZE: usize = 64; + +/// Size of garbage terminator. +const GARBAGE_TERMINATOR_SIZE: usize = 16; + +/// Result of the V2 handshake attempt. +pub enum V2HandshakeResult { + /// Successfully completed V2 handshake. + Success(Box), + /// Detected V1-only peer (first bytes matched network magic). + FallbackToV1, +} + +/// Session data from a successful V2 handshake. +pub struct V2Session { + /// The TCP stream (ownership transferred from handshake). + pub stream: TcpStream, + /// The cipher session for encryption/decryption. + pub cipher: CipherSession, + /// Session ID for optional out-of-band MitM verification. + pub session_id: [u8; 32], +} + +/// V2 Handshake manager for BIP324 encrypted connections. +pub struct V2HandshakeManager { + /// Network magic bytes for key derivation. + magic: [u8; 4], + /// Our role in the handshake (initiator or responder). + role: Role, + /// Peer address (for logging). + peer_address: SocketAddr, +} + +impl V2HandshakeManager { + /// Create a new handshake manager for initiating connections. + /// + /// The initiator sends the first message (their ElligatorSwift pubkey). + pub fn new_initiator(network: Network, peer_address: SocketAddr) -> Self { + Self { + magic: network.magic().to_le_bytes(), + role: Role::Initiator, + peer_address, + } + } + + /// Create a new handshake manager for responding to connections. + /// + /// The responder waits for the initiator's pubkey first. + pub fn new_responder(network: Network, peer_address: SocketAddr) -> Self { + Self { + magic: network.magic().to_le_bytes(), + role: Role::Responder, + peer_address, + } + } + + /// Perform the V2 handshake on the given TCP stream. + /// + /// # Arguments + /// * `stream` - A connected TCP stream + /// + /// # Returns + /// * `V2HandshakeResult::Success(session)` - Handshake completed successfully + /// * `V2HandshakeResult::FallbackToV1` - Detected v1-only peer + /// + /// # Errors + /// Returns `NetworkError` if the handshake fails (e.g., timeout, protocol error). + pub async fn perform_handshake( + self, + mut stream: TcpStream, + ) -> NetworkResult { + tracing::debug!("V2 handshake: Starting as {:?} with {}", self.role, self.peer_address); + + // Create the handshake state machine + let handshake = Handshake::new(self.magic, self.role).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to create handshake: {}", e)) + })?; + + // Step 1: Send our public key (no garbage for simplicity) + let mut send_key_buffer = vec![0u8; Handshake::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut send_key_buffer).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to prepare key: {}", e)) + })?; + + tracing::debug!( + "V2 handshake: Sending our ElligatorSwift pubkey ({} bytes) to {}", + send_key_buffer.len(), + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_key_buffer)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to send pubkey: {}", e)) + })?; + + stream + .flush() + .await + .map_err(|e| NetworkError::V2HandshakeFailed(format!("Failed to flush: {}", e)))?; + + // Step 2: Read the remote's public key (64 bytes) + // First, peek at the initial bytes to detect v1 magic + let mut peek_buf = [0u8; 4]; + let peek_result = tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.peek(&mut peek_buf)).await; + + match peek_result { + Ok(Ok(n)) if n >= 4 => { + if peek_buf == self.magic { + tracing::info!( + "V2 handshake: Detected V1-only peer {} (received magic bytes)", + self.peer_address + ); + return Ok(V2HandshakeResult::FallbackToV1); + } + } + Ok(Ok(_)) => { + // Not enough bytes to determine, continue with v2 + } + Ok(Err(e)) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to peek for v1 detection: {}", + e + ))); + } + Err(_) => { + return Err(NetworkError::Timeout); + } + } + + // Read the full remote pubkey (64 bytes) + let mut remote_pubkey = [0u8; ELLIGATOR_SWIFT_KEY_SIZE]; + tracing::debug!( + "V2 handshake: Reading remote ElligatorSwift pubkey from {}", + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut remote_pubkey)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to read remote pubkey: {}", e)) + })?; + + // Step 3: Process the remote's public key and derive session keys + let handshake = handshake.receive_key(remote_pubkey).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to process remote pubkey: {}", e)) + })?; + + tracing::debug!("V2 handshake: Derived session keys with {}", self.peer_address); + + // Step 4: Send garbage terminator + version packet + let mut send_version_buffer = vec![0u8; Handshake::send_version_len(None)]; + let handshake = handshake.send_version(&mut send_version_buffer, None).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to prepare version: {}", e)) + })?; + + tracing::debug!( + "V2 handshake: Sending garbage terminator + version ({} bytes) to {}", + send_version_buffer.len(), + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_version_buffer)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to send version: {}", e)) + })?; + + stream + .flush() + .await + .map_err(|e| NetworkError::V2HandshakeFailed(format!("Failed to flush: {}", e)))?; + + // Step 5: Receive remote garbage + terminator + // Read up to MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE bytes + let mut garbage_buffer = Vec::with_capacity(MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE); + let mut handshake_state = handshake; + + tracing::debug!( + "V2 handshake: Scanning for remote garbage terminator from {}", + self.peer_address + ); + + let scan_start = std::time::Instant::now(); + loop { + // Check timeout + if scan_start.elapsed() > HANDSHAKE_TIMEOUT { + return Err(NetworkError::Timeout); + } + + // Read a chunk + let mut chunk = [0u8; 256]; + let n = match tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut chunk), + ) + .await + { + Ok(Ok(0)) => { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed during garbage scan".to_string(), + )); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to read garbage: {}", + e + ))); + } + Err(_) => { + return Err(NetworkError::Timeout); + } + }; + + garbage_buffer.extend_from_slice(&chunk[..n]); + + // Try to find the garbage terminator + match handshake_state.receive_garbage(&garbage_buffer) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => { + tracing::debug!( + "V2 handshake: Found garbage terminator after {} bytes from {}", + consumed_bytes, + self.peer_address + ); + + // Keep any remaining bytes after the garbage + let remaining = garbage_buffer[consumed_bytes..].to_vec(); + + // Step 6: Receive version packet (may be preceded by decoy packets) + // Loop until we receive the genuine version packet + let mut handshake = handshake; + let mut leftover_data = remaining; + + loop { + // Read at least 3 bytes for the length prefix + while leftover_data.len() < 3 { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read packet length: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before version packet".to_string(), + )); + } + leftover_data.extend_from_slice(&more[..n]); + } + + // Decrypt the packet length (first 3 bytes) + let length_bytes: [u8; 3] = + leftover_data[..3].try_into().map_err(|_| { + NetworkError::V2HandshakeFailed( + "Failed to extract length bytes".to_string(), + ) + })?; + let packet_len = + handshake.decrypt_packet_len(length_bytes).map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to decrypt packet length: {}", + e + )) + })?; + + tracing::debug!( + "V2 handshake: Packet length is {} bytes from {}", + packet_len, + self.peer_address + ); + + // Read more data if needed to have the full packet + let total_needed = 3 + packet_len; // length prefix + packet content + while leftover_data.len() < total_needed { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read packet content: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before packet complete".to_string(), + )); + } + leftover_data.extend_from_slice(&more[..n]); + } + + // Extract just the packet content (excluding the 3-byte length prefix) + let mut packet_content = leftover_data[3..3 + packet_len].to_vec(); + + // Keep any data after this packet for the next iteration + leftover_data = leftover_data[3 + packet_len..].to_vec(); + + // Process packet + match handshake.receive_version(&mut packet_content) { + Ok(VersionResult::Complete { + cipher, + }) => { + tracing::info!( + "V2 handshake: Completed successfully with {}", + self.peer_address + ); + + let session_id = *cipher.id(); + return Ok(V2HandshakeResult::Success(Box::new(V2Session { + stream, + cipher, + session_id, + }))); + } + Ok(VersionResult::Decoy(next_handshake)) => { + // Received a decoy packet, continue reading for version packet + tracing::debug!( + "V2 handshake: Received decoy packet from {}, continuing", + self.peer_address + ); + handshake = next_handshake; + // Continue loop to read next packet + } + Err(e) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to process packet: {}", + e + ))); + } + } + } + } + Ok(GarbageResult::NeedMoreData(hs)) => { + // Continue reading more data; bip324 enforces the max garbage limit + // internally and will return NoGarbageTerminator if exceeded + handshake_state = hs; + } + Err(e) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to process garbage: {}", + e + ))); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_handshake_manager_creation() { + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + + let initiator = V2HandshakeManager::new_initiator(Network::Dash, addr); + assert_eq!(initiator.role, Role::Initiator); + // Dash mainnet magic: 0xBD6B0CBF in little-endian + assert_eq!(initiator.magic, [0xbf, 0x0c, 0x6b, 0xbd]); + + let responder = V2HandshakeManager::new_responder(Network::Dash, addr); + assert_eq!(responder.role, Role::Responder); + } + + #[test] + fn test_testnet_magic() { + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let manager = V2HandshakeManager::new_initiator(Network::Testnet, addr); + // Dash testnet magic: 0xFFCAE2CE in little-endian + assert_eq!(manager.magic, [0xce, 0xe2, 0xca, 0xff]); + } +} diff --git a/dash-spv/tests/handshake_test.rs b/dash-spv/tests/handshake_test.rs index d8cb6579f..e23ce8d2c 100644 --- a/dash-spv/tests/handshake_test.rs +++ b/dash-spv/tests/handshake_test.rs @@ -4,6 +4,7 @@ use std::net::SocketAddr; use std::time::Duration; use dash_spv::client::config::MempoolStrategy; +use dash_spv::network::transport::TransportPreference; use dash_spv::network::{HandshakeManager, NetworkManager, Peer, PeerNetworkManager}; use dash_spv::{ClientConfig, Network}; @@ -13,7 +14,7 @@ async fn test_handshake_with_mainnet_peer() { let _ = env_logger::builder().filter_level(log::LevelFilter::Debug).is_test(true).try_init(); let peer_addr: SocketAddr = "127.0.0.1:9999".parse().expect("Valid peer address"); - let result = Peer::connect(peer_addr, 10, Network::Dash).await; + let result = Peer::connect(peer_addr, 10, Network::Dash, TransportPreference::V1Only).await; match result { Ok(mut connection) => { @@ -54,7 +55,7 @@ async fn test_handshake_timeout() { // Using a non-routable IP that will cause the connection to hang let peer_addr: SocketAddr = "10.255.255.1:9999".parse().expect("Valid peer address"); let start = std::time::Instant::now(); - let result = Peer::connect(peer_addr, 2, Network::Dash).await; + let result = Peer::connect(peer_addr, 2, Network::Dash, TransportPreference::V1Only).await; let elapsed = start.elapsed(); assert!(result.is_err(), "Connection should fail for non-routable peer"); @@ -92,7 +93,7 @@ async fn test_multiple_connect_disconnect_cycles() { for i in 1..=3 { println!("Attempt {} to connect to {}", i, peer_addr); - match connection.connect_instance().await { + match connection.connect_instance(TransportPreference::V1Only).await { Ok(_) => { assert!(connection.is_connected(), "Should be connected after successful connect"); @@ -114,3 +115,61 @@ async fn test_multiple_connect_disconnect_cycles() { } } } + +// ============================================================================= +// BIP324 V2 Transport Integration Tests +// ============================================================================= + +/// Test V2Preferred mode which tries V2 first then falls back to V1. +/// This test verifies the fallback mechanism works correctly. +#[tokio::test] +async fn test_v2preferred_fallback_to_v1() { + let _ = env_logger::builder().filter_level(log::LevelFilter::Debug).is_test(true).try_init(); + + let peer_addr: SocketAddr = "127.0.0.1:9999".parse().expect("Valid peer address"); + let result = + Peer::connect(peer_addr, 10, Network::Dash, TransportPreference::V2Preferred).await; + + match result { + Ok(mut connection) => { + let transport_version = connection.transport_version(); + println!( + "✓ Connected to {} using V{} transport (V2Preferred mode)", + peer_addr, transport_version + ); + + // V2Preferred should use V2 if supported, V1 otherwise + // Most current nodes are V1-only, so we typically expect V1 + assert!( + transport_version == 1 || transport_version == 2, + "Transport version should be 1 or 2" + ); + + // Perform application-level handshake to verify transport works + let mut handshake_manager = HandshakeManager::new( + Network::Dash, + MempoolStrategy::BloomFilter, + Some("v2pref_test".parse().unwrap()), + ); + handshake_manager + .perform_handshake(&mut connection) + .await + .expect("Application handshake failed"); + + assert!(connection.is_connected(), "Should be connected after handshake"); + + // Verify peer info is populated + let peer_info = connection.peer_info(); + assert_eq!(peer_info.address, peer_addr); + assert!(peer_info.connected); + + connection.disconnect().await.expect("Failed to disconnect"); + println!("✓ V2Preferred test passed (used V{} transport)", transport_version); + } + Err(e) => { + println!("✗ Connection failed: {}", e); + println!("Note: This test requires a Dash Core node running at 127.0.0.1:9999"); + // Don't fail - node might not be available in CI + } + } +} diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index 786b664fa..8c9c7efa5 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -352,45 +352,21 @@ impl NetworkMessage { _ => CommandString::try_from_static(self.cmd()).expect("cmd returns valid commands"), } } -} -impl RawNetworkMessage { - /// Return the message command as a static string reference. + /// Serialize the message payload without V1 framing (magic/command/checksum). /// - /// This returns `"unknown"` for [NetworkMessage::Unknown], - /// regardless of the actual command in the unknown message. - /// Use the [Self::command] method to get the command for unknown messages. - pub fn cmd(&self) -> &'static str { - self.payload.cmd() - } - - /// Return the CommandString for the message command. - pub fn command(&self) -> CommandString { - self.payload.command() - } -} - -struct HeaderSerializationWrapper<'a>(&'a Vec); - -impl<'a> Encodable for HeaderSerializationWrapper<'a> { - #[inline] - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += VarInt(self.0.len() as u64).consensus_encode(w)?; - for header in self.0.iter() { - len += header.consensus_encode(w)?; - len += 0u8.consensus_encode(w)?; - } - Ok(len) - } -} - -impl Encodable for RawNetworkMessage { - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += self.magic.consensus_encode(w)?; - len += self.command().consensus_encode(w)?; - len += CheckedData(match self.payload { + /// This method returns the raw serialized bytes of the message payload, + /// suitable for use with V2 (BIP324) transport or other protocols that + /// handle framing separately. + /// + /// # Note on Headers serialization + /// + /// The `Headers` message is serialized with a trailing zero byte after each + /// header, representing an empty transaction count (VarInt). This matches + /// the Bitcoin/Dash protocol where headers messages reuse the block + /// serialization format but with no transactions. + pub fn consensus_encode_payload(&self) -> Vec { + match *self { NetworkMessage::Version(ref dat) => serialize(dat), NetworkMessage::Addr(ref dat) => serialize(dat), NetworkMessage::Inv(ref dat) => serialize(dat), @@ -420,8 +396,19 @@ impl Encodable for RawNetworkMessage { NetworkMessage::BlockTxn(ref dat) => serialize(dat), NetworkMessage::Alert(ref dat) => serialize(dat), NetworkMessage::Reject(ref dat) => serialize(dat), - NetworkMessage::FeeFilter(ref data) => serialize(data), + NetworkMessage::FeeFilter(ref dat) => serialize(dat), NetworkMessage::AddrV2(ref dat) => serialize(dat), + NetworkMessage::GetMnListD(ref dat) => serialize(dat), + NetworkMessage::MnListDiff(ref dat) => serialize(dat), + NetworkMessage::GetQRInfo(ref dat) => serialize(dat), + NetworkMessage::QRInfo(ref dat) => serialize(dat), + NetworkMessage::CLSig(ref dat) => serialize(dat), + NetworkMessage::ISLock(ref dat) => serialize(dat), + NetworkMessage::SendDsq(wants_dsq) => serialize(&(wants_dsq as u8)), + NetworkMessage::Unknown { + payload: ref data, + .. + } => data.clone(), NetworkMessage::Verack | NetworkMessage::SendHeaders | NetworkMessage::SendHeaders2 @@ -430,81 +417,29 @@ impl Encodable for RawNetworkMessage { | NetworkMessage::WtxidRelay | NetworkMessage::FilterClear | NetworkMessage::SendAddrV2 => vec![], - NetworkMessage::Unknown { - payload: ref data, - .. - } => serialize(data), - NetworkMessage::GetMnListD(ref dat) => serialize(dat), - NetworkMessage::MnListDiff(ref dat) => serialize(dat), - NetworkMessage::GetQRInfo(ref dat) => serialize(dat), - NetworkMessage::QRInfo(ref dat) => serialize(dat), - NetworkMessage::CLSig(ref dat) => serialize(dat), - NetworkMessage::ISLock(ref dat) => serialize(dat), - NetworkMessage::SendDsq(wants_dsq) => serialize(&(wants_dsq as u8)), - }) - .consensus_encode(w)?; - Ok(len) - } -} - -struct HeaderDeserializationWrapper(Vec); - -impl Decodable for HeaderDeserializationWrapper { - #[inline] - fn consensus_decode_from_finite_reader( - r: &mut R, - ) -> Result { - let len = VarInt::consensus_decode(r)?.0; - // should be above usual number of items to avoid - // allocation - let mut ret = Vec::with_capacity(core::cmp::min(1024 * 16, len as usize)); - for _ in 0..len { - ret.push(Decodable::consensus_decode(r)?); - if u8::consensus_decode(r)? != 0u8 { - return Err(encode::Error::ParseFailed( - "Headers message should not contain transactions", - )); - } } - Ok(HeaderDeserializationWrapper(ret)) } - #[inline] - fn consensus_decode(r: &mut R) -> Result { - Self::consensus_decode_from_finite_reader(r.take(MAX_MSG_SIZE as u64).by_ref()) - } -} - -impl Decodable for RawNetworkMessage { - fn consensus_decode_from_finite_reader( - r: &mut R, - ) -> Result { - let magic = Decodable::consensus_decode_from_finite_reader(r)?; - let cmd = CommandString::consensus_decode_from_finite_reader(r)?; - let raw_payload = match CheckedData::consensus_decode_from_finite_reader(r) { - Ok(cd) => cd.0, - Err(encode::Error::InvalidChecksum { - expected, - actual, - }) => { - // Include message command and magic in logging to aid diagnostics - log::warn!( - "Invalid payload checksum for network message '{}' (magic {:#x}): expected {:02x?}, actual {:02x?}", - cmd.0, - magic, - expected, - actual - ); - return Err(encode::Error::InvalidChecksum { - expected, - actual, - }); - } - Err(e) => return Err(e), - }; - - let mut mem_d = io::Cursor::new(raw_payload); - let payload = match &cmd.0[..] { + /// Decode a message payload from raw bytes given a command string. + /// + /// This method decodes the raw payload bytes into a `NetworkMessage` variant + /// based on the command string. It handles all standard Bitcoin and Dash-specific + /// message types, including special cases like `Headers` which has trailing + /// transaction count bytes. + /// + /// This is the inverse of [`consensus_encode_payload`], suitable for use with + /// V2 (BIP324) transport or other protocols that handle framing separately. + /// + /// # Arguments + /// * `cmd` - The command string identifying the message type (e.g., "version", "headers") + /// * `payload` - The raw payload bytes to decode + /// + /// # Returns + /// * `Ok(NetworkMessage)` - Successfully decoded message + /// * `Err(encode::Error)` - Decoding failed + pub fn consensus_decode_payload(cmd: &str, payload: &[u8]) -> Result { + let mut mem_d = io::Cursor::new(payload); + let message = match cmd { "version" => { NetworkMessage::Version(Decodable::consensus_decode_from_finite_reader(&mut mem_d)?) } @@ -650,10 +585,113 @@ impl Decodable for RawNetworkMessage { NetworkMessage::SendDsq(byte != 0) } _ => NetworkMessage::Unknown { - command: cmd, - payload: mem_d.into_inner(), + command: CommandString::try_from(cmd.to_string()) + .map_err(|_| encode::Error::ParseFailed("Invalid command string"))?, + payload: payload.to_vec(), }, }; + Ok(message) + } +} + +impl RawNetworkMessage { + /// Return the message command as a static string reference. + /// + /// This returns `"unknown"` for [NetworkMessage::Unknown], + /// regardless of the actual command in the unknown message. + /// Use the [Self::command] method to get the command for unknown messages. + pub fn cmd(&self) -> &'static str { + self.payload.cmd() + } + + /// Return the CommandString for the message command. + pub fn command(&self) -> CommandString { + self.payload.command() + } +} + +struct HeaderSerializationWrapper<'a>(&'a Vec); + +impl<'a> Encodable for HeaderSerializationWrapper<'a> { + #[inline] + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += VarInt(self.0.len() as u64).consensus_encode(w)?; + for header in self.0.iter() { + len += header.consensus_encode(w)?; + len += 0u8.consensus_encode(w)?; + } + Ok(len) + } +} + +impl Encodable for RawNetworkMessage { + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += self.magic.consensus_encode(w)?; + len += self.command().consensus_encode(w)?; + len += CheckedData(self.payload.consensus_encode_payload()).consensus_encode(w)?; + Ok(len) + } +} + +struct HeaderDeserializationWrapper(Vec); + +impl Decodable for HeaderDeserializationWrapper { + #[inline] + fn consensus_decode_from_finite_reader( + r: &mut R, + ) -> Result { + let len = VarInt::consensus_decode(r)?.0; + // should be above usual number of items to avoid + // allocation + let mut ret = Vec::with_capacity(core::cmp::min(1024 * 16, len as usize)); + for _ in 0..len { + ret.push(Decodable::consensus_decode(r)?); + if u8::consensus_decode(r)? != 0u8 { + return Err(encode::Error::ParseFailed( + "Headers message should not contain transactions", + )); + } + } + Ok(HeaderDeserializationWrapper(ret)) + } + + #[inline] + fn consensus_decode(r: &mut R) -> Result { + Self::consensus_decode_from_finite_reader(r.take(MAX_MSG_SIZE as u64).by_ref()) + } +} + +impl Decodable for RawNetworkMessage { + fn consensus_decode_from_finite_reader( + r: &mut R, + ) -> Result { + let magic = Decodable::consensus_decode_from_finite_reader(r)?; + let cmd = CommandString::consensus_decode_from_finite_reader(r)?; + let raw_payload = match CheckedData::consensus_decode_from_finite_reader(r) { + Ok(cd) => cd.0, + Err(encode::Error::InvalidChecksum { + expected, + actual, + }) => { + // Include message command and magic in logging to aid diagnostics + log::warn!( + "Invalid payload checksum for network message '{}' (magic {:#x}): expected {:02x?}, actual {:02x?}", + cmd.0, + magic, + expected, + actual + ); + return Err(encode::Error::InvalidChecksum { + expected, + actual, + }); + } + Err(e) => return Err(e), + }; + + let payload = NetworkMessage::consensus_decode_payload(&cmd.0, &raw_payload)?; Ok(RawNetworkMessage { magic, payload, @@ -1046,4 +1084,211 @@ mod test { let msg = NetworkMessage::SendDsq(true); assert_eq!(msg.cmd(), "senddsq"); } + + // ========================================================================= + // V2 Transport Payload Encode/Decode Tests + // These tests verify consensus_encode_payload and consensus_decode_payload + // for use with BIP324 V2 encrypted transport. + // ========================================================================= + + /// Helper to test round-trip encoding/decoding for a message + fn test_payload_round_trip(msg: &NetworkMessage) { + let encoded = msg.consensus_encode_payload(); + let decoded = NetworkMessage::consensus_decode_payload(msg.cmd(), &encoded) + .expect(&format!("Failed to decode {} message", msg.cmd())); + assert_eq!(msg, &decoded, "Round-trip failed for {} message", msg.cmd()); + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_encode_decode_standard_bitcoin_messages() { + // Use deserialized test data to avoid complex construction + let tx: Transaction = deserialize(&hex!("0100000001a15d57094aa7a21a28cb20b59aab8fc7d1149a3bdbcddba9c622e4f5f6a99ece010000006c493046022100f93bb0e7d8db7bd46e40132d1f8242026e045f03a0efe71bbb8e3f475e970d790221009337cd7f1f929f00cc6ff01f03729b069a7c21b59b1736ddfee5db5946c5da8c0121033b9b137ee87d5a812d6f506efdd37f0affa7ffc310711c06c7f3e097c9447c52ffffffff0100e1f505000000001976a9140389035a9225b3839e2bbf32d826a1e222031fd888ac00000000")).unwrap(); + let block: Block = deserialize(&include_bytes!("../../tests/data/testnet_block_000000000000045e0b1660b6445b5e5c5ab63c9a4f956be7e1e69be04fa4497b.raw")[..]).unwrap(); + let header: block::Header = deserialize(&hex!("010000004ddccd549d28f385ab457e98d1b11ce80bfea2c5ab93015ade4973e400000000bf4473e53794beae34e64fccc471dace6ae544180816f89591894e0f417a914cd74d6e49ffff001d323b3a7b")).unwrap(); + + let inv = vec![Inventory::Transaction(hash([3u8; 32]).into())]; + + let messages: Vec = vec![ + NetworkMessage::Ping(0x1234567890abcdef), + NetworkMessage::Pong(0xfedcba0987654321), + NetworkMessage::Inv(inv.clone()), + NetworkMessage::GetData(inv.clone()), + NetworkMessage::NotFound(inv), + NetworkMessage::GetBlocks(GetBlocksMessage { + version: 70015, + locator_hashes: vec![hash_x11([4u8; 32]).into()], + stop_hash: hash_x11([5u8; 32]).into(), + }), + NetworkMessage::GetHeaders(GetHeadersMessage { + version: 70015, + locator_hashes: vec![hash_x11([6u8; 32]).into()], + stop_hash: hash_x11([7u8; 32]).into(), + }), + NetworkMessage::Headers(vec![header]), + NetworkMessage::Tx(tx), + NetworkMessage::Block(block), + NetworkMessage::FilterLoad(FilterLoad { + filter: vec![0x01, 0x02, 0x03], + hash_funcs: 11, + tweak: 0x12345678, + flags: BloomFlags::All, + }), + NetworkMessage::FilterAdd(FilterAdd { + data: vec![0xaa, 0xbb, 0xcc], + }), + NetworkMessage::SendCmpct(SendCmpct { + send_compact: true, + version: 1, + }), + NetworkMessage::GetCFilters(GetCFilters { + filter_type: 0, + start_height: 100, + stop_hash: hash_x11([8u8; 32]).into(), + }), + NetworkMessage::GetCFHeaders(GetCFHeaders { + filter_type: 0, + start_height: 100, + stop_hash: hash_x11([9u8; 32]).into(), + }), + NetworkMessage::GetCFCheckpt(GetCFCheckpt { + filter_type: 0, + stop_hash: hash_x11([10u8; 32]).into(), + }), + ]; + + for msg in &messages { + test_payload_round_trip(msg); + } + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_encode_decode_dash_specific_messages() { + use crate::bls_sig_utils::BLSSignature; + use crate::hash_types::CycleHash; + use crate::network::message_sml::GetMnListDiff; + use crate::{ChainLock, InstantLock}; + + let messages: Vec = vec![ + NetworkMessage::SendDsq(true), + NetworkMessage::SendDsq(false), + NetworkMessage::GetMnListD(GetMnListDiff { + base_block_hash: hash_x11([1u8; 32]).into(), + block_hash: hash_x11([2u8; 32]).into(), + }), + NetworkMessage::CLSig(ChainLock { + block_height: 123456, + block_hash: hash_x11([3u8; 32]).into(), + signature: BLSSignature::from([0u8; 96]), + }), + NetworkMessage::ISLock(InstantLock { + version: 1, + inputs: vec![], + txid: hash([4u8; 32]).into(), + cyclehash: CycleHash::from([5u8; 32]), + signature: BLSSignature::from([0u8; 96]), + }), + NetworkMessage::GetHeaders2(GetHeadersMessage { + version: 70015, + locator_hashes: vec![hash_x11([6u8; 32]).into()], + stop_hash: hash_x11([7u8; 32]).into(), + }), + NetworkMessage::SendHeaders2, + ]; + + for msg in &messages { + test_payload_round_trip(msg); + } + } + + #[test] + fn test_encode_decode_empty_payload_messages() { + let empty_payload_messages: Vec = vec![ + NetworkMessage::Verack, + NetworkMessage::SendHeaders, + NetworkMessage::SendHeaders2, + NetworkMessage::MemPool, + NetworkMessage::GetAddr, + NetworkMessage::WtxidRelay, + NetworkMessage::FilterClear, + NetworkMessage::SendAddrV2, + ]; + + for msg in &empty_payload_messages { + // Verify encoding produces empty payload + let encoded = msg.consensus_encode_payload(); + assert!( + encoded.is_empty(), + "{} should have empty payload, got {} bytes", + msg.cmd(), + encoded.len() + ); + + // Verify decoding works with empty payload + let decoded = NetworkMessage::consensus_decode_payload(msg.cmd(), &[]) + .expect(&format!("Failed to decode empty {} message", msg.cmd())); + assert_eq!(msg, &decoded, "Empty payload round-trip failed for {}", msg.cmd()); + } + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_headers_message_special_encoding() { + let header = block::Header { + version: block::Version::from_consensus(1), + prev_blockhash: hash_x11([1u8; 32]).into(), + merkle_root: hash([2u8; 32]).into(), + time: 1234567890, + bits: crate::pow::CompactTarget::from_consensus(0x1d00ffff), + nonce: 42, + }; + + // Test empty headers + let empty_headers = NetworkMessage::Headers(vec![]); + let encoded = empty_headers.consensus_encode_payload(); + assert_eq!(encoded, vec![0x00], "Empty headers should encode to single 0x00 varint"); + test_payload_round_trip(&empty_headers); + + // Test single header + let single_header = NetworkMessage::Headers(vec![header.clone()]); + let encoded = single_header.consensus_encode_payload(); + // Should be: varint(1) + header_bytes + 0x00 (tx count) + // Header is 80 bytes, so total should be 1 + 80 + 1 = 82 bytes + assert_eq!(encoded.len(), 82, "Single header should be 82 bytes"); + assert_eq!(encoded[81], 0x00, "Header should have trailing zero tx count"); + test_payload_round_trip(&single_header); + + // Test multiple headers + let multi_headers = NetworkMessage::Headers(vec![header.clone(), header.clone(), header]); + let encoded = multi_headers.consensus_encode_payload(); + // Should be: varint(3) + 3 * (header + 0x00) = 1 + 3*81 = 244 bytes + assert_eq!(encoded.len(), 244, "Three headers should be 244 bytes"); + test_payload_round_trip(&multi_headers); + } + + #[test] + fn test_encode_decode_unknown_message() { + // Create an Unknown message with a custom command and payload + let unknown_msg = NetworkMessage::Unknown { + command: CommandString::try_from_static("custom").unwrap(), + payload: vec![0xaa, 0xbb, 0xcc, 0xdd], + }; + + // Test encoding - should return raw payload bytes without length prefix + let encoded = unknown_msg.consensus_encode_payload(); + assert_eq!( + encoded, + vec![0xaa, 0xbb, 0xcc, 0xdd], + "Unknown message should encode to raw payload bytes without length prefix" + ); + + // Test decoding with the actual command (not "unknown") + // Note: We must use command() not cmd() because cmd() returns "unknown" for Unknown variants + let cmd = unknown_msg.command(); + let decoded = NetworkMessage::consensus_decode_payload(cmd.as_ref(), &encoded) + .expect("Failed to decode unknown message"); + + assert_eq!(unknown_msg, decoded, "Round-trip failed for unknown message"); + } }