diff --git a/Cargo.lock b/Cargo.lock index c16f0570..99309465 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6211,8 +6211,6 @@ dependencies = [ "actix-web-prometheus", "alloy", "anyhow", - "async-trait", - "base64 0.22.1", "chrono", "clap", "env_logger", @@ -6220,11 +6218,10 @@ dependencies = [ "google-cloud-auth 0.18.0", "google-cloud-storage", "hex", - "iroh", "log", "mockito", + "p2p", "prometheus 0.14.0", - "rand 0.8.5", "rand 0.9.1", "redis", "redis-test", @@ -6233,6 +6230,7 @@ dependencies = [ "serde_json", "shared", "tokio", + "tokio-util", "url", "utoipa", "utoipa-swagger-ui", @@ -8233,12 +8231,14 @@ dependencies = [ "base64 0.22.1", "chrono", "dashmap", + "futures", "futures-util", "google-cloud-storage", "hex", "iroh", "log", "nalgebra", + "p2p", "rand 0.8.5", "rand 0.9.1", "redis", @@ -8247,6 +8247,7 @@ dependencies = [ "serde_json", "subtle", "tokio", + "tokio-util", "url", "utoipa", "uuid", @@ -9448,13 +9449,11 @@ dependencies = [ "env_logger", "futures", "hex", - "iroh", "lazy_static", "log", "mockito", - "nalgebra", + "p2p", "prometheus 0.14.0", - "rand 0.8.5", "rand 0.9.1", "redis", "redis-test", @@ -9466,7 +9465,6 @@ dependencies = [ "tempfile", "tokio", "tokio-util", - "toml", "url", ] diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index 6ac53140..ce733ee6 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -7,35 +7,35 @@ edition.workspace = true workspace = true [dependencies] +p2p = { workspace = true} +shared = { workspace = true } + actix-web = { workspace = true } -actix-web-prometheus = "0.1.2" alloy = { workspace = true } anyhow = { workspace = true } -async-trait = "0.1.88" -base64 = "0.22.1" chrono = { workspace = true, features = ["serde"] } clap = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } -google-cloud-auth = "0.18.0" -google-cloud-storage = "0.24.0" hex = { workspace = true } log = { workspace = true } -prometheus = "0.14.0" -rand = "0.9.0" redis = { workspace = true, features = ["tokio-comp"] } redis-test = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -shared = { workspace = true } tokio = { workspace = true } +tokio-util = { workspace = true } url = { workspace = true } +uuid = { workspace = true } + +actix-web-prometheus = "0.1.2" +google-cloud-auth = "0.18.0" +google-cloud-storage = "0.24.0" +prometheus = "0.14.0" +rand = "0.9.0" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } utoipa-swagger-ui = { version = "9.0.2", features = ["actix-web", "debug-embed", "reqwest", "vendored"] } -uuid = { workspace = true } -iroh = { workspace = true } -rand_v8 = { workspace = true } [dev-dependencies] mockito = { workspace = true } diff --git a/crates/orchestrator/src/api/routes/groups.rs b/crates/orchestrator/src/api/routes/groups.rs index 44b22cd9..414f524a 100644 --- a/crates/orchestrator/src/api/routes/groups.rs +++ b/crates/orchestrator/src/api/routes/groups.rs @@ -236,9 +236,6 @@ async fn fetch_node_logs_p2p( match node { Some(node) => { - // Check if P2P client is available - let p2p_client = app_state.p2p_client.clone(); - // Check if node has P2P information let (worker_p2p_id, worker_p2p_addresses) = match (&node.worker_p2p_id, &node.worker_p2p_addresses) { @@ -254,11 +251,22 @@ async fn fetch_node_logs_p2p( }; // Send P2P request for task logs - match tokio::time::timeout( - Duration::from_secs(NODE_REQUEST_TIMEOUT), - p2p_client.get_task_logs(node_address, worker_p2p_id, worker_p2p_addresses), - ) - .await + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let get_task_logs_request = crate::p2p::GetTaskLogsRequest { + worker_wallet_address: node_address, + worker_p2p_id: worker_p2p_id.clone(), + worker_addresses: worker_p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.get_task_logs_tx.send(get_task_logs_request).await { + error!("Failed to send GetTaskLogsRequest for node {node_address}: {e}"); + return json!({ + "success": false, + "error": format!("Failed to send request: {}", e), + "status": node.status.to_string() + }); + }; + match tokio::time::timeout(Duration::from_secs(NODE_REQUEST_TIMEOUT), response_rx).await { Ok(Ok(log_lines)) => { json!({ diff --git a/crates/orchestrator/src/api/routes/nodes.rs b/crates/orchestrator/src/api/routes/nodes.rs index a260706a..9debddde 100644 --- a/crates/orchestrator/src/api/routes/nodes.rs +++ b/crates/orchestrator/src/api/routes/nodes.rs @@ -164,11 +164,22 @@ async fn restart_node_task(node_id: web::Path, app_state: Data .as_ref() .expect("worker_p2p_addresses should be present"); - match app_state - .p2p_client - .restart_task(node_address, p2p_id, p2p_addresses) - .await - { + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let restart_task_request = crate::p2p::RestartTaskRequest { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.restart_task_tx.send(restart_task_request).await { + error!("Failed to send restart task request: {e}"); + return HttpResponse::InternalServerError().json(json!({ + "success": false, + "error": "Failed to send restart task request" + })); + } + + match response_rx.await { Ok(_) => HttpResponse::Ok().json(json!({ "success": true, "message": "Task restarted successfully" @@ -240,11 +251,22 @@ async fn get_node_logs(node_id: web::Path, app_state: Data) -> })); }; - match app_state - .p2p_client - .get_task_logs(node_address, p2p_id, p2p_addresses) - .await - { + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let get_task_logs_request = crate::p2p::GetTaskLogsRequest { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + response_tx, + }; + if let Err(e) = app_state.get_task_logs_tx.send(get_task_logs_request).await { + error!("Failed to send get task logs request: {e}"); + return HttpResponse::InternalServerError().json(json!({ + "success": false, + "error": "Failed to send get task logs request" + })); + } + + match response_rx.await { Ok(logs) => HttpResponse::Ok().json(json!({ "success": true, "logs": logs diff --git a/crates/orchestrator/src/api/server.rs b/crates/orchestrator/src/api/server.rs index 095bcb6c..fc5943c9 100644 --- a/crates/orchestrator/src/api/server.rs +++ b/crates/orchestrator/src/api/server.rs @@ -5,7 +5,7 @@ use crate::api::routes::task::tasks_routes; use crate::api::routes::{heartbeat::heartbeat_routes, metrics::metrics_routes}; use crate::metrics::MetricsContext; use crate::models::node::NodeStatus; -use crate::p2p::client::P2PClient; +use crate::p2p::{GetTaskLogsRequest, RestartTaskRequest}; use crate::plugins::node_groups::NodeGroupsPlugin; use crate::scheduler::Scheduler; use crate::store::core::{RedisStore, StoreContext}; @@ -23,6 +23,7 @@ use shared::utils::StorageProvider; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::WalletProvider; use std::sync::Arc; +use tokio::sync::mpsc::Sender; use utoipa::{ openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}, Modify, OpenApi, @@ -116,17 +117,18 @@ async fn health_check(data: web::Data) -> HttpResponse { } pub(crate) struct AppState { - pub store_context: Arc, - pub storage_provider: Option>, - pub heartbeats: Arc, - pub redis_store: Arc, - pub hourly_upload_limit: i64, - pub contracts: Option>, - pub pool_id: u32, - pub scheduler: Scheduler, - pub node_groups_plugin: Option>, - pub metrics: Arc, - pub p2p_client: Arc, + pub(crate) store_context: Arc, + pub(crate) storage_provider: Option>, + pub(crate) heartbeats: Arc, + pub(crate) redis_store: Arc, + pub(crate) hourly_upload_limit: i64, + pub(crate) contracts: Option>, + pub(crate) pool_id: u32, + pub(crate) scheduler: Scheduler, + pub(crate) node_groups_plugin: Option>, + pub(crate) metrics: Arc, + pub(crate) get_task_logs_tx: Sender, + pub(crate) restart_task_tx: Sender, } #[allow(clippy::too_many_arguments)] @@ -145,7 +147,8 @@ pub async fn start_server( scheduler: Scheduler, node_groups_plugin: Option>, metrics: Arc, - p2p_client: Arc, + get_task_logs_tx: Sender, + restart_task_tx: Sender, ) -> Result<(), Error> { info!("Starting server at http://{host}:{port}"); let app_state = Data::new(AppState { @@ -159,7 +162,8 @@ pub async fn start_server( scheduler, node_groups_plugin, metrics, - p2p_client, + get_task_logs_tx, + restart_task_tx, }); let node_store = app_state.store_context.node_store.clone(); let node_store_clone = node_store.clone(); diff --git a/crates/orchestrator/src/api/tests/helper.rs b/crates/orchestrator/src/api/tests/helper.rs index ca2e65c1..92b26cce 100644 --- a/crates/orchestrator/src/api/tests/helper.rs +++ b/crates/orchestrator/src/api/tests/helper.rs @@ -18,12 +18,12 @@ use std::sync::Arc; use url::Url; #[cfg(test)] -pub async fn create_test_app_state() -> Data { +pub(crate) async fn create_test_app_state() -> Data { use shared::utils::MockStorageProvider; use crate::{ - metrics::MetricsContext, p2p::client::P2PClient, scheduler::Scheduler, - utils::loop_heartbeats::LoopHeartbeats, ServerMode, + metrics::MetricsContext, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, + ServerMode, }; let store = Arc::new(RedisStore::new_test()); @@ -46,12 +46,8 @@ pub async fn create_test_app_state() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new(1.to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(0); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(0); Data::new(AppState { store_context: store_context.clone(), @@ -64,17 +60,17 @@ pub async fn create_test_app_state() -> Data { scheduler, node_groups_plugin: None, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } #[cfg(test)] -pub async fn create_test_app_state_with_nodegroups() -> Data { +pub(crate) async fn create_test_app_state_with_nodegroups() -> Data { use shared::utils::MockStorageProvider; use crate::{ metrics::MetricsContext, - p2p::client::P2PClient, plugins::node_groups::{NodeGroupConfiguration, NodeGroupsPlugin}, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, @@ -116,12 +112,8 @@ pub async fn create_test_app_state_with_nodegroups() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new(1.to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(0); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(0); Data::new(AppState { store_context: store_context.clone(), @@ -134,12 +126,13 @@ pub async fn create_test_app_state_with_nodegroups() -> Data { scheduler, node_groups_plugin, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } #[cfg(test)] -pub fn setup_contract() -> Contracts { +pub(crate) fn setup_contract() -> Contracts { let coordinator_key = "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97"; let rpc_url: Url = Url::parse("http://localhost:8545").unwrap(); let wallet = Wallet::new(coordinator_key, rpc_url).unwrap(); @@ -154,12 +147,12 @@ pub fn setup_contract() -> Contracts { } #[cfg(test)] -pub async fn create_test_app_state_with_metrics() -> Data { +pub(crate) async fn create_test_app_state_with_metrics() -> Data { use shared::utils::MockStorageProvider; use crate::{ - metrics::MetricsContext, p2p::client::P2PClient, scheduler::Scheduler, - utils::loop_heartbeats::LoopHeartbeats, ServerMode, + metrics::MetricsContext, scheduler::Scheduler, utils::loop_heartbeats::LoopHeartbeats, + ServerMode, }; let store = Arc::new(RedisStore::new_test()); @@ -182,12 +175,8 @@ pub async fn create_test_app_state_with_metrics() -> Data { let mock_storage = MockStorageProvider::new(); let storage_provider = Arc::new(mock_storage); let metrics = Arc::new(MetricsContext::new("0".to_string())); - let wallet = Wallet::new( - "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97", - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let (get_task_logs_tx, _) = tokio::sync::mpsc::channel(0); + let (restart_task_tx, _) = tokio::sync::mpsc::channel(0); Data::new(AppState { store_context: store_context.clone(), @@ -200,6 +189,7 @@ pub async fn create_test_app_state_with_metrics() -> Data { scheduler, node_groups_plugin: None, metrics, - p2p_client: p2p_client.clone(), + get_task_logs_tx, + restart_task_tx, }) } diff --git a/crates/orchestrator/src/discovery/monitor.rs b/crates/orchestrator/src/discovery/monitor.rs index 56fed833..d1ea3133 100644 --- a/crates/orchestrator/src/discovery/monitor.rs +++ b/crates/orchestrator/src/discovery/monitor.rs @@ -384,15 +384,12 @@ impl DiscoveryMonitor { if let Some(balance) = discovery_node.latest_balance { if balance == U256::ZERO { - info!( - "Node {} has zero balance, marking as low balance", - node_address - ); + info!("Node {node_address} has zero balance, marking as low balance"); if let Err(e) = self .update_node_status(&node_address, NodeStatus::LowBalance) .await { - error!("Error updating node status: {}", e); + error!("Error updating node status: {e}"); } } } diff --git a/crates/orchestrator/src/lib.rs b/crates/orchestrator/src/lib.rs index 5f82d58d..19d13eba 100644 --- a/crates/orchestrator/src/lib.rs +++ b/crates/orchestrator/src/lib.rs @@ -16,7 +16,7 @@ pub use metrics::sync_service::MetricsSyncService; pub use metrics::webhook_sender::MetricsWebhookSender; pub use metrics::MetricsContext; pub use node::invite::NodeInviter; -pub use p2p::client::P2PClient; +pub use p2p::Service as P2PService; pub use plugins::node_groups::NodeGroupConfiguration; pub use plugins::node_groups::NodeGroupsPlugin; pub use plugins::webhook::WebhookConfig; diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index f9beaccb..5f8e2af2 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -9,12 +9,13 @@ use shared::web3::contracts::core::builder::ContractBuilder; use shared::web3::wallet::Wallet; use std::sync::Arc; use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; use url::Url; use orchestrator::{ start_server, DiscoveryMonitor, LoopHeartbeats, MetricsContext, MetricsSyncService, MetricsWebhookSender, NodeGroupConfiguration, NodeGroupsPlugin, NodeInviter, NodeStatusUpdater, - P2PClient, RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, + P2PService, RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, StoreContext, WebhookConfig, WebhookPlugin, }; @@ -91,6 +92,10 @@ struct Args { /// Max healthy nodes with same endpoint #[arg(long, default_value = "1")] max_healthy_nodes_with_same_endpoint: u32, + + /// Libp2p port + #[arg(long, default_value = "4004")] + libp2p_port: u16, } #[tokio::main] @@ -143,7 +148,27 @@ async fn main() -> Result<()> { let store = Arc::new(RedisStore::new(&args.redis_store_url)); let store_context = Arc::new(StoreContext::new(store.clone())); - let p2p_client = Arc::new(P2PClient::new(wallet.clone()).await.unwrap()); + let keypair = p2p::Keypair::generate_ed25519(); + let cancellation_token = CancellationToken::new(); + let (p2p_service, invite_tx, get_task_logs_tx, restart_task_tx) = { + match P2PService::new( + keypair, + args.libp2p_port, + cancellation_token.clone(), + wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res + } + Err(e) => { + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); + } + } + }; + + tokio::task::spawn(p2p_service.run()); let contracts = ContractBuilder::new(wallet.provider()) .with_compute_registry() @@ -297,24 +322,29 @@ async fn main() -> Result<()> { let inviter_store_context = store_context.clone(); let inviter_heartbeats = heartbeats.clone(); - tasks.spawn({ - let wallet = wallet.clone(); - let p2p_client = p2p_client.clone(); - async move { - let inviter = NodeInviter::new( - wallet, - compute_pool_id, - domain_id, - args.host.as_deref(), - Some(&args.port), - args.url.as_deref(), - inviter_store_context.clone(), - inviter_heartbeats.clone(), - p2p_client, - ); - inviter.run().await + let wallet = wallet.clone(); + let inviter = match NodeInviter::new( + wallet, + compute_pool_id, + domain_id, + args.host.as_deref(), + Some(&args.port), + args.url.as_deref(), + inviter_store_context.clone(), + inviter_heartbeats.clone(), + invite_tx, + ) { + Ok(inviter) => { + info!("Node inviter initialized successfully"); + inviter } - }); + Err(e) => { + error!("Failed to initialize node inviter: {e}"); + std::process::exit(1); + } + }; + + tasks.spawn(async move { inviter.run().await }); // Create status_update_plugins for status updater let mut status_updater_plugins: Vec = vec![]; @@ -387,7 +417,8 @@ async fn main() -> Result<()> { scheduler, node_groups_plugin, metrics_context, - p2p_client, + get_task_logs_tx, + restart_task_tx, ) => { if let Err(e) = res { error!("Server error: {e}"); @@ -403,6 +434,8 @@ async fn main() -> Result<()> { } } + // TODO: use cancellation token to gracefully shutdown tasks + cancellation_token.cancel(); tasks.shutdown().await; Ok(()) } diff --git a/crates/orchestrator/src/node/invite.rs b/crates/orchestrator/src/node/invite.rs index 17ae4207..8391d047 100644 --- a/crates/orchestrator/src/node/invite.rs +++ b/crates/orchestrator/src/node/invite.rs @@ -1,40 +1,40 @@ use crate::models::node::NodeStatus; use crate::models::node::OrchestratorNode; -use crate::p2p::client::P2PClient; +use crate::p2p::InviteRequest as InviteRequestWithMetadata; use crate::store::core::StoreContext; use crate::utils::loop_heartbeats::LoopHeartbeats; use alloy::primitives::utils::keccak256 as keccak; use alloy::primitives::U256; use alloy::signers::Signer; -use anyhow::Result; +use anyhow::{bail, Result}; use futures::stream; use futures::StreamExt; use log::{debug, error, info, warn}; -use shared::models::invite::InviteRequest; +use p2p::InviteRequest; +use p2p::InviteRequestUrl; use shared::web3::wallet::Wallet; use std::sync::Arc; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use tokio::sync::mpsc::Sender; use tokio::time::{interval, Duration}; // Timeout constants const DEFAULT_INVITE_CONCURRENT_COUNT: usize = 32; // Max concurrent count of nodes being invited -pub struct NodeInviter<'a> { +pub struct NodeInviter { wallet: Wallet, pool_id: u32, domain_id: u32, - host: Option<&'a str>, - port: Option<&'a u16>, - url: Option<&'a str>, + url: InviteRequestUrl, store_context: Arc, heartbeats: Arc, - p2p_client: Arc, + invite_tx: Sender, } -impl<'a> NodeInviter<'a> { +impl NodeInviter { #[allow(clippy::too_many_arguments)] - pub fn new( + pub fn new<'a>( wallet: Wallet, pool_id: u32, domain_id: u32, @@ -43,19 +43,31 @@ impl<'a> NodeInviter<'a> { url: Option<&'a str>, store_context: Arc, heartbeats: Arc, - p2p_client: Arc, - ) -> Self { - Self { + invite_tx: Sender, + ) -> Result { + let url = if let Some(url) = url { + InviteRequestUrl::MasterUrl(url.to_string()) + } else { + let Some(host) = host else { + bail!("either host or url must be provided"); + }; + + let Some(port) = port else { + bail!("either port or url must be provided"); + }; + + InviteRequestUrl::MasterIpPort(host.to_string(), *port) + }; + + Ok(Self { wallet, pool_id, domain_id, - host, - port, url, store_context, heartbeats, - p2p_client, - } + invite_tx, + }) } pub async fn run(&self) -> Result<()> { @@ -71,7 +83,7 @@ impl<'a> NodeInviter<'a> { } } - async fn _generate_invite( + async fn generate_invite( &self, node: &OrchestratorNode, nonce: [u8; 32], @@ -102,7 +114,7 @@ impl<'a> NodeInviter<'a> { Ok(signature) } - async fn _send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { + async fn send_invite(&self, node: &OrchestratorNode) -> Result<(), anyhow::Error> { if node.worker_p2p_id.is_none() || node.worker_p2p_addresses.is_none() { return Err(anyhow::anyhow!("Node does not have p2p information")); } @@ -120,21 +132,11 @@ impl<'a> NodeInviter<'a> { ) .to_be_bytes(); - let invite_signature = self._generate_invite(node, nonce, expiration).await?; + let invite_signature = self.generate_invite(node, nonce, expiration).await?; let payload = InviteRequest { invite: hex::encode(invite_signature), pool_id: self.pool_id, - master_url: self.url.map(|u| u.to_string()), - master_ip: if self.url.is_none() { - self.host.map(|h| h.to_string()) - } else { - None - }, - master_port: if self.url.is_none() { - self.port.copied() - } else { - None - }, + url: self.url.clone(), timestamp: SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|e| anyhow::anyhow!("System time error: {}", e))? @@ -145,11 +147,19 @@ impl<'a> NodeInviter<'a> { info!("Sending invite to node: {p2p_id}"); - match self - .p2p_client - .invite_worker(node.address, p2p_id, p2p_addresses, payload) + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let invite = InviteRequestWithMetadata { + worker_wallet_address: node.address, + worker_p2p_id: p2p_id.clone(), + worker_addresses: p2p_addresses.clone(), + invite: payload, + response_tx, + }; + self.invite_tx + .send(invite) .await - { + .map_err(|_| anyhow::anyhow!("failed to send invite request"))?; + match response_rx.await { Ok(_) => { info!("Successfully invited node"); if let Err(e) = self @@ -182,7 +192,7 @@ impl<'a> NodeInviter<'a> { let invited_nodes = stream::iter(nodes.into_iter().map(|node| async move { info!("Processing node {:?}", node.address); - match self._send_invite(&node).await { + match self.send_invite(&node).await { Ok(_) => { info!("Successfully processed node {:?}", node.address); Ok(()) diff --git a/crates/orchestrator/src/p2p/client.rs b/crates/orchestrator/src/p2p/client.rs deleted file mode 100644 index 39810151..00000000 --- a/crates/orchestrator/src/p2p/client.rs +++ /dev/null @@ -1,102 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Result; -use log::{info, warn}; -use shared::models::invite::InviteRequest; -use shared::p2p::{client::P2PClient as SharedP2PClient, messages::P2PMessage}; -use shared::web3::wallet::Wallet; - -pub struct P2PClient { - shared_client: SharedP2PClient, -} - -impl P2PClient { - pub async fn new(wallet: Wallet) -> Result { - let shared_client = SharedP2PClient::new(wallet).await?; - Ok(Self { shared_client }) - } - - pub async fn invite_worker( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - invite: InviteRequest, - ) -> Result<()> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::Invite(invite), - 20, - ) - .await?; - - match response { - P2PMessage::InviteResponse { status, error } => { - if status == "ok" { - info!("Successfully invited worker {worker_p2p_id}"); - Ok(()) - } else { - let error_msg = error.unwrap_or_else(|| "Unknown error".to_string()); - warn!("Failed to invite worker {worker_p2p_id}: {error_msg}"); - Err(anyhow::anyhow!("Invite failed: {}", error_msg)) - } - } - _ => Err(anyhow::anyhow!("Unexpected response type for invite")), - } - } - - pub async fn get_task_logs( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::GetTaskLogs, - 20, - ) - .await?; - - match response { - P2PMessage::GetTaskLogsResponse { logs } => { - logs.map_err(|e| anyhow::anyhow!("Failed to get task logs: {}", e)) - } - _ => Err(anyhow::anyhow!( - "Unexpected response type for get_task_logs" - )), - } - } - - pub async fn restart_task( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result<()> { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::RestartTask, - 25, - ) - .await?; - - match response { - P2PMessage::RestartTaskResponse { result } => { - result.map_err(|e| anyhow::anyhow!("Failed to restart task: {}", e)) - } - _ => Err(anyhow::anyhow!("Unexpected response type for restart_task")), - } - } -} diff --git a/crates/orchestrator/src/p2p/mod.rs b/crates/orchestrator/src/p2p/mod.rs index 1d331315..f3bf57cf 100644 --- a/crates/orchestrator/src/p2p/mod.rs +++ b/crates/orchestrator/src/p2p/mod.rs @@ -1 +1,171 @@ -pub(crate) mod client; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use futures::FutureExt; +use p2p::{Keypair, Protocols}; +use shared::p2p::OutgoingRequest; +use shared::p2p::Service as P2PService; +use shared::web3::wallet::Wallet; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio_util::sync::CancellationToken; + +pub struct Service { + inner: P2PService, + outgoing_message_tx: Sender, + invite_rx: Receiver, + get_task_logs_rx: Receiver, + restart_task_rx: Receiver, +} + +impl Service { + #[allow(clippy::type_complexity)] + pub fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + ) -> Result<( + Self, + Sender, + Sender, + Sender, + )> { + let (invite_tx, invite_rx) = tokio::sync::mpsc::channel(100); + let (get_task_logs_tx, get_task_logs_rx) = tokio::sync::mpsc::channel(100); + let (restart_task_tx, restart_task_rx) = tokio::sync::mpsc::channel(100); + let (inner, outgoing_message_tx) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet, + Protocols::new() + .with_invite() + .with_get_task_logs() + .with_restart() + .with_validator_authentication(), + ) + .context("failed to create p2p service")?; + Ok(( + Self { + inner, + outgoing_message_tx, + invite_rx, + get_task_logs_rx, + restart_task_rx, + }, + invite_tx, + get_task_logs_tx, + restart_task_tx, + )) + } + + pub async fn run(self) -> Result<()> { + use futures::StreamExt as _; + + let Self { + inner, + outgoing_message_tx, + mut invite_rx, + mut get_task_logs_rx, + mut restart_task_rx, + } = self; + + tokio::task::spawn(inner.run()); + + let mut futures = FuturesUnordered::new(); + + loop { + tokio::select! { + Some(request) = invite_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::Invite(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for invite request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: request.invite.into(), + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing invite request")?; + } + Some(request) = get_task_logs_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::GetTaskLogs(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for get task logs request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::Request::GetTaskLogs, + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing get task logs request")?; + } + Some(request) = restart_task_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let p2p::Response::Restart(resp) = incoming_resp_rx.await.context("outgoing request tx channel was dropped")? else { + bail!("unexpected response type for restart task request"); + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }.boxed(); + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::Request::Restart, + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing restart task request")?; + } + Some(res) = futures.next() => { + if let Err(e) = res { + log::error!("failed to handle response conversion: {e}"); + } + } + } + } + } +} + +pub struct InviteRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) invite: p2p::InviteRequest, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} + +pub struct GetTaskLogsRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} + +pub struct RestartTaskRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs index 0a5637a9..0ad1e4a5 100644 --- a/crates/p2p/src/lib.rs +++ b/crates/p2p/src/lib.rs @@ -15,9 +15,9 @@ mod message; mod protocol; use behaviour::Behaviour; -use protocol::Protocols; pub use message::*; +pub use protocol::*; pub type Libp2pIncomingMessage = libp2p::request_response::Message; pub type ResponseChannel = libp2p::request_response::ResponseChannel; @@ -120,7 +120,8 @@ impl Node { } Some(message) = outgoing_message_rx.recv() => { match message { - OutgoingMessage::Request((peer, request)) => { + OutgoingMessage::Request((peer, _addrs, request)) => { + // TODO: if we're not connected to the peer, we should dial it swarm.behaviour_mut().request_response().send_request(&peer, request); } OutgoingMessage::Response((channel, response)) => { @@ -238,6 +239,11 @@ impl NodeBuilder { self } + pub fn with_protocols(mut self, protocols: Protocols) -> Self { + self.protocols.join(protocols); + self + } + pub fn with_bootnode(mut self, bootnode: Multiaddr) -> Self { self.bootnodes.push(bootnode); self @@ -370,7 +376,7 @@ mod test { // send request from node1->node2 let request = message::Request::GetTaskLogs; outgoing_message_tx1 - .send(request.into_outgoing_message(node2_peer_id)) + .send(request.into_outgoing_message(node2_peer_id, vec![])) .await .unwrap(); let message = incoming_message_rx2.recv().await.unwrap(); diff --git a/crates/p2p/src/message/mod.rs b/crates/p2p/src/message/mod.rs index adff99ac..dc2403e3 100644 --- a/crates/p2p/src/message/mod.rs +++ b/crates/p2p/src/message/mod.rs @@ -1,3 +1,4 @@ +use crate::Protocol; use libp2p::PeerId; use serde::{Deserialize, Serialize}; use std::time::SystemTime; @@ -15,7 +16,7 @@ pub struct IncomingMessage { #[allow(clippy::large_enum_variant)] #[derive(Debug)] pub enum OutgoingMessage { - Request((PeerId, Request)), + Request((PeerId, Vec, Request)), Response( ( libp2p::request_response::ResponseChannel, @@ -35,8 +36,23 @@ pub enum Request { } impl Request { - pub fn into_outgoing_message(self, peer: PeerId) -> OutgoingMessage { - OutgoingMessage::Request((peer, self)) + pub fn into_outgoing_message( + self, + peer: PeerId, + multiaddrs: Vec, + ) -> OutgoingMessage { + OutgoingMessage::Request((peer, multiaddrs, self)) + } + + pub fn protocol(&self) -> Protocol { + match self { + Request::ValidatorAuthentication(_) => Protocol::ValidatorAuthentication, + Request::HardwareChallenge(_) => Protocol::HardwareChallenge, + Request::Invite(_) => Protocol::Invite, + Request::GetTaskLogs => Protocol::GetTaskLogs, + Request::Restart => Protocol::Restart, + Request::General(_) => Protocol::General, + } } } diff --git a/crates/p2p/src/protocol.rs b/crates/p2p/src/protocol.rs index df423ef8..ae839cec 100644 --- a/crates/p2p/src/protocol.rs +++ b/crates/p2p/src/protocol.rs @@ -2,7 +2,7 @@ use libp2p::StreamProtocol; use std::{collections::HashSet, hash::Hash}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) enum Protocol { +pub enum Protocol { // validator -> worker ValidatorAuthentication, // validator -> worker @@ -33,42 +33,70 @@ impl Protocol { } #[derive(Debug, Clone)] -pub(crate) struct Protocols(HashSet); +pub struct Protocols(HashSet); impl Protocols { - pub(crate) fn new() -> Self { + pub fn new() -> Self { Self(HashSet::new()) } - pub(crate) fn with_validator_authentication(mut self) -> Self { + pub fn has_validator_authentication(&self) -> bool { + self.0.contains(&Protocol::ValidatorAuthentication) + } + + pub fn has_hardware_challenge(&self) -> bool { + self.0.contains(&Protocol::HardwareChallenge) + } + + pub fn has_invite(&self) -> bool { + self.0.contains(&Protocol::Invite) + } + + pub fn has_get_task_logs(&self) -> bool { + self.0.contains(&Protocol::GetTaskLogs) + } + + pub fn has_restart(&self) -> bool { + self.0.contains(&Protocol::Restart) + } + + pub fn has_general(&self) -> bool { + self.0.contains(&Protocol::General) + } + + pub fn with_validator_authentication(mut self) -> Self { self.0.insert(Protocol::ValidatorAuthentication); self } - pub(crate) fn with_hardware_challenge(mut self) -> Self { + pub fn with_hardware_challenge(mut self) -> Self { self.0.insert(Protocol::HardwareChallenge); self } - pub(crate) fn with_invite(mut self) -> Self { + pub fn with_invite(mut self) -> Self { self.0.insert(Protocol::Invite); self } - pub(crate) fn with_get_task_logs(mut self) -> Self { + pub fn with_get_task_logs(mut self) -> Self { self.0.insert(Protocol::GetTaskLogs); self } - pub(crate) fn with_restart(mut self) -> Self { + pub fn with_restart(mut self) -> Self { self.0.insert(Protocol::Restart); self } - pub(crate) fn with_general(mut self) -> Self { + pub fn with_general(mut self) -> Self { self.0.insert(Protocol::General); self } + + pub(crate) fn join(&mut self, other: Protocols) { + self.0.extend(other.0); + } } impl IntoIterator for Protocols { diff --git a/crates/shared/Cargo.toml b/crates/shared/Cargo.toml index 9afdafff..4d3a8760 100644 --- a/crates/shared/Cargo.toml +++ b/crates/shared/Cargo.toml @@ -15,6 +15,8 @@ default = [] testnet = [] [dependencies] +p2p = { workspace = true} + tokio = { workspace = true } alloy = { workspace = true } alloy-provider = { workspace = true } @@ -40,3 +42,5 @@ iroh = { workspace = true } rand_v8 = { workspace = true } subtle = "2.6.1" utoipa = { version = "5.3.0", features = ["actix_extras", "chrono", "uuid"] } +futures = { workspace = true } +tokio-util = { workspace = true } diff --git a/crates/shared/src/p2p/mod.rs b/crates/shared/src/p2p/mod.rs index f505f3b1..cac69a8a 100644 --- a/crates/shared/src/p2p/mod.rs +++ b/crates/shared/src/p2p/mod.rs @@ -1,6 +1,9 @@ pub mod client; pub mod messages; pub mod protocol; +mod service; pub use client::P2PClient; pub use protocol::*; + +pub use service::*; diff --git a/crates/shared/src/p2p/service.rs b/crates/shared/src/p2p/service.rs new file mode 100644 index 00000000..f5a7bbe3 --- /dev/null +++ b/crates/shared/src/p2p/service.rs @@ -0,0 +1,453 @@ +use crate::web3::wallet::Wallet; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use p2p::{ + IncomingMessage, Libp2pIncomingMessage, Node, NodeBuilder, OutgoingMessage, PeerId, Protocol, + Protocols, Response, ValidatorAuthenticationInitiationRequest, ValidatorAuthenticationResponse, + ValidatorAuthenticationSolutionRequest, +}; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; + +pub struct OutgoingRequest { + pub peer_wallet_address: alloy::primitives::Address, + pub request: p2p::Request, + pub peer_id: String, + pub multiaddrs: Vec, + pub response_tx: tokio::sync::oneshot::Sender, +} + +/// A p2p service implementation that is used by the validator and the orchestrator. +/// It handles the authentication protocol used before sending +/// requests to the worker. +pub struct Service { + _node: Node, + dial_tx: p2p::DialSender, + incoming_messages_rx: Receiver, + outgoing_messages_rx: Receiver, + cancellation_token: CancellationToken, + context: Context, +} + +impl Service { + pub fn new( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + protocols: Protocols, + ) -> Result<(Self, Sender)> { + let (node, dial_tx, incoming_messages_rx, outgoing_messages) = + build_p2p_node(keypair, port, cancellation_token.clone(), protocols.clone()) + .context("failed to build p2p node")?; + let (outgoing_messages_tx, outgoing_messages_rx) = tokio::sync::mpsc::channel(100); + + Ok(( + Self { + _node: node, + dial_tx, + incoming_messages_rx, + outgoing_messages_rx, + cancellation_token, + context: Context::new(outgoing_messages, wallet, protocols), + }, + outgoing_messages_tx, + )) + } + + pub async fn run(self) { + use futures::StreamExt as _; + + let Self { + _node, + dial_tx, + mut incoming_messages_rx, + mut outgoing_messages_rx, + cancellation_token, + context, + } = self; + + let mut message_handlers = FuturesUnordered::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break; + } + Some(message) = outgoing_messages_rx.recv() => { + if let Err(e) = handle_outgoing_message(message, dial_tx.clone(), context.clone()) + .await { + log::error!("failed to handle outgoing message: {e}"); + } + } + Some(message) = incoming_messages_rx.recv() => { + let context = context.clone(); + let handle = tokio::task::spawn( + handle_incoming_message(message, context) + ); + message_handlers.push(handle); + } + Some(res) = message_handlers.next() => { + if let Err(e) = res { + log::error!("failed to handle incoming message: {e}"); + } + } + } + } + } +} + +fn build_p2p_node( + keypair: p2p::Keypair, + port: u16, + cancellation_token: CancellationToken, + protocols: Protocols, +) -> Result<( + Node, + p2p::DialSender, + Receiver, + Sender, +)> { + NodeBuilder::new() + .with_keypair(keypair) + .with_port(port) + .with_validator_authentication() + .with_protocols(protocols) + .with_cancellation_token(cancellation_token) + .try_build() +} + +#[derive(Clone)] +struct Context { + // outbound message channel; receiver is held by libp2p node + outgoing_messages: Sender, + + // ongoing authentication requests + ongoing_auth_requests: Arc>>, + is_authenticated_with_peer: Arc>>, + + // this assumes that there is only one outbound request per protocol per peer at a time, + // is this a correct assumption? + // response channel is for sending the response back to the caller who initiated the request + ongoing_outbound_requests: + Arc>>>, + + wallet: Wallet, + protocols: Protocols, +} + +#[derive(Debug)] +struct OngoingAuthChallenge { + peer_wallet_address: alloy::primitives::Address, + auth_challenge_request_message: String, + outgoing_message: p2p::Request, + response_tx: tokio::sync::oneshot::Sender, +} + +impl Context { + fn new( + outgoing_messages: Sender, + wallet: Wallet, + protocols: Protocols, + ) -> Self { + Self { + outgoing_messages, + ongoing_auth_requests: Arc::new(RwLock::new(HashMap::new())), + is_authenticated_with_peer: Arc::new(RwLock::new(HashSet::new())), + ongoing_outbound_requests: Arc::new(RwLock::new(HashMap::new())), + wallet, + protocols, + } + } +} + +async fn handle_outgoing_message( + message: OutgoingRequest, + dial_tx: p2p::DialSender, + context: Context, +) -> Result<()> { + use rand_v8::rngs::OsRng; + use rand_v8::Rng as _; + use std::str::FromStr as _; + + let OutgoingRequest { + peer_wallet_address, + request, + peer_id, + multiaddrs, + response_tx, + } = message; + + let peer_id = PeerId::from_str(&peer_id).context("failed to parse peer id")?; + + // check if we're authenticated already + let is_authenticated_with_peer = context.is_authenticated_with_peer.read().await; + if is_authenticated_with_peer.contains(&peer_id) { + log::debug!( + "already authenticated with peer {peer_id}, skipping validation authentication" + ); + // multiaddresses are already known, as we've connected to them previously + context + .outgoing_messages + .send(request.into_outgoing_message(peer_id, vec![])) + .await + .context("failed to send outgoing message")?; + return Ok(()); + } + + log::debug!("sending validation authentication request to {peer_id}"); + + // first, dial the worker + // ensure there's no ongoing challenge + // use write-lock to make this atomic until we finish sending the auth request and writing to the map + let mut ongoing_auth_requests = context.ongoing_auth_requests.write().await; + if ongoing_auth_requests.contains_key(&peer_id) { + bail!("ongoing auth request for {} already exists", peer_id); + } + + let multiaddrs = multiaddrs + .iter() + .filter_map(|addr| p2p::Multiaddr::from_str(addr).ok()?.with_p2p(peer_id).ok()) + .collect::>(); + if multiaddrs.is_empty() { + bail!("no valid multiaddrs for peer id {peer_id}"); + } + + // TODO: we can improve this by checking if we're already connected to the peer before dialing + let (res_tx, res_rx) = tokio::sync::oneshot::channel(); + dial_tx + .send((multiaddrs.clone(), res_tx)) + .await + .context("failed to send dial request")?; + res_rx + .await + .context("failed to receive dial response")? + .context("failed to dial worker")?; + + // create the authentication challenge request message + let challenge_bytes: [u8; 32] = OsRng.gen(); + let auth_challenge_message: String = hex::encode(challenge_bytes); + + let req: p2p::Request = ValidatorAuthenticationInitiationRequest { + message: auth_challenge_message.clone(), + } + .into(); + let outgoing_message = req.into_outgoing_message(peer_id, multiaddrs); + log::debug!("sending ValidatorAuthenticationInitiationRequest to {peer_id}"); + context + .outgoing_messages + .send(outgoing_message) + .await + .context("failed to send outgoing message")?; + + // store the ongoing auth challenge + let ongoing_challenge = OngoingAuthChallenge { + peer_wallet_address, + auth_challenge_request_message: auth_challenge_message.clone(), + outgoing_message: request, + response_tx, + }; + + ongoing_auth_requests.insert(peer_id, ongoing_challenge); + Ok(()) +} + +async fn handle_incoming_message(message: IncomingMessage, context: Context) -> Result<()> { + match message.message { + Libp2pIncomingMessage::Request { + request_id: _, + request, + channel: _, + } => { + log::error!( + "node should not receive incoming requests: {request:?} from {}", + message.peer + ); + } + Libp2pIncomingMessage::Response { + request_id: _, + response, + } => { + log::debug!("received incoming response {response:?}"); + handle_incoming_response(message.peer, response, context) + .await + .context("failed to handle incoming response")?; + } + } + Ok(()) +} + +async fn handle_incoming_response( + from: PeerId, + response: p2p::Response, + context: Context, +) -> Result<()> { + match response { + p2p::Response::ValidatorAuthentication(resp) => { + log::debug!("received ValidatorAuthenticationSolutionResponse from {from}: {resp:?}"); + handle_validation_authentication_response(from, resp, context) + .await + .context("failed to handle validator authentication response")?; + } + p2p::Response::HardwareChallenge(ref resp) => { + if !context.protocols.has_hardware_challenge() { + bail!("received HardwareChallengeResponse from {from}, but hardware challenge protocol is not enabled"); + } + + log::debug!("received HardwareChallengeResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = + ongoing_outbound_requests.remove(&(from, Protocol::HardwareChallenge)) + else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle HardwareChallengeResponse" + ); + }; + let _ = response_tx.send(response); + } + p2p::Response::Invite(ref resp) => { + if !context.protocols.has_invite() { + bail!("received InviteResponse from {from}, but invite protocol is not enabled"); + } + + log::debug!("received InviteResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::Invite)) + else { + bail!("no ongoing invite for peer {from}, cannot handle InviteResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::GetTaskLogs(ref resp) => { + if !context.protocols.has_get_task_logs() { + bail!("received GetTaskLogsResponse from {from}, but get task logs protocol is not enabled"); + } + + log::debug!("received GetTaskLogsResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = + ongoing_outbound_requests.remove(&(from, Protocol::GetTaskLogs)) + else { + bail!("no ongoing GetTaskLogs for peer {from}, cannot handle GetTaskLogsResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::Restart(ref resp) => { + if !context.protocols.has_restart() { + bail!("received RestartResponse from {from}, but restart protocol is not enabled"); + } + + log::debug!("received RestartResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::Restart)) + else { + bail!("no ongoing Restart for peer {from}, cannot handle RestartResponse"); + }; + let _ = response_tx.send(response); + } + p2p::Response::General(ref resp) => { + if !context.protocols.has_general() { + bail!("received GeneralResponse from {from}, but general protocol is not enabled"); + } + + log::debug!("received GeneralResponse from {from}: {resp:?}"); + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + let Some(response_tx) = ongoing_outbound_requests.remove(&(from, Protocol::General)) + else { + bail!("no ongoing General for peer {from}, cannot handle GeneralResponse"); + }; + let _ = response_tx.send(response); + } + } + + Ok(()) +} + +async fn handle_validation_authentication_response( + from: PeerId, + response: p2p::ValidatorAuthenticationResponse, + context: Context, +) -> Result<()> { + use crate::security::request_signer::sign_message; + use std::str::FromStr as _; + + match response { + ValidatorAuthenticationResponse::Initiation(req) => { + let ongoing_auth_requests = context.ongoing_auth_requests.read().await; + let Some(ongoing_challenge) = ongoing_auth_requests.get(&from) else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationInitiationResponse" + ); + }; + + let Ok(parsed_signature) = alloy::primitives::Signature::from_str(&req.signature) + else { + bail!("failed to parse signature from response"); + }; + + // recover address from the challenge message that the peer signed + let Ok(recovered_address) = parsed_signature + .recover_address_from_msg(&ongoing_challenge.auth_challenge_request_message) + else { + bail!("Failed to recover address from response signature") + }; + + // verify the recovered address matches the expected worker wallet address + if recovered_address != ongoing_challenge.peer_wallet_address { + bail!( + "peer address verification failed: expected {}, got {recovered_address}", + ongoing_challenge.peer_wallet_address, + ) + } + + log::debug!("auth challenge initiation response received from node: {from}"); + let signature = sign_message(&req.message, &context.wallet).await.unwrap(); + + let req: p2p::Request = ValidatorAuthenticationSolutionRequest { signature }.into(); + let req = req.into_outgoing_message(from, vec![]); + context + .outgoing_messages + .send(req) + .await + .context("failed to send outgoing message")?; + } + ValidatorAuthenticationResponse::Solution(req) => { + let mut ongoing_auth_requests = context.ongoing_auth_requests.write().await; + let Some(ongoing_challenge) = ongoing_auth_requests.remove(&from) else { + bail!( + "no ongoing hardware challenge for peer {from}, cannot handle ValidatorAuthenticationSolutionResponse" + ); + }; + + match req { + p2p::ValidatorAuthenticationSolutionResponse::Granted => {} + p2p::ValidatorAuthenticationSolutionResponse::Rejected => { + log::debug!("auth challenge rejected by node: {from}"); + return Ok(()); + } + } + + // auth was granted, finally send the hardware challenge + let mut is_authenticated_with_peer = context.is_authenticated_with_peer.write().await; + is_authenticated_with_peer.insert(from); + + let protocol = ongoing_challenge.outgoing_message.protocol(); + let req = ongoing_challenge + .outgoing_message + .into_outgoing_message(from, vec![]); + context + .outgoing_messages + .send(req) + .await + .context("failed to send outgoing message")?; + + let mut ongoing_outbound_requests = context.ongoing_outbound_requests.write().await; + ongoing_outbound_requests.insert((from, protocol), ongoing_challenge.response_tx); + } + } + Ok(()) +} diff --git a/crates/validator/Cargo.toml b/crates/validator/Cargo.toml index db3694ca..4d329921 100644 --- a/crates/validator/Cargo.toml +++ b/crates/validator/Cargo.toml @@ -7,6 +7,9 @@ edition.workspace = true workspace = true [dependencies] +shared = { workspace = true } +p2p = { workspace = true} + actix-web = { workspace = true } alloy = { workspace = true } anyhow = { workspace = true } @@ -16,25 +19,21 @@ directories = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } hex = { workspace = true } -iroh = { workspace = true } -rand_v8 = { workspace = true } -lazy_static = "1.5.0" log = { workspace = true } -nalgebra = { workspace = true } -prometheus = "0.14.0" -rand = "0.9.0" redis = { workspace = true, features = ["tokio-comp"] } -redis-test = { workspace = true } -regex = "1.11.1" reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -shared = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } -toml = { workspace = true } url = { workspace = true } +lazy_static = "1.5.0" +prometheus = "0.14.0" +rand = "0.9.0" +regex = "1.11.1" + [dev-dependencies] mockito = { workspace = true } +redis-test = { workspace = true } tempfile = "=3.14.0" diff --git a/crates/validator/src/lib.rs b/crates/validator/src/lib.rs index 760af2d1..9fac5ce8 100644 --- a/crates/validator/src/lib.rs +++ b/crates/validator/src/lib.rs @@ -5,7 +5,7 @@ mod validators; pub use metrics::export_metrics; pub use metrics::MetricsContext; -pub use p2p::P2PClient; +pub use p2p::Service as P2PService; pub use store::redis::RedisStore; pub use validators::hardware::HardwareValidator; pub use validators::synthetic_data::types::InvalidationType; diff --git a/crates/validator/src/main.rs b/crates/validator/src/main.rs index 55b3900d..d17f5004 100644 --- a/crates/validator/src/main.rs +++ b/crates/validator/src/main.rs @@ -23,7 +23,7 @@ use tokio_util::sync::CancellationToken; use url::Url; use validator::{ - export_metrics, HardwareValidator, InvalidationType, MetricsContext, P2PClient, RedisStore, + export_metrics, HardwareValidator, InvalidationType, MetricsContext, P2PService, RedisStore, SyntheticDataValidator, }; @@ -196,6 +196,10 @@ struct Args { /// Redis URL #[arg(long, default_value = "redis://localhost:6380")] redis_url: String, + + /// Libp2p port + #[arg(long, default_value = "4003")] + libp2p_port: u16, } #[tokio::main] @@ -269,19 +273,27 @@ async fn main() -> anyhow::Result<()> { MetricsContext::new(validator_wallet.address().to_string(), args.pool_id.clone()); // Initialize P2P client if enabled - let p2p_client = { - match P2PClient::new(validator_wallet.clone()).await { - Ok(client) => { - info!("P2P client initialized for testing"); - Some(client) + let keypair = p2p::Keypair::generate_ed25519(); + let (p2p_service, hardware_challenge_tx) = { + match P2PService::new( + keypair, + args.libp2p_port, + cancellation_token.clone(), + validator_wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res } Err(e) => { - error!("Failed to initialize P2P client: {e}"); - None + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); } } }; + tokio::task::spawn(p2p_service.run()); + if let Some(pool_id) = args.pool_id.clone() { let pool = match contracts .compute_pool @@ -308,8 +320,7 @@ async fn main() -> anyhow::Result<()> { let contracts = contract_builder.build().unwrap(); - let hardware_validator = - HardwareValidator::new(&validator_wallet, contracts.clone(), p2p_client.as_ref()); + let hardware_validator = HardwareValidator::new(contracts.clone(), hardware_challenge_tx); let synthetic_validator = if let Some(pool_id) = args.pool_id.clone() { let penalty = U256::from(args.validator_penalty) * Unit::ETHER.wei(); diff --git a/crates/validator/src/p2p/client.rs b/crates/validator/src/p2p/client.rs deleted file mode 100644 index a0b21db1..00000000 --- a/crates/validator/src/p2p/client.rs +++ /dev/null @@ -1,89 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Result; -use log::info; -use rand_v8::Rng; -use shared::models::challenge::{ChallengeRequest, ChallengeResponse}; -use shared::p2p::{client::P2PClient as SharedP2PClient, messages::P2PMessage}; -use shared::web3::wallet::Wallet; -use std::time::SystemTime; - -pub struct P2PClient { - shared_client: SharedP2PClient, -} - -impl P2PClient { - pub async fn new(wallet: Wallet) -> Result { - let shared_client = SharedP2PClient::new(wallet).await?; - Ok(Self { shared_client }) - } - - pub async fn ping_worker( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - ) -> Result { - let nonce = rand_v8::thread_rng().gen::(); - - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::Ping { - timestamp: SystemTime::now(), - nonce, - }, - 10, - ) - .await?; - - match response { - P2PMessage::Pong { - nonce: returned_nonce, - .. - } => { - if returned_nonce == nonce { - info!("Received valid pong from worker {worker_p2p_id} with nonce: {nonce}"); - Ok(nonce) - } else { - Err(anyhow::anyhow!("Invalid nonce in pong response")) - } - } - _ => Err(anyhow::anyhow!("Unexpected response type for ping")), - } - } - - pub async fn send_hardware_challenge( - &self, - worker_wallet_address: Address, - worker_p2p_id: &str, - worker_addresses: &[String], - challenge: ChallengeRequest, - ) -> Result { - let response = self - .shared_client - .send_request( - worker_p2p_id, - worker_addresses, - worker_wallet_address, - P2PMessage::HardwareChallenge { - challenge, - timestamp: SystemTime::now(), - }, - 30, - ) - .await?; - - match response { - P2PMessage::HardwareChallengeResponse { response, .. } => { - info!("Received hardware challenge response from worker {worker_p2p_id}"); - Ok(response) - } - _ => Err(anyhow::anyhow!( - "Unexpected response type for hardware challenge" - )), - } - } -} diff --git a/crates/validator/src/p2p/mod.rs b/crates/validator/src/p2p/mod.rs index 33dad50c..dc6b23e6 100644 --- a/crates/validator/src/p2p/mod.rs +++ b/crates/validator/src/p2p/mod.rs @@ -1,3 +1,102 @@ -pub(crate) mod client; +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use p2p::{Keypair, Protocols}; +use shared::p2p::OutgoingRequest; +use shared::p2p::Service as P2PService; +use shared::web3::wallet::Wallet; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio_util::sync::CancellationToken; -pub use client::P2PClient; +pub struct Service { + inner: P2PService, + + // converts incoming hardware challenges to outgoing requests + outgoing_message_tx: Sender, + hardware_challenge_rx: Receiver, +} + +impl Service { + pub fn new( + keypair: Keypair, + port: u16, + cancellation_token: CancellationToken, + wallet: Wallet, + ) -> Result<(Self, Sender)> { + let (hardware_challenge_tx, hardware_challenge_rx) = tokio::sync::mpsc::channel(100); + let (inner, outgoing_message_tx) = P2PService::new( + keypair, + port, + cancellation_token.clone(), + wallet, + Protocols::new() + .with_hardware_challenge() + .with_validator_authentication(), + ) + .context("failed to create p2p service")?; + Ok(( + Self { + inner, + outgoing_message_tx, + hardware_challenge_rx, + }, + hardware_challenge_tx, + )) + } + + pub async fn run(self) -> Result<()> { + use futures::StreamExt as _; + + let Self { + inner, + outgoing_message_tx, + mut hardware_challenge_rx, + } = self; + + tokio::task::spawn(inner.run()); + + let mut futures = FuturesUnordered::new(); + + loop { + tokio::select! { + Some(request) = hardware_challenge_rx.recv() => { + let (incoming_resp_tx, incoming_resp_rx) = tokio::sync::oneshot::channel(); + let fut = async move { + let resp = match incoming_resp_rx.await.context("outgoing request tx channel was dropped")? { + p2p::Response::HardwareChallenge(resp) => resp.response, + _ => bail!("unexpected response type for hardware challenge request"), + }; + request.response_tx.send(resp).map_err(|_|anyhow::anyhow!("caller dropped response channel"))?; + Ok(()) + }; + futures.push(fut); + + let outgoing_request = OutgoingRequest { + peer_wallet_address: request.worker_wallet_address, + peer_id: request.worker_p2p_id, + multiaddrs: request.worker_addresses, + request: p2p::HardwareChallengeRequest { + challenge: request.challenge, + timestamp: std::time::SystemTime::now(), + }.into(), + response_tx: incoming_resp_tx, + }; + outgoing_message_tx.send(outgoing_request).await + .context("failed to send outgoing hardware challenge request")?; + } + Some(res) = futures.next() => { + if let Err(e) = res { + log::error!("failed to handle response conversion: {e}"); + } + } + } + } + } +} + +pub struct HardwareChallengeRequest { + pub(crate) worker_wallet_address: alloy::primitives::Address, + pub(crate) worker_p2p_id: String, + pub(crate) worker_addresses: Vec, + pub(crate) challenge: p2p::ChallengeRequest, + pub(crate) response_tx: tokio::sync::oneshot::Sender, +} diff --git a/crates/validator/src/validators/hardware.rs b/crates/validator/src/validators/hardware.rs index 00736d34..877861da 100644 --- a/crates/validator/src/validators/hardware.rs +++ b/crates/validator/src/validators/hardware.rs @@ -1,15 +1,13 @@ use alloy::primitives::Address; +use anyhow::bail; use anyhow::Result; use log::{debug, error, info}; use shared::{ models::node::DiscoveryNode, - web3::{ - contracts::core::builder::Contracts, - wallet::{Wallet, WalletProvider}, - }, + web3::{contracts::core::builder::Contracts, wallet::WalletProvider}, }; -use crate::p2p::client::P2PClient; +use crate::p2p::HardwareChallengeRequest; use crate::validators::hardware_challenge::HardwareChallenge; /// Hardware validator implementation @@ -17,35 +15,27 @@ use crate::validators::hardware_challenge::HardwareChallenge; /// NOTE: This is a temporary implementation that will be replaced with a proper /// hardware validator in the near future. The current implementation only performs /// basic matrix multiplication challenges and does not verify actual hardware specs. -pub struct HardwareValidator<'a> { - wallet: &'a Wallet, +pub struct HardwareValidator { contracts: Contracts, - p2p_client: Option<&'a P2PClient>, + challenge_tx: tokio::sync::mpsc::Sender, } -impl<'a> HardwareValidator<'a> { +impl HardwareValidator { pub fn new( - wallet: &'a Wallet, contracts: Contracts, - p2p_client: Option<&'a P2PClient>, + challenge_tx: tokio::sync::mpsc::Sender, ) -> Self { Self { - wallet, contracts, - p2p_client, + challenge_tx, } } - async fn validate_node( - _wallet: &'a Wallet, - contracts: Contracts, - p2p_client: Option<&'a P2PClient>, - node: DiscoveryNode, - ) -> Result<()> { + async fn validate_node(&self, node: DiscoveryNode) -> Result<()> { let node_address = match node.id.trim_start_matches("0x").parse::
() { Ok(addr) => addr, Err(e) => { - return Err(anyhow::anyhow!("Failed to parse node address: {}", e)); + bail!("failed to parse node address: {e:?}"); } }; @@ -56,30 +46,22 @@ impl<'a> HardwareValidator<'a> { { Ok(addr) => addr, Err(e) => { - return Err(anyhow::anyhow!("Failed to parse provider address: {}", e)); + bail!("failed to parse provider address: {e:?}"); } }; // Perform hardware challenge - if let Some(p2p_client) = p2p_client { - let hardware_challenge = HardwareChallenge::new(p2p_client); - let challenge_result = hardware_challenge.challenge_node(&node).await; - - if let Err(e) = challenge_result { - println!("Challenge failed for node: {}, error: {}", node.id, e); - error!("Challenge failed for node: {}, error: {}", node.id, e); - return Err(anyhow::anyhow!("Failed to challenge node: {}", e)); - } - } else { - debug!( - "P2P client not available, skipping hardware challenge for node {}", - node.id - ); + let hardware_challenge = HardwareChallenge::new(self.challenge_tx.clone()); + let challenge_result = hardware_challenge.challenge_node(&node).await; + + if let Err(e) = challenge_result { + bail!("failed to challenge node: {e:?}"); } debug!("Sending validation transaction for node {}", node.id); - if let Err(e) = contracts + if let Err(e) = self + .contracts .prime_network .validate_node(provider_address, node_address) .await @@ -100,17 +82,11 @@ impl<'a> HardwareValidator<'a> { debug!("Non validated nodes: {non_validated:?}"); info!("Starting validation for {} nodes", non_validated.len()); - let contracts = self.contracts.clone(); - let wallet = self.wallet; - let p2p_client = self.p2p_client; - // Process non validated nodes sequentially as simple fix // to avoid nonce conflicts for now. Will sophisticate this in the future for node in non_validated { let node_id = node.id.clone(); - match HardwareValidator::validate_node(wallet, contracts.clone(), p2p_client, node) - .await - { + match self.validate_node(node).await { Ok(_) => (), Err(e) => { error!("Failed to validate node {node_id}: {e}"); @@ -134,7 +110,6 @@ mod tests { async fn test_challenge_node() { let coordinator_key = "0xdbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97"; let rpc_url: Url = Url::parse("http://localhost:8545").unwrap(); - let coordinator_wallet = Arc::new(Wallet::new(coordinator_key, rpc_url).unwrap()); let contracts = ContractBuilder::new(coordinator_wallet.provider()) @@ -145,7 +120,8 @@ mod tests { .build() .unwrap(); - let validator = HardwareValidator::new(&coordinator_wallet, contracts, None); + let (tx, _rx) = tokio::sync::mpsc::channel(100); + let validator = HardwareValidator::new(contracts, tx); let fake_discovery_node1 = DiscoveryNode { is_validated: false, diff --git a/crates/validator/src/validators/hardware_challenge.rs b/crates/validator/src/validators/hardware_challenge.rs index c881c542..6970355d 100644 --- a/crates/validator/src/validators/hardware_challenge.rs +++ b/crates/validator/src/validators/hardware_challenge.rs @@ -1,40 +1,38 @@ -use crate::p2p::client::P2PClient; use alloy::primitives::Address; -use anyhow::{Error, Result}; +use anyhow::{bail, Context as _, Result}; use log::{error, info}; use rand::{rng, Rng}; -use shared::models::{ - challenge::{calc_matrix, ChallengeRequest, FixedF64}, - node::DiscoveryNode, -}; +use shared::models::node::DiscoveryNode; use std::str::FromStr; -pub(crate) struct HardwareChallenge<'a> { - p2p_client: &'a P2PClient, +use crate::p2p::HardwareChallengeRequest; + +pub(crate) struct HardwareChallenge { + challenge_tx: tokio::sync::mpsc::Sender, } -impl<'a> HardwareChallenge<'a> { - pub(crate) fn new(p2p_client: &'a P2PClient) -> Self { - Self { p2p_client } +impl HardwareChallenge { + pub(crate) fn new(challenge_tx: tokio::sync::mpsc::Sender) -> Self { + Self { challenge_tx } } - pub(crate) async fn challenge_node(&self, node: &DiscoveryNode) -> Result { + pub(crate) async fn challenge_node(&self, node: &DiscoveryNode) -> Result<()> { // Check if node has P2P ID and addresses let p2p_id = node .node .worker_p2p_id - .as_ref() + .clone() .ok_or_else(|| anyhow::anyhow!("Node {} does not have P2P ID", node.id))?; let p2p_addresses = node .node .worker_p2p_addresses - .as_ref() + .clone() .ok_or_else(|| anyhow::anyhow!("Node {} does not have P2P addresses", node.id))?; // create random challenge matrix let challenge_matrix = self.random_challenge(3, 3, 3, 3); - let challenge_expected = calc_matrix(&challenge_matrix); + let challenge_expected = p2p::calc_matrix(&challenge_matrix); // Add timestamp to the challenge let current_time = std::time::SystemTime::now() @@ -47,34 +45,36 @@ impl<'a> HardwareChallenge<'a> { let node_address = Address::from_str(&node.node.id) .map_err(|e| anyhow::anyhow!("Failed to parse node address {}: {}", node.node.id, e))?; + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let hardware_challenge = HardwareChallengeRequest { + worker_wallet_address: node_address, + worker_p2p_id: p2p_id, + worker_addresses: p2p_addresses, + challenge: challenge_with_timestamp, + response_tx, + }; + // Send challenge via P2P - match self - .p2p_client - .send_hardware_challenge( - node_address, - p2p_id, - p2p_addresses, - challenge_with_timestamp, - ) + self.challenge_tx + .send(hardware_challenge) + .await + .context("failed to send hardware challenge request to p2p service")?; + + let resp = response_rx .await - { - Ok(response) => { - if challenge_expected.result == response.result { - info!("Challenge for node {} successful", node.id); - Ok(0) - } else { - error!( - "Challenge failed for node {}: expected {:?}, got {:?}", - node.id, challenge_expected.result, response.result - ); - Err(anyhow::anyhow!("Node failed challenge")) - } - } - Err(e) => { - error!("Failed to send challenge to node {}: {}", node.id, e); - Err(anyhow::anyhow!("Failed to send challenge: {}", e)) - } + .context("failed to receive response from node")?; + + if challenge_expected.result == resp.result { + info!("Challenge for node {} successful", node.id); + } else { + error!( + "Challenge failed for node {}: expected {:?}, got {:?}", + node.id, challenge_expected.result, resp.result + ); + bail!("Node failed challenge"); } + + Ok(()) } fn random_challenge( @@ -83,7 +83,9 @@ impl<'a> HardwareChallenge<'a> { cols_a: usize, rows_b: usize, cols_b: usize, - ) -> ChallengeRequest { + ) -> p2p::ChallengeRequest { + use p2p::FixedF64; + let mut rng = rng(); let data_a_vec: Vec = (0..(rows_a * cols_a)) @@ -98,7 +100,7 @@ impl<'a> HardwareChallenge<'a> { let data_a: Vec = data_a_vec.iter().map(|x| FixedF64(*x)).collect(); let data_b: Vec = data_b_vec.iter().map(|x| FixedF64(*x)).collect(); - ChallengeRequest { + p2p::ChallengeRequest { rows_a, cols_a, data_a,