diff --git a/Cargo.lock b/Cargo.lock index 84c4ad45..c16f0570 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6257,6 +6257,20 @@ dependencies = [ "sha2", ] +[[package]] +name = "p2p" +version = "0.3.11" +dependencies = [ + "anyhow", + "libp2p", + "nalgebra", + "serde", + "tokio", + "tokio-util", + "tracing", + "void", +] + [[package]] name = "parity-scale-codec" version = "3.7.4" diff --git a/Cargo.toml b/Cargo.toml index 00702d19..1bc9e2ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,12 +5,15 @@ members = [ "crates/validator", "crates/shared", "crates/orchestrator", + "crates/p2p", "crates/dev-utils", ] resolver = "2" [workspace.dependencies] shared = { path = "crates/shared" } +p2p = { path = "crates/p2p" } + actix-web = "4.9.0" clap = { version = "4.5.27", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } @@ -42,6 +45,7 @@ rand_core_v6 = { package = "rand_core", version = "0.6.4", features = ["std"] } ipld-core = "0.4" rust-ipfs = "0.14" cid = "0.11" +tracing = "0.1.41" [workspace.package] version = "0.3.11" diff --git a/crates/p2p/Cargo.toml b/crates/p2p/Cargo.toml new file mode 100644 index 00000000..bb670107 --- /dev/null +++ b/crates/p2p/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "p2p" +version.workspace = true +edition.workspace = true + +[dependencies] +libp2p = { version = "0.54", features = ["request-response", "identify", "ping", "mdns", "noise", "tcp", "autonat", "kad", "tokio", "cbor", "macros", "yamux"] } +void = "1.0" + +anyhow = {workspace = true} +nalgebra = {workspace = true} +serde = {workspace = true} +tokio = {workspace = true, features = ["sync"]} +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/crates/p2p/src/behaviour.rs b/crates/p2p/src/behaviour.rs new file mode 100644 index 00000000..b114b61e --- /dev/null +++ b/crates/p2p/src/behaviour.rs @@ -0,0 +1,184 @@ +use anyhow::Context as _; +use anyhow::Result; +use libp2p::autonat; +use libp2p::connection_limits; +use libp2p::connection_limits::ConnectionLimits; +use libp2p::identify; +use libp2p::identity; +use libp2p::kad; +use libp2p::kad::store::MemoryStore; +use libp2p::mdns; +use libp2p::ping; +use libp2p::request_response; +use libp2p::swarm::NetworkBehaviour; +use std::time::Duration; +use tracing::debug; + +use crate::message::IncomingMessage; +use crate::message::{Request, Response}; +use crate::Protocols; +use crate::PRIME_STREAM_PROTOCOL; + +#[derive(NetworkBehaviour)] +#[behaviour(to_swarm = "BehaviourEvent")] +pub(crate) struct Behaviour { + // connection gating + connection_limits: connection_limits::Behaviour, + + // discovery + mdns: mdns::tokio::Behaviour, + kademlia: kad::Behaviour, + + // protocols + identify: identify::Behaviour, + ping: ping::Behaviour, + request_response: request_response::cbor::Behaviour, + + // nat traversal + autonat: autonat::Behaviour, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum BehaviourEvent { + Autonat(autonat::Event), + Identify(identify::Event), + Kademlia(kad::Event), + Mdns(mdns::Event), + Ping(ping::Event), + RequestResponse(request_response::Event), +} + +impl From for BehaviourEvent { + fn from(_: void::Void) -> Self { + unreachable!("void::Void cannot be converted to BehaviourEvent") + } +} + +impl From for BehaviourEvent { + fn from(event: autonat::Event) -> Self { + BehaviourEvent::Autonat(event) + } +} + +impl From for BehaviourEvent { + fn from(event: kad::Event) -> Self { + BehaviourEvent::Kademlia(event) + } +} + +impl From for BehaviourEvent { + fn from(event: libp2p::mdns::Event) -> Self { + BehaviourEvent::Mdns(event) + } +} + +impl From for BehaviourEvent { + fn from(event: ping::Event) -> Self { + BehaviourEvent::Ping(event) + } +} + +impl From for BehaviourEvent { + fn from(event: identify::Event) -> Self { + BehaviourEvent::Identify(event) + } +} + +impl From> for BehaviourEvent { + fn from(event: request_response::Event) -> Self { + BehaviourEvent::RequestResponse(event) + } +} + +impl Behaviour { + pub(crate) fn new( + keypair: &identity::Keypair, + protocols: Protocols, + agent_version: String, + ) -> Result { + let peer_id = keypair.public().to_peer_id(); + + let protocols = protocols.into_iter().map(|protocol| { + ( + protocol.as_stream_protocol(), + request_response::ProtocolSupport::Full, // TODO: configure inbound/outbound based on node role and protocol + ) + }); + + let autonat = autonat::Behaviour::new(peer_id, autonat::Config::default()); + let connection_limits = connection_limits::Behaviour::new( + ConnectionLimits::default().with_max_established(Some(100)), + ); + + let mdns = mdns::tokio::Behaviour::new(mdns::Config::default(), peer_id) + .context("failed to create mDNS behaviour")?; + let kademlia = kad::Behaviour::new(peer_id, MemoryStore::new(peer_id)); + + let identify = identify::Behaviour::new( + identify::Config::new(PRIME_STREAM_PROTOCOL.to_string(), keypair.public()) + .with_agent_version(agent_version), + ); + let ping = ping::Behaviour::new(ping::Config::new().with_interval(Duration::from_secs(10))); + + Ok(Self { + autonat, + connection_limits, + kademlia, + mdns, + identify, + ping, + request_response: request_response::cbor::Behaviour::new( + protocols, + request_response::Config::default(), + ), + }) + } + + pub(crate) fn request_response( + &mut self, + ) -> &mut request_response::cbor::Behaviour { + &mut self.request_response + } +} + +impl BehaviourEvent { + pub(crate) async fn handle(self, message_tx: tokio::sync::mpsc::Sender) { + match self { + BehaviourEvent::Autonat(_event) => {} + BehaviourEvent::Identify(_event) => {} + BehaviourEvent::Kademlia(_event) => { // TODO: potentially on outbound queries + } + BehaviourEvent::Mdns(_event) => {} + BehaviourEvent::Ping(_event) => {} + BehaviourEvent::RequestResponse(event) => match event { + request_response::Event::Message { peer, message } => { + debug!("received message from peer {peer:?}: {message:?}"); + // if this errors, user dropped their incoming message channel + let _ = message_tx.send(IncomingMessage { peer, message }).await; + } + request_response::Event::ResponseSent { peer, request_id } => { + debug!("response sent to peer {peer:?} for request ID {request_id:?}"); + } + request_response::Event::InboundFailure { + peer, + request_id, + error, + } => { + debug!( + "inbound failure from peer {peer:?} for request ID {request_id:?}: {error}" + ); + } + request_response::Event::OutboundFailure { + peer, + request_id, + error, + } => { + debug!( + "outbound failure to peer {peer:?} for request ID {request_id:?}: {error}" + ); + } + }, + } + } +} diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs new file mode 100644 index 00000000..0a5637a9 --- /dev/null +++ b/crates/p2p/src/lib.rs @@ -0,0 +1,408 @@ +use anyhow::Context; +use anyhow::Result; +use libp2p::noise; +use libp2p::swarm::SwarmEvent; +use libp2p::tcp; +use libp2p::yamux; +use libp2p::Swarm; +use libp2p::SwarmBuilder; +use libp2p::{identity, Transport}; +use std::time::Duration; +use tracing::debug; + +mod behaviour; +mod message; +mod protocol; + +use behaviour::Behaviour; +use protocol::Protocols; + +pub use message::*; + +pub type Libp2pIncomingMessage = libp2p::request_response::Message; +pub type ResponseChannel = libp2p::request_response::ResponseChannel; +pub type PeerId = libp2p::PeerId; +pub type Multiaddr = libp2p::Multiaddr; +pub type Keypair = libp2p::identity::Keypair; +pub type DialSender = + tokio::sync::mpsc::Sender<(Vec, tokio::sync::oneshot::Sender>)>; + +pub const PRIME_STREAM_PROTOCOL: libp2p::StreamProtocol = + libp2p::StreamProtocol::new("/prime/1.0.0"); +// TODO: force this to be passed by the user +pub const DEFAULT_AGENT_VERSION: &str = "prime-node/0.1.0"; + +pub struct Node { + peer_id: PeerId, + listen_addrs: Vec, + swarm: Swarm, + bootnodes: Vec, + cancellation_token: tokio_util::sync::CancellationToken, + + dial_rx: + tokio::sync::mpsc::Receiver<(Vec, tokio::sync::oneshot::Sender>)>, + + // channel for sending incoming messages to the consumer of this library + incoming_message_tx: tokio::sync::mpsc::Sender, + + // channel for receiving outgoing messages from the consumer of this library + outgoing_message_rx: tokio::sync::mpsc::Receiver, +} + +impl Node { + pub fn peer_id(&self) -> PeerId { + self.peer_id + } + + pub fn listen_addrs(&self) -> &[libp2p::Multiaddr] { + &self.listen_addrs + } + + /// Returns the multiaddresses that this node is listening on, with the peer ID included. + pub fn multiaddrs(&self) -> Vec { + self.listen_addrs + .iter() + .map(|addr| { + addr.clone() + .with_p2p(self.peer_id) + .expect("can add peer ID to multiaddr") + }) + .collect() + } + + pub async fn run(self) -> Result<()> { + use libp2p::futures::StreamExt as _; + + let Node { + peer_id: _, + listen_addrs, + mut swarm, + bootnodes, + cancellation_token, + mut dial_rx, + incoming_message_tx, + mut outgoing_message_rx, + } = self; + + for addr in listen_addrs { + swarm + .listen_on(addr) + .context("swarm failed to listen on multiaddr")?; + } + + for bootnode in bootnodes { + match swarm.dial(bootnode.clone()) { + Ok(_) => {} + Err(e) => { + debug!("failed to dial bootnode {bootnode}: {e:?}"); + } + } + } + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + debug!("cancellation token triggered, shutting down node"); + break Ok(()); + } + Some((addrs, res_tx)) = dial_rx.recv() => { + let mut res = Ok(()); + for addr in addrs { + match swarm.dial(addr.clone()) { + Ok(_) => {} + Err(e) => { + res = Err(anyhow::anyhow!("failed to dial {addr}: {e:?}")); + break; + } + } + } + let _ = res_tx.send(res); + } + Some(message) = outgoing_message_rx.recv() => { + match message { + OutgoingMessage::Request((peer, request)) => { + swarm.behaviour_mut().request_response().send_request(&peer, request); + } + OutgoingMessage::Response((channel, response)) => { + if let Err(e) = swarm.behaviour_mut().request_response().send_response(channel, response) { + debug!("failed to send response: {e:?}"); + } + } + } + } + event = swarm.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { + listener_id: _, + address, + } => { + debug!("new listen address: {address}"); + } + SwarmEvent::ExternalAddrConfirmed { address } => { + debug!("external address confirmed: {address}"); + } + SwarmEvent::ConnectionClosed { + peer_id, + cause, + endpoint: _, + connection_id: _, + num_established: _, + } => { + debug!("connection closed with peer {peer_id}: {cause:?}"); + } + SwarmEvent::Behaviour(event) => event.handle(incoming_message_tx.clone()).await, + _ => continue, + } + }, + } + } + } +} + +pub struct NodeBuilder { + port: Option, + listen_addrs: Vec, + keypair: Option, + agent_version: Option, + protocols: Protocols, + bootnodes: Vec, + cancellation_token: Option, +} + +impl Default for NodeBuilder { + fn default() -> Self { + Self::new() + } +} + +impl NodeBuilder { + pub fn new() -> Self { + Self { + port: None, + listen_addrs: Vec::new(), + keypair: None, + agent_version: None, + protocols: Protocols::new(), + bootnodes: Vec::new(), + cancellation_token: None, + } + } + + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + pub fn with_listen_addr(mut self, addr: libp2p::Multiaddr) -> Self { + self.listen_addrs.push(addr); + self + } + + pub fn with_keypair(mut self, keypair: identity::Keypair) -> Self { + self.keypair = Some(keypair); + self + } + + pub fn with_agent_version(mut self, agent_version: String) -> Self { + self.agent_version = Some(agent_version); + self + } + + pub fn with_validator_authentication(mut self) -> Self { + self.protocols = self.protocols.with_validator_authentication(); + self + } + + pub fn with_hardware_challenge(mut self) -> Self { + self.protocols = self.protocols.with_hardware_challenge(); + self + } + + pub fn with_invite(mut self) -> Self { + self.protocols = self.protocols.with_invite(); + self + } + + pub fn with_get_task_logs(mut self) -> Self { + self.protocols = self.protocols.with_get_task_logs(); + self + } + + pub fn with_restart(mut self) -> Self { + self.protocols = self.protocols.with_restart(); + self + } + + pub fn with_general(mut self) -> Self { + self.protocols = self.protocols.with_general(); + self + } + + pub fn with_bootnode(mut self, bootnode: Multiaddr) -> Self { + self.bootnodes.push(bootnode); + self + } + + pub fn with_bootnodes(mut self, bootnodes: I) -> Self + where + I: IntoIterator, + T: Into, + { + for bootnode in bootnodes { + self.bootnodes.push(bootnode.into()); + } + self + } + + pub fn with_cancellation_token( + mut self, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> Self { + self.cancellation_token = Some(cancellation_token); + self + } + + pub fn try_build( + self, + ) -> Result<( + Node, + DialSender, + tokio::sync::mpsc::Receiver, + tokio::sync::mpsc::Sender, + )> { + let Self { + port, + mut listen_addrs, + keypair, + agent_version, + protocols, + bootnodes, + cancellation_token, + } = self; + + let keypair = keypair.unwrap_or(identity::Keypair::generate_ed25519()); + let peer_id = keypair.public().to_peer_id(); + + let transport = create_transport(&keypair)?; + let behaviour = Behaviour::new( + &keypair, + protocols, + agent_version.unwrap_or(DEFAULT_AGENT_VERSION.to_string()), + ) + .context("failed to create behaviour")?; + + let swarm = SwarmBuilder::with_existing_identity(keypair) + .with_tokio() + .with_other_transport(|_| transport)? + .with_behaviour(|_| behaviour)? + .with_swarm_config(|cfg| { + cfg.with_idle_connection_timeout(Duration::from_secs(u64::MAX)) // don't disconnect from idle peers + }) + .build(); + + if listen_addrs.is_empty() { + let port = port.unwrap_or(0); + let listen_addr = format!("/ip4/0.0.0.0/tcp/{port}") + .parse() + .expect("can parse valid multiaddr"); + listen_addrs.push(listen_addr); + } + + let (dial_tx, dial_rx) = tokio::sync::mpsc::channel(100); + let (incoming_message_tx, incoming_message_rx) = tokio::sync::mpsc::channel(100); + let (outgoing_message_tx, outgoing_message_rx) = tokio::sync::mpsc::channel(100); + + Ok(( + Node { + peer_id, + swarm, + listen_addrs, + bootnodes, + dial_rx, + incoming_message_tx, + outgoing_message_rx, + cancellation_token: cancellation_token.unwrap_or_default(), + }, + dial_tx, + incoming_message_rx, + outgoing_message_tx, + )) + } +} + +fn create_transport( + keypair: &identity::Keypair, +) -> Result> { + let transport = tcp::tokio::Transport::new(tcp::Config::default()) + .upgrade(libp2p::core::upgrade::Version::V1) + .authenticate(noise::Config::new(keypair)?) + .multiplex(yamux::Config::default()) + .timeout(Duration::from_secs(20)) + .boxed(); + + Ok(transport) +} + +#[cfg(test)] +mod test { + use super::NodeBuilder; + use crate::message; + + #[tokio::test] + async fn two_nodes_can_connect_and_do_request_response() { + let (node1, _, mut incoming_message_rx1, outgoing_message_tx1) = + NodeBuilder::new().with_get_task_logs().try_build().unwrap(); + let node1_peer_id = node1.peer_id(); + + let (node2, _, mut incoming_message_rx2, outgoing_message_tx2) = NodeBuilder::new() + .with_get_task_logs() + .with_bootnodes(node1.multiaddrs()) + .try_build() + .unwrap(); + let node2_peer_id = node2.peer_id(); + + tokio::spawn(async move { node1.run().await }); + tokio::spawn(async move { node2.run().await }); + + // TODO: implement a way to get peer count + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // send request from node1->node2 + let request = message::Request::GetTaskLogs; + outgoing_message_tx1 + .send(request.into_outgoing_message(node2_peer_id)) + .await + .unwrap(); + let message = incoming_message_rx2.recv().await.unwrap(); + assert_eq!(message.peer, node1_peer_id); + let libp2p::request_response::Message::Request { + request_id: _, + request: message::Request::GetTaskLogs, + channel, + } = message.message + else { + panic!("expected a GetTaskLogs request message"); + }; + + // send response from node2->node1 + let response = + message::Response::GetTaskLogs(message::GetTaskLogsResponse::Ok("logs".to_string())); + outgoing_message_tx2 + .send(response.into_outgoing_message(channel)) + .await + .unwrap(); + let message = incoming_message_rx1.recv().await.unwrap(); + assert_eq!(message.peer, node2_peer_id); + let libp2p::request_response::Message::Response { + request_id: _, + response: message::Response::GetTaskLogs(response), + } = message.message + else { + panic!("expected a GetTaskLogs response message"); + }; + let message::GetTaskLogsResponse::Ok(logs) = response else { + panic!("expected a successful GetTaskLogs response"); + }; + assert_eq!(logs, "logs"); + } +} diff --git a/crates/p2p/src/message/hardware_challenge.rs b/crates/p2p/src/message/hardware_challenge.rs new file mode 100644 index 00000000..639cc602 --- /dev/null +++ b/crates/p2p/src/message/hardware_challenge.rs @@ -0,0 +1,89 @@ +use nalgebra::DMatrix; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt; + +#[derive(Debug, Clone)] +pub struct FixedF64(pub f64); + +impl Serialize for FixedF64 { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // adjust precision as needed + serializer.serialize_str(&format!("{:.12}", self.0)) + } +} + +impl<'de> Deserialize<'de> for FixedF64 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FixedF64Visitor; + + impl Visitor<'_> for FixedF64Visitor { + type Value = FixedF64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string representing a fixed precision float") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + value + .parse::() + .map(FixedF64) + .map_err(|_| E::custom(format!("invalid f64: {value}"))) + } + } + + deserializer.deserialize_str(FixedF64Visitor) + } +} + +impl PartialEq for FixedF64 { + fn eq(&self, other: &Self) -> bool { + format!("{:.10}", self.0) == format!("{:.10}", other.0) + } +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +pub struct ChallengeRequest { + pub rows_a: usize, + pub cols_a: usize, + pub data_a: Vec, + pub rows_b: usize, + pub cols_b: usize, + pub data_b: Vec, + pub timestamp: Option, +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +pub struct ChallengeResponse { + pub result: Vec, + pub rows: usize, + pub cols: usize, +} + +pub fn calc_matrix(req: &ChallengeRequest) -> ChallengeResponse { + // convert FixedF64 to f64 + let data_a: Vec = req.data_a.iter().map(|x| x.0).collect(); + let data_b: Vec = req.data_b.iter().map(|x| x.0).collect(); + let a = DMatrix::from_vec(req.rows_a, req.cols_a, data_a); + let b = DMatrix::from_vec(req.rows_b, req.cols_b, data_b); + let c = a * b; + + let data_c: Vec = c.iter().map(|x| FixedF64(*x)).collect(); + + ChallengeResponse { + rows: c.nrows(), + cols: c.ncols(), + result: data_c, + } +} diff --git a/crates/p2p/src/message/mod.rs b/crates/p2p/src/message/mod.rs new file mode 100644 index 00000000..adff99ac --- /dev/null +++ b/crates/p2p/src/message/mod.rs @@ -0,0 +1,234 @@ +use libp2p::PeerId; +use serde::{Deserialize, Serialize}; +use std::time::SystemTime; + +mod hardware_challenge; + +pub use hardware_challenge::*; + +#[derive(Debug)] +pub struct IncomingMessage { + pub peer: PeerId, + pub message: libp2p::request_response::Message, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum OutgoingMessage { + Request((PeerId, Request)), + Response( + ( + libp2p::request_response::ResponseChannel, + Response, + ), + ), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Request { + ValidatorAuthentication(ValidatorAuthenticationRequest), + HardwareChallenge(HardwareChallengeRequest), + Invite(InviteRequest), + GetTaskLogs, + Restart, + General(GeneralRequest), +} + +impl Request { + pub fn into_outgoing_message(self, peer: PeerId) -> OutgoingMessage { + OutgoingMessage::Request((peer, self)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Response { + ValidatorAuthentication(ValidatorAuthenticationResponse), + HardwareChallenge(HardwareChallengeResponse), + Invite(InviteResponse), + GetTaskLogs(GetTaskLogsResponse), + Restart(RestartResponse), + General(GeneralResponse), +} + +impl Response { + pub fn into_outgoing_message( + self, + channel: libp2p::request_response::ResponseChannel, + ) -> OutgoingMessage { + OutgoingMessage::Response((channel, self)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ValidatorAuthenticationRequest { + Initiation(ValidatorAuthenticationInitiationRequest), + Solution(ValidatorAuthenticationSolutionRequest), +} + +impl From for Request { + fn from(request: ValidatorAuthenticationRequest) -> Self { + Request::ValidatorAuthentication(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ValidatorAuthenticationResponse { + Initiation(ValidatorAuthenticationInitiationResponse), + Solution(ValidatorAuthenticationSolutionResponse), +} + +impl From for Response { + fn from(response: ValidatorAuthenticationResponse) -> Self { + Response::ValidatorAuthentication(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidatorAuthenticationInitiationRequest { + pub message: String, +} + +impl From for Request { + fn from(request: ValidatorAuthenticationInitiationRequest) -> Self { + Request::ValidatorAuthentication(ValidatorAuthenticationRequest::Initiation(request)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidatorAuthenticationInitiationResponse { + pub signature: String, + pub message: String, +} + +impl From for Response { + fn from(response: ValidatorAuthenticationInitiationResponse) -> Self { + Response::ValidatorAuthentication(ValidatorAuthenticationResponse::Initiation(response)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidatorAuthenticationSolutionRequest { + pub signature: String, +} + +impl From for Request { + fn from(request: ValidatorAuthenticationSolutionRequest) -> Self { + Request::ValidatorAuthentication(ValidatorAuthenticationRequest::Solution(request)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ValidatorAuthenticationSolutionResponse { + Granted, + Rejected, +} + +impl From for Response { + fn from(response: ValidatorAuthenticationSolutionResponse) -> Self { + Response::ValidatorAuthentication(ValidatorAuthenticationResponse::Solution(response)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HardwareChallengeRequest { + pub challenge: ChallengeRequest, + pub timestamp: SystemTime, +} + +impl From for Request { + fn from(request: HardwareChallengeRequest) -> Self { + Request::HardwareChallenge(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HardwareChallengeResponse { + pub response: ChallengeResponse, + pub timestamp: SystemTime, +} + +impl From for Response { + fn from(response: HardwareChallengeResponse) -> Self { + Response::HardwareChallenge(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum InviteRequestUrl { + MasterUrl(String), + MasterIpPort(String, u16), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteRequest { + pub invite: String, + pub pool_id: u32, + pub url: InviteRequestUrl, + pub timestamp: u64, + pub expiration: [u8; 32], + pub nonce: [u8; 32], +} + +impl From for Request { + fn from(request: InviteRequest) -> Self { + Request::Invite(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum InviteResponse { + Ok, + Error(String), +} + +impl From for Response { + fn from(response: InviteResponse) -> Self { + Response::Invite(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GetTaskLogsResponse { + Ok(String), + Error(String), +} + +impl From for Response { + fn from(response: GetTaskLogsResponse) -> Self { + Response::GetTaskLogs(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RestartResponse { + Ok, + Error(String), +} + +impl From for Response { + fn from(response: RestartResponse) -> Self { + Response::Restart(response) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralRequest { + data: Vec, +} + +impl From for Request { + fn from(request: GeneralRequest) -> Self { + Request::General(request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralResponse { + data: Vec, +} + +impl From for Response { + fn from(response: GeneralResponse) -> Self { + Response::General(response) + } +} diff --git a/crates/p2p/src/protocol.rs b/crates/p2p/src/protocol.rs new file mode 100644 index 00000000..df423ef8 --- /dev/null +++ b/crates/p2p/src/protocol.rs @@ -0,0 +1,81 @@ +use libp2p::StreamProtocol; +use std::{collections::HashSet, hash::Hash}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) enum Protocol { + // validator -> worker + ValidatorAuthentication, + // validator -> worker + HardwareChallenge, + // orchestrator -> worker + Invite, + // any -> worker + GetTaskLogs, + // any -> worker + Restart, + // any -> any + General, +} + +impl Protocol { + pub(crate) fn as_stream_protocol(&self) -> StreamProtocol { + match self { + Protocol::ValidatorAuthentication => { + StreamProtocol::new("/prime/validator_authentication/1.0.0") + } + Protocol::HardwareChallenge => StreamProtocol::new("/prime/hardware_challenge/1.0.0"), + Protocol::Invite => StreamProtocol::new("/prime/invite/1.0.0"), + Protocol::GetTaskLogs => StreamProtocol::new("/prime/get_task_logs/1.0.0"), + Protocol::Restart => StreamProtocol::new("/prime/restart/1.0.0"), + Protocol::General => StreamProtocol::new("/prime/general/1.0.0"), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Protocols(HashSet); + +impl Protocols { + pub(crate) fn new() -> Self { + Self(HashSet::new()) + } + + pub(crate) fn with_validator_authentication(mut self) -> Self { + self.0.insert(Protocol::ValidatorAuthentication); + self + } + + pub(crate) fn with_hardware_challenge(mut self) -> Self { + self.0.insert(Protocol::HardwareChallenge); + self + } + + pub(crate) fn with_invite(mut self) -> Self { + self.0.insert(Protocol::Invite); + self + } + + pub(crate) fn with_get_task_logs(mut self) -> Self { + self.0.insert(Protocol::GetTaskLogs); + self + } + + pub(crate) fn with_restart(mut self) -> Self { + self.0.insert(Protocol::Restart); + self + } + + pub(crate) fn with_general(mut self) -> Self { + self.0.insert(Protocol::General); + self + } +} + +impl IntoIterator for Protocols { + type Item = Protocol; + type IntoIter = std::collections::hash_set::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 18596ba5..0f08e404 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -50,7 +50,7 @@ unicode-width = "0.2.0" rand = "0.9.0" tempfile = "3.14.0" tracing-loki = "0.2.6" -tracing = "0.1.41" +tracing = { workspace = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } tracing-log = "0.2.0" time = "0.3.41"