diff --git a/Cargo.lock b/Cargo.lock index c16f0570..a2506158 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2292,16 +2292,6 @@ dependencies = [ "cipher", ] -[[package]] -name = "ctrlc" -version = "3.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c" -dependencies = [ - "nix 0.29.0", - "windows-sys 0.59.0", -] - [[package]] name = "curve25519-dalek" version = "4.1.3" @@ -3578,7 +3568,7 @@ dependencies = [ "futures-channel", "futures-io", "futures-util", - "idna 1.0.3", + "idna", "ipnet", "once_cell", "rand 0.8.5", @@ -3603,7 +3593,7 @@ dependencies = [ "futures-channel", "futures-io", "futures-util", - "idna 1.0.3", + "idna", "ipnet", "once_cell", "rand 0.9.1", @@ -4077,16 +4067,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" -[[package]] -name = "idna" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "idna" version = "1.0.3" @@ -4141,12 +4121,6 @@ dependencies = [ "windows 0.52.0", ] -[[package]] -name = "if_chain" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" - [[package]] name = "igd-next" version = "0.14.3" @@ -4241,19 +4215,6 @@ dependencies = [ "serde", ] -[[package]] -name = "indicatif" -version = "0.17.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" -dependencies = [ - "console", - "number_prefix", - "portable-atomic", - "unicode-width", - "web-time", -] - [[package]] name = "inout" version = "0.1.4" @@ -6077,12 +6038,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "nvml-wrapper" version = "0.10.0" @@ -6756,30 +6711,6 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -9240,12 +9171,6 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" -[[package]] -name = "unicode-bidi" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" - [[package]] name = "unicode-ident" version = "1.0.18" @@ -9339,7 +9264,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", - "idna 1.0.3", + "idna", "percent-encoding", "serde", ] @@ -9470,48 +9395,6 @@ dependencies = [ "url", ] -[[package]] -name = "validator" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" -dependencies = [ - "idna 0.4.0", - "lazy_static", - "regex", - "serde", - "serde_derive", - "serde_json", - "url", - "validator_derive", -] - -[[package]] -name = "validator_derive" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af" -dependencies = [ - "if_chain", - "lazy_static", - "proc-macro-error", - "proc-macro2", - "quote", - "regex", - "syn 1.0.109", - "validator_types", -] - -[[package]] -name = "validator_types" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3" -dependencies = [ - "proc-macro2", - "syn 1.0.109", -] - [[package]] name = "valuable" version = "0.1.1" @@ -10325,32 +10208,24 @@ dependencies = [ "alloy", "anyhow", "bollard", - "bytes", "chrono", "cid", "clap", "colored", "console", - "ctrlc", - "dashmap", "directories", "env_logger", "futures", - "futures-core", "futures-util", "hex", "homedir", - "indicatif", - "iroh", "lazy_static", "libc", "log", - "nalgebra", "nvml-wrapper", + "p2p", "rand 0.8.5", "rand 0.9.1", - "rand_core 0.6.4", - "regex", "reqwest", "rust-ipfs", "serde", @@ -10367,15 +10242,12 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "toml", "tracing", - "tracing-log", "tracing-loki", "tracing-subscriber", "unicode-width", "url", "uuid", - "validator 0.16.1", ] [[package]] diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 0f08e404..eb041cad 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -8,43 +8,41 @@ workspace = true [dependencies] shared = { workspace = true } +p2p = { workspace = true } + actix-web = { workspace = true } -bollard = "0.18.1" +alloy = { workspace = true } +anyhow = { workspace = true } +cid = { workspace = true } clap = { workspace = true } -colored = "2.0" -lazy_static = "1.4" -regex = "1.10" +chrono = { workspace = true } +directories = { workspace = true } +env_logger = { workspace = true } +futures = { workspace = true } +futures-util = { workspace = true } +hex = { workspace = true } +log = { workspace = true } +rand_v8 = { workspace = true } +reqwest = { workspace = true, features = ["blocking"] } +rust-ipfs = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } +stun = { workspace = true } tokio = { workspace = true, features = ["full", "macros"] } +tokio-util = { workspace = true, features = ["rt"] } +url = { workspace = true } uuid = { workspace = true } -validator = { version = "0.16", features = ["derive"] } + +bollard = "0.18.1" +colored = "2.0" +lazy_static = "1.4" sysinfo = "0.30" libc = "0.2" nvml-wrapper = "0.10.0" -log = { workspace = true } -env_logger = { workspace = true } -futures-core = "0.3" -futures-util = { workspace = true } -alloy = { workspace = true } -url = { workspace = true } -serde_json = { workspace = true } -reqwest = { workspace = true, features = ["blocking"] } -hex = { workspace = true } console = "0.15.10" -indicatif = "0.17.9" -bytes = "1.9.0" -anyhow = { workspace = true } thiserror = "2.0.11" -toml = { workspace = true } -ctrlc = "3.4.5" -tokio-util = { workspace = true, features = ["rt"] } -futures = { workspace = true } -chrono = { workspace = true } serial_test = "0.5.1" -directories = { workspace = true } strip-ansi-escapes = "0.2.1" -nalgebra = { workspace = true } -stun = { workspace = true } sha2 = "0.10.8" unicode-width = "0.2.0" rand = "0.9.0" @@ -52,13 +50,6 @@ tempfile = "3.14.0" tracing-loki = "0.2.6" tracing = { workspace = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -tracing-log = "0.2.0" time = "0.3.41" -iroh = { workspace = true } -rand_v8 = { workspace = true } -rand_core_v6 = { workspace = true } -dashmap = "6.1.0" tokio-stream = { version = "0.1.17", features = ["net"] } -rust-ipfs = { workspace = true } -cid = { workspace = true } homedir = "0.3" diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index 92de379e..8f358252 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -9,13 +9,12 @@ use crate::metrics::store::MetricsStore; use crate::operations::compute_node::ComputeNodeOperations; use crate::operations::heartbeat::service::HeartbeatService; use crate::operations::provider::ProviderOperations; -use crate::p2p::P2PContext; -use crate::p2p::P2PService; use crate::services::discovery::DiscoveryService; use crate::services::discovery_updater::DiscoveryUpdater; use crate::state::system_state::SystemState; use crate::TaskHandles; use alloy::primitives::utils::format_ether; +use alloy::primitives::Address; use alloy::primitives::U256; use alloy::signers::local::PrivateKeySigner; use alloy::signers::Signer; @@ -24,8 +23,10 @@ use log::{error, info}; use shared::models::node::ComputeRequirements; use shared::models::node::Node; use shared::web3::contracts::core::builder::ContractBuilder; +use shared::web3::contracts::core::builder::Contracts; use shared::web3::contracts::structs::compute_pool::PoolStatus; use shared::web3::wallet::Wallet; +use shared::web3::wallet::WalletProvider; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -56,13 +57,17 @@ pub enum Commands { #[arg(long, default_value = "8080")] port: u16, + /// Port for libp2p service + #[arg(long, default_value = "4002")] + libp2p_port: u16, + /// External IP address for the worker to advertise #[arg(long)] external_ip: Option, /// Compute pool ID #[arg(long)] - compute_pool_id: u64, + compute_pool_id: u32, /// Dry run the command without starting the worker #[arg(long, default_value = "false")] @@ -176,7 +181,7 @@ pub enum Commands { /// Compute pool ID #[arg(long)] - compute_pool_id: u64, + compute_pool_id: u32, }, } @@ -188,6 +193,7 @@ pub async fn execute_command( match command { Commands::Run { port: _, + libp2p_port, external_ip, compute_pool_id, dry_run: _, @@ -217,7 +223,7 @@ pub async fn execute_command( let state = Arc::new(SystemState::new( state_dir_overwrite.clone(), *disable_state_storing, - Some(compute_pool_id.to_string()), + *compute_pool_id, )); let private_key_provider = if let Some(key) = private_key_provider { @@ -296,7 +302,7 @@ pub async fn execute_command( let discovery_state = state.clone(); let discovery_updater = DiscoveryUpdater::new(discovery_service.clone(), discovery_state.clone()); - let pool_id = U256::from(*compute_pool_id as u32); + let pool_id = U256::from(*compute_pool_id); let pool_info = loop { match contracts.compute_pool.get_pool_info(pool_id).await { @@ -338,7 +344,7 @@ pub async fn execute_command( .address() .to_string(), compute_specs: None, - compute_pool_id: *compute_pool_id as u32, + compute_pool_id: *compute_pool_id, worker_p2p_id: None, worker_p2p_addresses: None, }; @@ -515,7 +521,6 @@ pub async fn execute_command( .default_signer() .address() .to_string(), - state.get_p2p_seed(), *disable_host_network_mode, )); @@ -701,15 +706,6 @@ pub async fn execute_command( } }; - let p2p_context = P2PContext { - docker_service: docker_service.clone(), - heartbeat_service: heartbeat.clone(), - system_state: state.clone(), - contracts: contracts.clone(), - node_wallet: node_wallet_instance.clone(), - provider_wallet: provider_wallet_instance.clone(), - }; - let validators = match contracts.prime_network.get_validator_role().await { Ok(validators) => validators, Err(e) => { @@ -728,15 +724,19 @@ pub async fn execute_command( let mut allowed_addresses = vec![pool_info.creator, pool_info.compute_manager_key]; allowed_addresses.extend(validators); - let p2p_service = match P2PService::new( - state.worker_p2p_seed, - cancellation_token.clone(), - Some(p2p_context), + let validator_addresses = std::collections::HashSet::from_iter(allowed_addresses); + let p2p_service = match crate::p2p::Service::new( + state.get_p2p_keypair().clone(), + *libp2p_port, node_wallet_instance.clone(), - allowed_addresses, - ) - .await - { + validator_addresses, + docker_service.clone(), + heartbeat.clone(), + state.clone(), + contracts.clone(), + provider_wallet_instance.clone(), + cancellation_token.clone(), + ) { Ok(service) => service, Err(e) => { error!("❌ Failed to start P2P service: {e}"); @@ -744,23 +744,18 @@ pub async fn execute_command( } }; - if let Err(e) = p2p_service.start() { - error!("❌ Failed to start P2P listener: {e}"); - std::process::exit(1); - } - - node_config.worker_p2p_id = Some(p2p_service.node_id().to_string()); + let peer_id = p2p_service.peer_id(); + node_config.worker_p2p_id = Some(peer_id.to_string()); node_config.worker_p2p_addresses = Some( p2p_service - .listening_addresses() + .listen_addrs() .iter() .map(|addr| addr.to_string()) .collect(), ); - Console::success(&format!( - "P2P service started with ID: {}", - p2p_service.node_id() - )); + tokio::task::spawn(p2p_service.run()); + + Console::success(&format!("P2P service started with ID: {peer_id}",)); let mut attempts = 0; let max_attempts = 100; @@ -814,7 +809,7 @@ pub async fn execute_command( // Start monitoring compute node status on chain provider_ops.start_monitoring(provider_ops_cancellation); - let pool_id = state.compute_pool_id.clone().unwrap_or("0".to_string()); + let pool_id = state.get_compute_pool_id(); if let Err(err) = compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) { error!("❌ Failed to start node monitoring: {err}"); @@ -1021,7 +1016,7 @@ pub async fn execute_command( std::process::exit(1); } }; - let state = Arc::new(SystemState::new(None, true, None)); + /* Initialize dependencies - services, contracts, operations */ @@ -1035,25 +1030,25 @@ pub async fn execute_command( .build() .unwrap(); - let compute_node_ops = ComputeNodeOperations::new( - &provider_wallet_instance, - &node_wallet_instance, - contracts.clone(), - state.clone(), - ); + let provider_address = provider_wallet_instance.wallet.default_signer().address(); + let node_address = node_wallet_instance.wallet.default_signer().address(); let provider_ops = ProviderOperations::new(provider_wallet_instance.clone(), contracts.clone(), false); - let compute_node_exists = match compute_node_ops.check_compute_node_exists().await { - Ok(exists) => exists, + let compute_node_exists = match contracts + .compute_registry + .get_node(provider_address, node_address) + .await + { + Ok(_) => true, Err(e) => { Console::user_error(&format!("❌ Failed to check if compute node exists: {e}")); std::process::exit(1); } }; - let pool_id = U256::from(*compute_pool_id as u32); + let pool_id = U256::from(*compute_pool_id); if compute_node_exists { match contracts @@ -1073,7 +1068,7 @@ pub async fn execute_command( std::process::exit(1); } } - match compute_node_ops.remove_compute_node().await { + match remove_compute_node(contracts, provider_address, node_address).await { Ok(_removed_node) => { Console::success("Compute node removed"); match provider_ops.reclaim_stake(U256::from(0)).await { @@ -1099,3 +1094,17 @@ pub async fn execute_command( } } } + +async fn remove_compute_node( + contracts: Contracts, + provider_address: Address, + node_address: Address, +) -> Result> { + Console::title("🔄 Removing compute node"); + let remove_node_tx = contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + Console::success(&format!("Remove node tx: {remove_node_tx:?}")); + Ok(true) +} diff --git a/crates/worker/src/docker/service.rs b/crates/worker/src/docker/service.rs index 63425e2d..da15b88e 100644 --- a/crates/worker/src/docker/service.rs +++ b/crates/worker/src/docker/service.rs @@ -24,7 +24,6 @@ pub(crate) struct DockerService { system_memory_mb: Option, task_bridge_socket_path: String, node_address: String, - p2p_seed: Option, } const TASK_PREFIX: &str = "prime-task"; @@ -39,7 +38,6 @@ impl DockerService { task_bridge_socket_path: String, storage_path: String, node_address: String, - p2p_seed: Option, disable_host_network_mode: bool, ) -> Self { let docker_manager = @@ -52,7 +50,6 @@ impl DockerService { system_memory_mb, task_bridge_socket_path, node_address, - p2p_seed, } } @@ -177,7 +174,6 @@ impl DockerService { let system_memory_mb = self.system_memory_mb; let task_bridge_socket_path = self.task_bridge_socket_path.clone(); let node_address = self.node_address.clone(); - let p2p_seed = self.p2p_seed; let handle = tokio::spawn(async move { let Some(payload) = state_clone.get_current_task().await else { return; @@ -185,11 +181,7 @@ impl DockerService { let cmd = match payload.cmd { Some(cmd_vec) => { cmd_vec.into_iter().map(|arg| { - let mut processed_arg = arg.replace("${SOCKET_PATH}", &task_bridge_socket_path); - if let Some(seed) = p2p_seed { - processed_arg = processed_arg.replace("${WORKER_P2P_SEED}", &seed.to_string()); - } - processed_arg + arg.replace("${SOCKET_PATH}", &task_bridge_socket_path) }).collect() } None => vec!["sleep".to_string(), "infinity".to_string()], @@ -199,10 +191,7 @@ impl DockerService { if let Some(env) = &payload.env_vars { // Clone env vars and replace ${SOCKET_PATH} in values for (key, value) in env.iter() { - let mut processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); - if let Some(seed) = p2p_seed { - processed_value = processed_value.replace("${WORKER_P2P_SEED}", &seed.to_string()); - } + let processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); env_vars.insert(key.clone(), processed_value); } } @@ -432,7 +421,6 @@ mod tests { "/tmp/com.prime.miner/metrics.sock".to_string(), "/tmp/test-storage".to_string(), Address::ZERO.to_string(), - None, false, ); let task = Task { @@ -481,7 +469,6 @@ mod tests { test_socket_path.to_string(), "/tmp/test-storage".to_string(), Address::ZERO.to_string(), - Some(12345), // p2p_seed for testing false, ); diff --git a/crates/worker/src/docker/taskbridge/bridge.rs b/crates/worker/src/docker/taskbridge/bridge.rs index 65a28f76..80b8aee7 100644 --- a/crates/worker/src/docker/taskbridge/bridge.rs +++ b/crates/worker/src/docker/taskbridge/bridge.rs @@ -473,7 +473,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0)); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -506,7 +506,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0)); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -541,7 +541,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0)); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -590,7 +590,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0)); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), @@ -639,7 +639,7 @@ mod tests { let temp_dir = tempdir()?; let socket_path = temp_dir.path().join("test.sock"); let metrics_store = Arc::new(MetricsStore::new()); - let state = Arc::new(SystemState::new(None, false, None)); + let state = Arc::new(SystemState::new(None, false, 0)); let bridge = TaskBridge::new( Some(socket_path.to_str().unwrap()), metrics_store.clone(), diff --git a/crates/worker/src/operations/compute_node.rs b/crates/worker/src/operations/compute_node.rs index 39b18c29..00f147a7 100644 --- a/crates/worker/src/operations/compute_node.rs +++ b/crates/worker/src/operations/compute_node.rs @@ -32,7 +32,7 @@ impl<'c> ComputeNodeOperations<'c> { pub(crate) fn start_monitoring( &self, cancellation_token: CancellationToken, - pool_id: String, + pool_id: u32, ) -> Result<()> { let provider_address = self.provider_wallet.wallet.default_signer().address(); let node_address = self.node_wallet.wallet.default_signer().address(); @@ -81,9 +81,8 @@ impl<'c> ComputeNodeOperations<'c> { } // Check rewards for the current compute pool - if let Ok(pool_id_u32) = pool_id.parse::() { match contracts.compute_pool.calculate_node_rewards( - U256::from(pool_id_u32), + U256::from(pool_id), node_address, ).await { Ok((claimable, locked)) => { @@ -96,9 +95,9 @@ impl<'c> ComputeNodeOperations<'c> { } } Err(e) => { - log::debug!("Failed to check rewards for pool {pool_id_u32}: {e}"); + log::debug!("Failed to check rewards for pool {pool_id}: {e}"); } - } + } first_check = false; @@ -165,23 +164,4 @@ impl<'c> ComputeNodeOperations<'c> { Console::success(&format!("Add node tx: {add_node_tx:?}")); Ok(true) } - - pub(crate) async fn remove_compute_node(&self) -> Result> { - Console::title("🔄 Removing compute node"); - - if !self.check_compute_node_exists().await? { - return Ok(false); - } - - Console::progress("Removing compute node"); - let provider_address = self.provider_wallet.wallet.default_signer().address(); - let node_address = self.node_wallet.wallet.default_signer().address(); - let remove_node_tx = self - .contracts - .prime_network - .remove_compute_node(provider_address, node_address) - .await?; - Console::success(&format!("Remove node tx: {remove_node_tx:?}")); - Ok(true) - } } diff --git a/crates/worker/src/operations/heartbeat/service.rs b/crates/worker/src/operations/heartbeat/service.rs index 0d77d783..1b002cae 100644 --- a/crates/worker/src/operations/heartbeat/service.rs +++ b/crates/worker/src/operations/heartbeat/service.rs @@ -143,7 +143,7 @@ async fn send_heartbeat( wallet: Wallet, docker_service: Arc, metrics_store: Arc, - p2p_id: Option, + p2p_id: p2p::PeerId, ) -> Result { if endpoint.is_none() { return Err(HeartbeatError::RequestFailed); @@ -176,7 +176,7 @@ async fn send_heartbeat( .to_string(), ), timestamp: Some(ts), - p2p_id, + p2p_id: Some(p2p_id.to_string()), // TODO: this should always be `Some` task_details, } } else { @@ -188,7 +188,7 @@ async fn send_heartbeat( .to_string(), ), timestamp: Some(ts), - p2p_id, + p2p_id: Some(p2p_id.to_string()), // TODO: this should always be `Some` ..Default::default() } }; diff --git a/crates/worker/src/p2p/mod.rs b/crates/worker/src/p2p/mod.rs index 9393f985..748d1d54 100644 --- a/crates/worker/src/p2p/mod.rs +++ b/crates/worker/src/p2p/mod.rs @@ -1,4 +1,495 @@ -pub(crate) mod service; +use anyhow::Context as _; +use anyhow::Result; +use futures::stream::FuturesUnordered; +use p2p::InviteRequestUrl; +use p2p::Node; +use p2p::NodeBuilder; +use p2p::PeerId; +use p2p::Response; +use p2p::{IncomingMessage, Libp2pIncomingMessage, OutgoingMessage}; +use shared::web3::contracts::core::builder::Contracts; +use shared::web3::wallet::Wallet; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use std::time::SystemTime; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; -pub(crate) use service::P2PContext; -pub(crate) use service::P2PService; +use crate::docker::DockerService; +use crate::operations::heartbeat::service::HeartbeatService; +use crate::state::system_state::SystemState; +use shared::web3::wallet::WalletProvider; + +pub(crate) struct Service { + node: Node, + incoming_messages: Receiver, + cancellation_token: CancellationToken, + context: Context, +} + +impl Service { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + keypair: p2p::Keypair, + port: u16, + wallet: Wallet, + validator_addresses: HashSet, + docker_service: Arc, + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, + cancellation_token: CancellationToken, + ) -> Result { + let (node, incoming_messages, outgoing_messages) = + build_p2p_node(keypair, port, cancellation_token.clone()) + .context("failed to build p2p node")?; + Ok(Self { + node, + incoming_messages, + cancellation_token, + context: Context::new( + wallet, + outgoing_messages, + validator_addresses, + docker_service, + heartbeat_service, + system_state, + contracts, + provider_wallet, + ), + }) + } + + pub(crate) fn peer_id(&self) -> PeerId { + self.node.peer_id() + } + + pub(crate) fn listen_addrs(&self) -> &[p2p::Multiaddr] { + self.node.listen_addrs() + } + + pub(crate) async fn run(self) { + use futures::StreamExt as _; + + let Self { + node: _, + mut incoming_messages, + cancellation_token, + context, + } = self; + + let mut message_handlers = FuturesUnordered::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break; + } + Some(message) = incoming_messages.recv() => { + let context = context.clone(); + let handle = tokio::task::spawn( + handle_incoming_message(message, context) + ); + message_handlers.push(handle); + } + Some(res) = message_handlers.next() => { + if let Err(e) = res { + tracing::error!("failed to handle incoming message: {e}"); + } + } + } + } + } +} + +fn build_p2p_node( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, +) -> Result<(Node, Receiver, Sender)> { + let (node, _, incoming_message_rx, outgoing_message_tx) = NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_validator_authentication() + .with_hardware_challenge() + .with_invite() + .with_get_task_logs() + .with_restart() + .with_cancellation_token(cancellation_token) + .try_build() + .context("failed to build p2p node")?; + Ok((node, incoming_message_rx, outgoing_message_tx)) +} + +#[derive(Clone)] +struct Context { + authorized_peers: Arc>>, + wallet: Wallet, + validator_addresses: Arc>, + + // for validator authentication requests + ongoing_auth_challenges: Arc>>, // use request_id? + nonce_cache: Arc>>, + outgoing_messages: Sender, + + // for get_task_logs and restart requests + docker_service: Arc, + + // for invite requests + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, +} + +impl Context { + #[allow(clippy::too_many_arguments)] + fn new( + wallet: Wallet, + outgoing_messages: Sender, + validator_addresses: HashSet, + docker_service: Arc, + heartbeat_service: Arc, + system_state: Arc, + contracts: Contracts, + provider_wallet: Wallet, + ) -> Self { + Self { + authorized_peers: Arc::new(RwLock::new(HashSet::new())), + ongoing_auth_challenges: Arc::new(RwLock::new(HashMap::new())), + nonce_cache: Arc::new(RwLock::new(HashMap::new())), + wallet, + outgoing_messages, + validator_addresses: Arc::new(validator_addresses), + docker_service, + heartbeat_service, + system_state, + contracts, + provider_wallet, + } + } +} + +async fn handle_incoming_message(message: IncomingMessage, context: Context) -> Result<()> { + match message.message { + Libp2pIncomingMessage::Request { + request_id: _, + request, + channel, + } => { + tracing::debug!("received incoming request {request:?}"); + handle_incoming_request(message.peer, request, channel, context).await?; + } + Libp2pIncomingMessage::Response { + request_id: _, + response, + } => { + tracing::debug!("received incoming response {response:?}"); + handle_incoming_response(response); + } + } + Ok(()) +} + +async fn handle_incoming_request( + from: PeerId, + request: p2p::Request, + channel: p2p::ResponseChannel, + context: Context, +) -> Result<()> { + let resp = match request { + p2p::Request::ValidatorAuthentication(req) => { + tracing::debug!("handling ValidatorAuthentication request"); + match req { + p2p::ValidatorAuthenticationRequest::Initiation(req) => { + handle_validator_authentication_initiation_request(from, req, &context) + .await + .context("failed to handle ValidatorAuthenticationInitiationRequest")? + } + p2p::ValidatorAuthenticationRequest::Solution(req) => { + match handle_validator_authentication_solution_request(from, req, &context) + .await + { + Ok(()) => p2p::ValidatorAuthenticationSolutionResponse::Granted.into(), + Err(e) => { + tracing::error!( + "failed to handle ValidatorAuthenticationSolutionRequest: {e:?}" + ); + p2p::ValidatorAuthenticationSolutionResponse::Rejected.into() + } + } + } + } + } + p2p::Request::HardwareChallenge(req) => { + tracing::debug!("handling HardwareChallenge request"); + handle_hardware_challenge_request(from, req, &context) + .await + .context("failed to handle HardwareChallenge request")? + } + p2p::Request::Invite(req) => { + tracing::debug!("handling Invite request"); + match handle_invite_request(from, req, &context).await { + Ok(()) => p2p::InviteResponse::Ok.into(), + Err(e) => p2p::InviteResponse::Error(e.to_string()).into(), + } + } + p2p::Request::GetTaskLogs => { + tracing::debug!("handling GetTaskLogs request"); + handle_get_task_logs_request(from, &context).await + } + p2p::Request::Restart => { + tracing::debug!("handling Restart request"); + handle_restart_request(from, &context).await + } + p2p::Request::General(_) => { + todo!() + } + }; + + let outgoing_message = resp.into_outgoing_message(channel); + context + .outgoing_messages + .send(outgoing_message) + .await + .context("failed to send ValidatorAuthentication response")?; + + Ok(()) +} + +async fn handle_validator_authentication_initiation_request( + from: PeerId, + req: p2p::ValidatorAuthenticationInitiationRequest, + context: &Context, +) -> Result { + use rand_v8::Rng as _; + use shared::security::request_signer::sign_message; + + // generate a fresh cryptographically secure challenge message for this auth attempt + let challenge_bytes: [u8; 32] = rand_v8::rngs::OsRng.gen(); + let challenge_message = hex::encode(challenge_bytes); + let signature = sign_message(&req.message, &context.wallet) + .await + .map_err(|e| anyhow::anyhow!("failed to sign message: {e:?}"))?; + + // store the challenge message in nonce cache to prevent replay + let mut nonce_cache = context.nonce_cache.write().await; + nonce_cache.insert(challenge_message.clone(), SystemTime::now()); + + // store the current challenge for this peer + let mut ongoing_auth_challenges = context.ongoing_auth_challenges.write().await; + ongoing_auth_challenges.insert(from, challenge_message.clone()); + + Ok(p2p::ValidatorAuthenticationInitiationResponse { + message: challenge_message, + signature, + } + .into()) +} + +async fn handle_validator_authentication_solution_request( + from: PeerId, + req: p2p::ValidatorAuthenticationSolutionRequest, + context: &Context, +) -> Result<()> { + use std::str::FromStr as _; + + let mut ongoing_auth_challenges = context.ongoing_auth_challenges.write().await; + let challenge_message = ongoing_auth_challenges + .remove(&from) + .ok_or_else(|| anyhow::anyhow!("no ongoing authentication challenge for peer {from}"))?; + + let mut nonce_cache = context.nonce_cache.write().await; + if nonce_cache.remove(&challenge_message).is_none() { + anyhow::bail!("challenge message {challenge_message} not found in nonce cache"); + } + + let Ok(signature) = alloy::primitives::Signature::from_str(&req.signature) else { + anyhow::bail!("failed to parse signature from message"); + }; + + let Ok(recovered_address) = signature.recover_address_from_msg(challenge_message) else { + anyhow::bail!("failed to recover address from signature and message"); + }; + + if !context.validator_addresses.contains(&recovered_address) { + anyhow::bail!("recovered address {recovered_address} is not in the list of authorized validator addresses"); + } + + let mut authorized_peers = context.authorized_peers.write().await; + authorized_peers.insert(from); + Ok(()) +} + +async fn handle_hardware_challenge_request( + from: PeerId, + request: p2p::HardwareChallengeRequest, + context: &Context, +) -> Result { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + // TODO: error response variant? + anyhow::bail!("unauthorized peer {from} attempted to access HardwareChallenge request"); + } + + let challenge_response = p2p::calc_matrix(&request.challenge); + let response = p2p::HardwareChallengeResponse { + response: challenge_response, + timestamp: SystemTime::now(), + }; + Ok(response.into()) +} + +async fn handle_get_task_logs_request(from: PeerId, context: &Context) -> Response { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return p2p::GetTaskLogsResponse::Error("unauthorized".to_string()).into(); + } + + match context.docker_service.get_logs().await { + Ok(logs) => p2p::GetTaskLogsResponse::Ok(logs).into(), + Err(e) => p2p::GetTaskLogsResponse::Error(format!("failed to get task logs: {e:?}")).into(), + } +} + +async fn handle_restart_request(from: PeerId, context: &Context) -> Response { + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return p2p::RestartResponse::Error("unauthorized".to_string()).into(); + } + + match context.docker_service.restart_task().await { + Ok(()) => p2p::RestartResponse::Ok.into(), + Err(e) => p2p::RestartResponse::Error(format!("failed to restart task: {e:?}")).into(), + } +} + +fn handle_incoming_response(response: p2p::Response) { + // critical developer error if any of these happen, could panic here + match response { + p2p::Response::ValidatorAuthentication(_) => { + tracing::error!("worker should never receive ValidatorAuthentication responses"); + } + p2p::Response::HardwareChallenge(_) => { + tracing::error!("worker should never receive HardwareChallenge responses"); + } + p2p::Response::Invite(_) => { + tracing::error!("worker should never receive Invite responses"); + } + p2p::Response::GetTaskLogs(_) => { + tracing::error!("worker should never receive GetTaskLogs responses"); + } + p2p::Response::Restart(_) => { + tracing::error!("worker should never receive Restart responses"); + } + p2p::Response::General(_) => { + todo!() + } + } +} + +async fn handle_invite_request( + from: PeerId, + req: p2p::InviteRequest, + context: &Context, +) -> Result<()> { + use crate::console::Console; + use shared::web3::contracts::helpers::utils::retry_call; + use shared::web3::contracts::structs::compute_pool::PoolStatus; + + let authorized_peers = context.authorized_peers.read().await; + if !authorized_peers.contains(&from) { + return Err(anyhow::anyhow!( + "unauthorized peer {from} attempted to send invite" + )); + } + + if context.system_state.is_running().await { + anyhow::bail!("heartbeat is currently running and in a compute pool"); + } + + if req.pool_id != context.system_state.get_compute_pool_id() { + anyhow::bail!( + "pool ID mismatch: expected {}, got {}", + context.system_state.get_compute_pool_id(), + req.pool_id + ); + } + + let invite_bytes = hex::decode(&req.invite).context("failed to decode invite hex")?; + + if invite_bytes.len() < 65 { + anyhow::bail!("invite data is too short, expected at least 65 bytes"); + } + + let contracts = &context.contracts; + let pool_id = alloy::primitives::U256::from(req.pool_id); + + let bytes_array: [u8; 65] = match invite_bytes[..65].try_into() { + Ok(array) => array, + Err(_) => { + anyhow::bail!("failed to convert invite bytes to 65 byte array"); + } + }; + + let provider_address = context.provider_wallet.wallet.default_signer().address(); + + let pool_info = match contracts.compute_pool.get_pool_info(pool_id).await { + Ok(info) => info, + Err(err) => { + anyhow::bail!("failed to get pool info: {err:?}"); + } + }; + + if let PoolStatus::PENDING = pool_info.status { + anyhow::bail!("invalid invite; pool is pending"); + } + + let node_address = vec![context.wallet.wallet.default_signer().address()]; + let signatures = vec![alloy::primitives::FixedBytes::from(&bytes_array)]; + let call = contracts + .compute_pool + .build_join_compute_pool_call( + pool_id, + provider_address, + node_address, + vec![req.nonce], + vec![req.expiration], + signatures, + ) + .map_err(|e| anyhow::anyhow!("failed to build join compute pool call: {e:?}"))?; + + let provider = &context.provider_wallet.provider; + match retry_call(call, 3, provider.clone(), None).await { + Ok(result) => { + Console::section("WORKER JOINED COMPUTE POOL"); + Console::success(&format!( + "Successfully registered on chain with tx: {result}" + )); + Console::info( + "Status", + "Worker is now part of the compute pool and ready to receive tasks", + ); + } + Err(err) => { + anyhow::bail!("failed to join compute pool: {err:?}"); + } + } + + let heartbeat_endpoint = match req.url { + InviteRequestUrl::MasterIpPort(ip, port) => { + format!("http://{ip}:{port}/heartbeat") + } + InviteRequestUrl::MasterUrl(url) => format!("{url}/heartbeat"), + }; + + context + .heartbeat_service + .start(heartbeat_endpoint) + .await + .context("failed to start heartbeat service")?; + Ok(()) +} diff --git a/crates/worker/src/p2p/service.rs b/crates/worker/src/p2p/service.rs deleted file mode 100644 index 51a68405..00000000 --- a/crates/worker/src/p2p/service.rs +++ /dev/null @@ -1,736 +0,0 @@ -use crate::console::Console; -use crate::docker::DockerService; -use crate::operations::heartbeat::service::HeartbeatService; -use crate::state::system_state::SystemState; -use alloy::primitives::{Address, FixedBytes, U256}; -use anyhow::Result; -use dashmap::DashMap; -use iroh::endpoint::Incoming; -use iroh::{Endpoint, RelayMode, SecretKey}; -use lazy_static::lazy_static; -use log::{debug, error, info, warn}; -use rand_v8::Rng; -use shared::models::challenge::calc_matrix; -use shared::models::invite::InviteRequest; -use shared::p2p::messages::MAX_MESSAGE_SIZE; -use shared::p2p::messages::{P2PMessage, P2PRequest, P2PResponse}; -use shared::p2p::protocol::PRIME_P2P_PROTOCOL; -use shared::security::request_signer::sign_message; -use shared::web3::contracts::core::builder::Contracts; -use shared::web3::contracts::helpers::utils::retry_call; -use shared::web3::contracts::structs::compute_pool::PoolStatus; -use shared::web3::wallet::{Wallet, WalletProvider}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; -use tokio_util::sync::CancellationToken; - -lazy_static! { - static ref NONCE_CACHE: DashMap = DashMap::new(); -} - -#[derive(Clone)] -pub(crate) struct P2PContext { - pub docker_service: Arc, - pub heartbeat_service: Arc, - pub system_state: Arc, - pub contracts: Contracts, - pub node_wallet: Wallet, - pub provider_wallet: Wallet, -} - -#[derive(Clone)] -pub(crate) struct P2PService { - endpoint: Endpoint, - secret_key: SecretKey, - node_id: String, - listening_addrs: Vec, - cancellation_token: CancellationToken, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, -} - -enum EndpointLoopResult { - Shutdown, - EndpointClosed, -} - -impl P2PService { - /// Create a new P2P service with a unique worker identity - pub(crate) async fn new( - worker_p2p_seed: Option, - cancellation_token: CancellationToken, - context: Option, - wallet: Wallet, - allowed_addresses: Vec
, - ) -> Result { - // Generate or derive the secret key for this worker - let secret_key = if let Some(seed) = worker_p2p_seed { - // Derive from seed for deterministic identity - let mut seed_bytes = [0u8; 32]; - seed_bytes[..8].copy_from_slice(&seed.to_le_bytes()); - SecretKey::from_bytes(&seed_bytes) - } else { - let mut rng = rand_v8::thread_rng(); - SecretKey::generate(&mut rng) - }; - - let node_id = secret_key.public().to_string(); - info!("Starting P2P service with node ID: {node_id}"); - - // Create the endpoint - let endpoint = Endpoint::builder() - .secret_key(secret_key.clone()) - .alpns(vec![PRIME_P2P_PROTOCOL.to_vec()]) - .discovery_n0() - .relay_mode(RelayMode::Default) - .bind() - .await?; - - // Get listening addresses - let node_addr = endpoint.node_addr().await?; - let listening_addrs = node_addr - .direct_addresses - .iter() - .map(|addr| addr.to_string()) - .collect::>(); - - info!("P2P service listening on: {listening_addrs:?}"); - - Ok(Self { - endpoint, - secret_key, - node_id, - listening_addrs, - cancellation_token, - context, - allowed_addresses, - wallet, - }) - } - - /// Get the P2P node ID - pub(crate) fn node_id(&self) -> &str { - &self.node_id - } - - /// Get the listening addresses - pub(crate) fn listening_addresses(&self) -> &[String] { - &self.listening_addrs - } - - /// Recreate the endpoint with the same identity - async fn recreate_endpoint(&self) -> Result { - info!("Recreating P2P endpoint with node ID: {}", self.node_id); - - let endpoint = Endpoint::builder() - .secret_key(self.secret_key.clone()) - .alpns(vec![PRIME_P2P_PROTOCOL.to_vec()]) - .discovery_n0() - .relay_mode(RelayMode::Default) - .bind() - .await?; - - let node_addr = endpoint.node_addr().await?; - let listening_addrs = node_addr - .direct_addresses - .iter() - .map(|addr| addr.to_string()) - .collect::>(); - - info!("P2P endpoint recreated, listening on: {listening_addrs:?}"); - Ok(endpoint) - } - /// Start accepting incoming connections with automatic recovery - pub(crate) fn start(&self) -> Result<()> { - let service = Arc::new(self.clone()); - let cancellation_token = self.cancellation_token.clone(); - - tokio::spawn(async move { - service.run_with_recovery(cancellation_token).await; - }); - - Ok(()) - } - - /// Run the P2P service with automatic endpoint recovery - async fn run_with_recovery(&self, cancellation_token: CancellationToken) { - let mut endpoint = self.endpoint.clone(); - let mut retry_delay = Duration::from_secs(1); - const MAX_RETRY_DELAY: Duration = Duration::from_secs(60); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - info!("P2P service shutting down"); - break; - } - result = self.run_endpoint_loop(&endpoint, &cancellation_token) => { - match result { - EndpointLoopResult::Shutdown => break, - EndpointLoopResult::EndpointClosed => { - warn!("P2P endpoint closed, attempting recovery in {retry_delay:?}"); - - tokio::select! { - _ = cancellation_token.cancelled() => break, - _ = tokio::time::sleep(retry_delay) => {} - } - - match self.recreate_endpoint().await { - Ok(new_endpoint) => { - info!("P2P endpoint successfully recovered"); - endpoint = new_endpoint; - retry_delay = Duration::from_secs(1); - } - Err(e) => { - error!("Failed to recreate P2P endpoint: {e}"); - retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); - } - } - } - } - } - } - } - } - - /// Run the main endpoint acceptance loop - async fn run_endpoint_loop( - &self, - endpoint: &Endpoint, - cancellation_token: &CancellationToken, - ) -> EndpointLoopResult { - let context = self.context.clone(); - let allowed_addresses = self.allowed_addresses.clone(); - let wallet = self.wallet.clone(); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - return EndpointLoopResult::Shutdown; - } - incoming = endpoint.accept() => { - if let Some(incoming) = incoming { - tokio::spawn(Self::handle_connection(incoming, context.clone(), allowed_addresses.clone(), wallet.clone())); - } else { - return EndpointLoopResult::EndpointClosed; - } - } - } - } - } - - /// Handle an incoming connection - async fn handle_connection( - incoming: Incoming, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, - ) { - match incoming.await { - Ok(connection) => { - match connection.accept_bi().await { - Ok((send, recv)) => { - if let Err(e) = - Self::handle_stream(send, recv, context, allowed_addresses, wallet) - .await - { - error!("Error handling stream: {e}"); - } - // Wait a bit before closing to ensure client has processed response - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - Err(e) => { - error!("Failed to accept bi-stream: {e}"); - connection.close(1u32.into(), b"stream error"); - } - } - } - Err(e) => { - // Only log as debug for protocol mismatches, which are expected - if e.to_string() - .contains("peer doesn't support any known protocol") - { - debug!("Connection attempt with unsupported protocol: {e}"); - } else { - error!("Failed to accept connection: {e}"); - } - } - } - } - - /// Read a message from the stream - async fn read_message(recv: &mut iroh::endpoint::RecvStream) -> Result { - // Read message length - let mut msg_len_bytes = [0u8; 4]; - match recv.read_exact(&mut msg_len_bytes).await { - Ok(_) => {} - Err(e) => { - debug!("Stream read ended: {e}"); - return Err(anyhow::anyhow!("Stream closed")); - } - } - let msg_len = u32::from_be_bytes(msg_len_bytes) as usize; - - // Enforce maximum message size - if msg_len > MAX_MESSAGE_SIZE { - error!("Message size {msg_len} exceeds maximum allowed size {MAX_MESSAGE_SIZE}"); - return Err(anyhow::anyhow!("Message too large")); - } - - let mut msg_bytes = vec![0u8; msg_len]; - recv.read_exact(&mut msg_bytes).await?; - - let request: P2PRequest = serde_json::from_slice(&msg_bytes) - .map_err(|e| anyhow::anyhow!("Failed to deserialize P2P request: {}", e))?; - - debug!("Received P2P request: {request:?}"); - Ok(request) - } - - async fn write_response( - send: &mut iroh::endpoint::SendStream, - response: P2PResponse, - ) -> Result<()> { - let response_bytes = serde_json::to_vec(&response)?; - - // Check response size before sending - if response_bytes.len() > MAX_MESSAGE_SIZE { - error!( - "Response size {} exceeds maximum allowed size {}", - response_bytes.len(), - MAX_MESSAGE_SIZE - ); - return Err(anyhow::anyhow!("Response too large")); - } - - send.write_all(&(response_bytes.len() as u32).to_be_bytes()) - .await?; - send.write_all(&response_bytes).await?; - Ok(()) - } - - /// Handle a bidirectional stream - async fn handle_stream( - mut send: iroh::endpoint::SendStream, - mut recv: iroh::endpoint::RecvStream, - context: Option, - allowed_addresses: Vec
, - wallet: Wallet, - ) -> Result<()> { - // Handle multiple messages in sequence - let mut is_authorized = false; - let mut current_challenge: Option = None; - - loop { - let Ok(request) = Self::read_message(&mut recv).await else { - break; - }; - - // Handle the request - let response = match request.message { - P2PMessage::Ping { nonce, .. } => { - info!("Received ping with nonce: {nonce}"); - P2PResponse::new( - request.id, - P2PMessage::Pong { - timestamp: SystemTime::now(), - nonce, - }, - ) - } - P2PMessage::RequestAuthChallenge { message } => { - // Generate a fresh cryptographically secure challenge message for this auth attempt - let challenge_bytes: [u8; 32] = rand_v8::rngs::OsRng.gen(); - let challenge_message = hex::encode(challenge_bytes); - - debug!("Received request auth challenge"); - let signature = match sign_message(&message, &wallet).await { - Ok(signature) => signature, - Err(e) => { - error!("Failed to sign message: {e}"); - return Err(anyhow::anyhow!("Failed to sign message: {}", e)); - } - }; - - // Store the challenge message in nonce cache to prevent replay - NONCE_CACHE.insert(challenge_message.clone(), SystemTime::now()); - - // Store the current challenge for this connection - current_challenge = Some(challenge_message.clone()); - - P2PResponse::new( - request.id, - P2PMessage::AuthChallenge { - message: challenge_message, - signed_message: signature, - }, - ) - } - P2PMessage::AuthSolution { signed_message } => { - // Get the challenge message for this connection - debug!("Received auth solution"); - let Some(challenge_message) = ¤t_challenge else { - warn!("No active challenge for auth solution"); - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Check if challenge message has been used before (replay attack prevention) - if !NONCE_CACHE.contains_key(challenge_message) { - warn!("Challenge message not found or expired: {challenge_message}"); - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - } - - // Clean up old nonces (older than 5 minutes) - let cutoff_time = SystemTime::now() - Duration::from_secs(300); - NONCE_CACHE.retain(|_, &mut timestamp| timestamp > cutoff_time); - - // Parse the signature - let Ok(parsed_signature) = - alloy::primitives::Signature::from_str(&signed_message) - else { - // Handle signature parsing error - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Recover address from the challenge message that the client signed - let Ok(recovered_address) = - parsed_signature.recover_address_from_msg(challenge_message) - else { - // Handle address recovery error - let response = P2PResponse::new(request.id, P2PMessage::AuthRejected {}); - Self::write_response(&mut send, response).await?; - continue; - }; - - // Check if the recovered address is in allowed addresses - NONCE_CACHE.remove(challenge_message); - current_challenge = None; - if allowed_addresses.contains(&recovered_address) { - is_authorized = true; - P2PResponse::new(request.id, P2PMessage::AuthGranted {}) - } else { - P2PResponse::new(request.id, P2PMessage::AuthRejected {}) - } - } - P2PMessage::HardwareChallenge { challenge, .. } if is_authorized => { - info!("Received hardware challenge"); - let challenge_response = calc_matrix(&challenge); - P2PResponse::new( - request.id, - P2PMessage::HardwareChallengeResponse { - response: challenge_response, - timestamp: SystemTime::now(), - }, - ) - } - P2PMessage::Invite(invite) if is_authorized => { - if let Some(context) = &context { - let (status, error) = Self::handle_invite(invite, context).await; - P2PResponse::new(request.id, P2PMessage::InviteResponse { status, error }) - } else { - P2PResponse::new( - request.id, - P2PMessage::InviteResponse { - status: "error".to_string(), - error: Some("No context".to_string()), - }, - ) - } - } - P2PMessage::GetTaskLogs if is_authorized => { - if let Some(context) = &context { - let logs = context.docker_service.get_logs().await; - let response_logs = logs - .map(|log_string| vec![log_string]) - .map_err(|e| e.to_string()); - P2PResponse::new( - request.id, - P2PMessage::GetTaskLogsResponse { - logs: response_logs, - }, - ) - } else { - P2PResponse::new( - request.id, - P2PMessage::GetTaskLogsResponse { logs: Ok(vec![]) }, - ) - } - } - P2PMessage::RestartTask if is_authorized => { - if let Some(context) = &context { - let result = context.docker_service.restart_task().await; - let response_result = result.map_err(|e| e.to_string()); - P2PResponse::new( - request.id, - P2PMessage::RestartTaskResponse { - result: response_result, - }, - ) - } else { - P2PResponse::new( - request.id, - P2PMessage::RestartTaskResponse { result: Ok(()) }, - ) - } - } - _ => { - warn!("Unexpected message type"); - continue; - } - }; - - // Send response - Self::write_response(&mut send, response).await?; - } - - Ok(()) - } - - async fn handle_invite( - invite: InviteRequest, - context: &P2PContext, - ) -> (String, Option) { - if context.system_state.is_running().await { - return ( - "error".to_string(), - Some("Heartbeat is currently running and in a compute pool".to_string()), - ); - } - if let Some(pool_id) = context.system_state.compute_pool_id.clone() { - if invite.pool_id.to_string() != pool_id { - return ("error".to_string(), Some("Invalid pool ID".to_string())); - } - } - - let invite_bytes = match hex::decode(&invite.invite) { - Ok(bytes) => bytes, - Err(err) => { - error!("Failed to decode invite hex string: {err:?}"); - return ( - "error".to_string(), - Some("Invalid invite format".to_string()), - ); - } - }; - - if invite_bytes.len() < 65 { - return ( - "error".to_string(), - Some("Invite data is too short".to_string()), - ); - } - - let contracts = &context.contracts; - let wallet = &context.node_wallet; - let pool_id = U256::from(invite.pool_id); - - let bytes_array: [u8; 65] = match invite_bytes[..65].try_into() { - Ok(array) => array, - Err(_) => { - error!("Failed to convert invite bytes to fixed-size array"); - return ( - "error".to_string(), - Some("Invalid invite signature format".to_string()), - ); - } - }; - - let provider_address = context.provider_wallet.wallet.default_signer().address(); - - let pool_info = match contracts.compute_pool.get_pool_info(pool_id).await { - Ok(info) => info, - Err(err) => { - error!("Failed to get pool info: {err:?}"); - return ( - "error".to_string(), - Some("Failed to get pool information".to_string()), - ); - } - }; - - if let PoolStatus::PENDING = pool_info.status { - Console::user_error("Pool is pending - Invite is invalid"); - return ( - "error".to_string(), - Some("Pool is pending - Invite is invalid".to_string()), - ); - } - - let node_address = vec![wallet.wallet.default_signer().address()]; - let signatures = vec![FixedBytes::from(&bytes_array)]; - let nonces = vec![invite.nonce]; - let expirations = vec![invite.expiration]; - let call = match contracts.compute_pool.build_join_compute_pool_call( - pool_id, - provider_address, - node_address, - nonces, - expirations, - signatures, - ) { - Ok(call) => call, - Err(err) => { - error!("Failed to build join compute pool call: {err:?}"); - return ( - "error".to_string(), - Some("Failed to build join compute pool call".to_string()), - ); - } - }; - let provider = &context.provider_wallet.provider; - match retry_call(call, 3, provider.clone(), None).await { - Ok(result) => { - Console::section("WORKER JOINED COMPUTE POOL"); - Console::success(&format!( - "Successfully registered on chain with tx: {result}" - )); - Console::info( - "Status", - "Worker is now part of the compute pool and ready to receive tasks", - ); - } - Err(err) => { - error!("Failed to join compute pool: {err:?}"); - return ( - "error".to_string(), - Some(format!("Failed to join compute pool: {err}")), - ); - } - } - let endpoint = if let Some(url) = &invite.master_url { - format!("{url}/heartbeat") - } else { - match (&invite.master_ip, &invite.master_port) { - (Some(ip), Some(port)) => format!("http://{ip}:{port}/heartbeat"), - _ => { - error!("Missing master IP or port in invite request"); - return ( - "error".to_string(), - Some("Missing master IP or port".to_string()), - ); - } - } - }; - - if let Err(err) = context.heartbeat_service.start(endpoint).await { - error!("Failed to start heartbeat service: {err:?}"); - return ( - "error".to_string(), - Some("Failed to start heartbeat service".to_string()), - ); - } - - ("ok".to_string(), None) - } -} - -#[cfg(test)] -mod tests { - use rand_v8::Rng; - use serial_test::serial; - use shared::p2p::P2PClient; - use url::Url; - - use super::*; - - async fn setup_test_service( - include_addresses: bool, - ) -> (P2PService, P2PClient, Address, Address) { - let validator_wallet = shared::web3::wallet::Wallet::new( - "0000000000000000000000000000000000000000000000000000000000000001", - Url::parse("https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161").unwrap(), - ) - .unwrap(); - let worker_wallet = shared::web3::wallet::Wallet::new( - "0000000000000000000000000000000000000000000000000000000000000002", - Url::parse("https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161").unwrap(), - ) - .unwrap(); - let validator_wallet_address = validator_wallet.wallet.default_signer().address(); - let worker_wallet_address = worker_wallet.wallet.default_signer().address(); - let service = P2PService::new( - None, - CancellationToken::new(), - None, - worker_wallet, - if include_addresses { - vec![validator_wallet_address] - } else { - vec![] - }, - ) - .await - .unwrap(); - let client = P2PClient::new(validator_wallet.clone()).await.unwrap(); - ( - service, - client, - validator_wallet_address, - worker_wallet_address, - ) - } - - #[tokio::test] - #[serial] - async fn test_ping() { - let (service, client, _, worker_wallet_address) = setup_test_service(true).await; - let node_id = service.node_id().to_string(); - let addresses = service.listening_addresses().to_vec(); - let random_nonce = rand_v8::thread_rng().gen::(); - - tokio::spawn(async move { - service.start().unwrap(); - }); - - let ping = P2PMessage::Ping { - nonce: random_nonce, - timestamp: SystemTime::now(), - }; - - let response = client - .send_request(&node_id, &addresses, worker_wallet_address, ping, 20) - .await - .unwrap(); - - let response_nonce = match response { - P2PMessage::Pong { nonce, .. } => nonce, - _ => panic!("Expected Pong message"), - }; - assert_eq!(response_nonce, random_nonce); - } - #[tokio::test] - #[serial] - async fn test_auth_error() { - let (service, client, _, worker_wallet_address) = setup_test_service(false).await; - let node_id = service.node_id().to_string(); - let addresses = service.listening_addresses().to_vec(); - - tokio::spawn(async move { - service.start().unwrap(); - }); - - let ping = P2PMessage::Ping { - nonce: rand_v8::thread_rng().gen::(), - timestamp: SystemTime::now(), - }; - - // Since we set include_addresses to false, the client's wallet address - // is not in the allowed_addresses list, so we expect auth to be rejected - let result = client - .send_request(&node_id, &addresses, worker_wallet_address, ping, 20) - .await; - - assert!( - result.is_err(), - "Expected auth to be rejected but request succeeded" - ); - } -} diff --git a/crates/worker/src/state/system_state.rs b/crates/worker/src/state/system_state.rs index fd8f0a3a..bed32693 100644 --- a/crates/worker/src/state/system_state.rs +++ b/crates/worker/src/state/system_state.rs @@ -2,7 +2,6 @@ use anyhow::Result; use directories::ProjectDirs; use log::debug; use log::error; -use log::warn; use serde::{Deserialize, Serialize}; use std::fs; use std::path::Path; @@ -10,9 +9,6 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use crate::utils::p2p::generate_iroh_node_id_from_seed; -use crate::utils::p2p::generate_random_seed; - const STATE_FILENAME: &str = "heartbeat_state.toml"; fn get_default_state_dir() -> Option { @@ -23,8 +19,29 @@ fn get_default_state_dir() -> Option { #[derive(Debug, Clone, Serialize, Deserialize)] struct PersistedSystemState { endpoint: Option, - p2p_seed: Option, - worker_p2p_seed: Option, + #[serde( + serialize_with = "serialize_keypair", + deserialize_with = "deserialize_keypair" + )] + p2p_keypair: p2p::Keypair, +} + +fn serialize_keypair(keypair: &p2p::Keypair, serializer: S) -> Result +where + S: serde::Serializer, +{ + let serialized = keypair + .to_protobuf_encoding() + .map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&serialized) +} + +fn deserialize_keypair<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let serialized: Vec = Deserialize::deserialize(deserializer)?; + p2p::Keypair::from_protobuf_encoding(&serialized).map_err(serde::de::Error::custom) } #[derive(Debug, Clone)] @@ -34,18 +51,15 @@ pub(crate) struct SystemState { endpoint: Arc>>, state_dir_overwrite: Option, disable_state_storing: bool, - pub compute_pool_id: Option, - - pub worker_p2p_seed: Option, - pub p2p_id: Option, - pub p2p_seed: Option, + compute_pool_id: u32, + p2p_keypair: p2p::Keypair, } impl SystemState { pub(crate) fn new( state_dir: Option, disable_state_storing: bool, - compute_pool_id: Option, + compute_pool_id: u32, ) -> Self { let default_state_dir = get_default_state_dir(); debug!("Default state dir: {default_state_dir:?}"); @@ -53,9 +67,10 @@ impl SystemState { .map(PathBuf::from) .or_else(|| default_state_dir.map(PathBuf::from)); debug!("State path: {state_path:?}"); + let mut endpoint = None; - let mut p2p_seed: Option = None; - let mut worker_p2p_seed: Option = None; + let mut p2p_keypair = None; + // Try to load state, log info if creating new file if !disable_state_storing { if let Some(path) = &state_path { @@ -67,31 +82,15 @@ impl SystemState { } else if let Ok(Some(loaded_state)) = SystemState::load_state(path) { debug!("Loaded previous state from {state_file:?}"); endpoint = loaded_state.endpoint; - p2p_seed = loaded_state.p2p_seed; - worker_p2p_seed = loaded_state.worker_p2p_seed; + p2p_keypair = Some(loaded_state.p2p_keypair); } else { debug!("Failed to load state from {state_file:?}"); } } } - if p2p_seed.is_none() { - let seed = generate_random_seed(); - p2p_seed = Some(seed); - } - // Generate p2p_id from seed if available - - let p2p_id: Option = - p2p_seed.and_then(|seed| match generate_iroh_node_id_from_seed(seed) { - Ok(id) => Some(id), - Err(_) => { - warn!("Failed to generate p2p_id from seed"); - None - } - }); - if worker_p2p_seed.is_none() { - let seed = generate_random_seed(); - worker_p2p_seed = Some(seed); + if p2p_keypair.is_none() { + p2p_keypair = Some(p2p::Keypair::generate_ed25519()); } Self { @@ -101,44 +100,34 @@ impl SystemState { state_dir_overwrite: state_path.clone(), disable_state_storing, compute_pool_id, - p2p_seed, - p2p_id, - worker_p2p_seed, + p2p_keypair: p2p_keypair.expect("p2p keypair must be Some at this point"), } } + fn save_state(&self, heartbeat_endpoint: Option) -> Result<()> { if !self.disable_state_storing { debug!("Saving state"); if let Some(state_dir) = &self.state_dir_overwrite { - // Get values without block_on - debug!("Saving p2p_seed: {:?}", self.p2p_seed); - - // Ensure p2p_seed is valid before creating state - if let Some(seed) = self.p2p_seed { - let state = PersistedSystemState { - endpoint: heartbeat_endpoint, - p2p_seed: Some(seed), - worker_p2p_seed: self.worker_p2p_seed, - }; - - debug!("state: {state:?}"); - - fs::create_dir_all(state_dir)?; - let state_path = state_dir.join(STATE_FILENAME); - - // Use JSON serialization instead of TOML - match serde_json::to_string_pretty(&state) { - Ok(json_string) => { - fs::write(&state_path, json_string)?; - debug!("Saved state to {state_path:?}"); - } - Err(e) => { - error!("Failed to serialize state: {e}"); - return Err(anyhow::anyhow!("Failed to serialize state: {}", e)); - } + let state = PersistedSystemState { + endpoint: heartbeat_endpoint, + p2p_keypair: self.p2p_keypair.clone(), + }; + + debug!("state: {state:?}"); + + fs::create_dir_all(state_dir)?; + let state_path = state_dir.join(STATE_FILENAME); + + // Use JSON serialization instead of TOML + match serde_json::to_string_pretty(&state) { + Ok(json_string) => { + fs::write(&state_path, json_string)?; + debug!("Saved state to {state_path:?}"); + } + Err(e) => { + error!("Failed to serialize state: {e}"); + return Err(anyhow::anyhow!("Failed to serialize state: {}", e)); } - } else { - warn!("Cannot save state: p2p_seed is None"); } } } @@ -160,12 +149,16 @@ impl SystemState { Ok(None) } - pub(crate) fn get_p2p_seed(&self) -> Option { - self.p2p_seed + pub(crate) fn get_compute_pool_id(&self) -> u32 { + self.compute_pool_id + } + + pub(crate) fn get_p2p_keypair(&self) -> &p2p::Keypair { + &self.p2p_keypair } - pub(crate) fn get_p2p_id(&self) -> Option { - self.p2p_id.clone() + pub(crate) fn get_p2p_id(&self) -> p2p::PeerId { + self.p2p_keypair.public().to_peer_id() } pub(crate) async fn update_last_heartbeat(&self) { @@ -238,9 +231,8 @@ mod tests { let state = SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, + 0, ); - assert!(state.p2p_id.is_some()); let _ = state .set_running(true, Some("http://localhost:8080/heartbeat".to_string())) .await; @@ -266,7 +258,7 @@ mod tests { let state = SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, + 0, ); assert!(!(state.is_running().await)); assert_eq!(state.get_heartbeat_endpoint().await, None); @@ -285,7 +277,7 @@ mod tests { let state = SystemState::new( Some(temp_dir.path().to_string_lossy().to_string()), false, - None, + 0, ); assert_eq!( state.get_heartbeat_endpoint().await, diff --git a/crates/worker/src/utils/mod.rs b/crates/worker/src/utils/mod.rs index 210f1e35..6a79dd07 100644 --- a/crates/worker/src/utils/mod.rs +++ b/crates/worker/src/utils/mod.rs @@ -1,2 +1 @@ pub(crate) mod logging; -pub(crate) mod p2p; diff --git a/crates/worker/src/utils/p2p.rs b/crates/worker/src/utils/p2p.rs deleted file mode 100644 index ef07b28c..00000000 --- a/crates/worker/src/utils/p2p.rs +++ /dev/null @@ -1,60 +0,0 @@ -use iroh::SecretKey; -use rand_v8::Rng; -use rand_v8::{rngs::StdRng, SeedableRng}; -use std::error::Error; - -/// Generate a random seed -pub(crate) fn generate_random_seed() -> u64 { - rand_v8::thread_rng().gen() -} - -// Generate an Iroh node ID from a seed -pub(crate) fn generate_iroh_node_id_from_seed(seed: u64) -> Result> { - // Create a deterministic RNG from the seed - let mut rng = StdRng::seed_from_u64(seed); - - // Generate the secret key using Iroh's method - // This matches exactly how it's done in your Node implementation - let secret_key = SecretKey::generate(&mut rng); - - // Get the node ID (public key) as a string - let node_id = secret_key.public().to_string(); - - Ok(node_id) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_random_seed() { - let seed1 = generate_random_seed(); - let seed2 = generate_random_seed(); - - assert_ne!(seed1, seed2); - } - - #[test] - fn test_known_generation() { - let seed: u32 = 848364385; - let result = generate_iroh_node_id_from_seed(seed as u64).unwrap(); - assert_eq!( - result, - "6ba970180efbd83909282ac741085431f54aa516e1783852978bd529a400d0e9" - ); - assert_eq!(result.len(), 64); - } - - #[test] - fn test_deterministic_generation() { - // Same seed should generate same node_id - let seed = generate_random_seed(); - println!("seed: {}", seed); - let result1 = generate_iroh_node_id_from_seed(seed).unwrap(); - let result2 = generate_iroh_node_id_from_seed(seed).unwrap(); - println!("result1: {}", result1); - - assert_eq!(result1, result2); - } -}