diff --git a/libs/gl-signerproxy/Cargo.toml b/libs/gl-signerproxy/Cargo.toml index 7f4788edd..366938742 100644 --- a/libs/gl-signerproxy/Cargo.toml +++ b/libs/gl-signerproxy/Cargo.toml @@ -15,16 +15,17 @@ name = "gl-signerproxy" path = "src/bin/signerproxy.rs" [build-dependencies] -tonic-build = "0.3" +tonic-build = "0.8" [dependencies] anyhow = { workspace = true } env_logger = { workspace = true } -tokio = { version = "0.2", features = ["full"] } -tonic = { version = "0.3", features = ["tls", "transport"] } -prost = "0.6" -log = "0.4" -tower = "0.3" +# Minimal tokio - only for gRPC client runtime +tokio = { version = "1", features = ["rt", "net", "io-util"] } +tonic = { version = "0.8", features = ["tls", "transport"] } +prost = "0.11" +log = "*" +tower = "0.4" which = "4.4.2" libc = "0.2" byteorder = "1.5.0" diff --git a/libs/gl-signerproxy/src/bin/signerproxy.rs b/libs/gl-signerproxy/src/bin/signerproxy.rs index 2b2e11d48..c08e45bce 100644 --- a/libs/gl-signerproxy/src/bin/signerproxy.rs +++ b/libs/gl-signerproxy/src/bin/signerproxy.rs @@ -1,11 +1,10 @@ use anyhow::Result; use gl_signerproxy::Proxy; -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { env_logger::builder() .target(env_logger::Target::Stderr) .init(); - Proxy::new().run().await + Proxy::new().run() } diff --git a/libs/gl-signerproxy/src/hsmproxy.rs b/libs/gl-signerproxy/src/hsmproxy.rs index c1b385759..b3d292741 100644 --- a/libs/gl-signerproxy/src/hsmproxy.rs +++ b/libs/gl-signerproxy/src/hsmproxy.rs @@ -4,17 +4,18 @@ use crate::pb::{hsm_client::HsmClient, Empty, HsmRequest, HsmRequestContext}; use crate::wire::{DaemonConnection, Message}; use anyhow::{anyhow, Context}; use anyhow::{Error, Result}; -use log::{error, info, warn}; +use log::{debug, error, info, warn}; use std::convert::TryFrom; use std::env; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net::UnixStream; +use std::path::PathBuf; use std::process::Command; use std::str; use std::sync::atomic; use std::sync::Arc; -#[cfg(unix)] -use tokio::net::UnixStream as TokioUnixStream; +use std::thread; +use tokio::runtime::Runtime; use tonic::transport::{Endpoint, Uri}; use tower::service_fn; use which::which; @@ -42,32 +43,35 @@ fn version() -> String { fn setup_node_stream() -> Result { let ms = unsafe { UnixStream::from_raw_fd(3) }; - Ok(DaemonConnection::new(TokioUnixStream::from_std(ms)?)) + Ok(DaemonConnection::new(ms)) } -fn start_handler(local: NodeConnection, counter: Arc, grpc: GrpcClient) { - tokio::spawn(async { - match process_requests(local, counter, grpc) - .await - .context("processing requests") - { +fn start_handler( + local: NodeConnection, + counter: Arc, + grpc: GrpcClient, + runtime: Arc, +) { + thread::spawn(move || { + match process_requests(local, counter, grpc, runtime).context("processing requests") { Ok(()) => panic!("why did the hsmproxy stop processing requests without an error?"), Err(e) => warn!("hsmproxy stopped processing requests with error: {}", e), } }); } -async fn process_requests( +fn process_requests( node_conn: NodeConnection, request_counter: Arc, mut server: GrpcClient, + runtime: Arc, ) -> Result<(), Error> { let conn = node_conn.conn; let context = node_conn.context; info!("Pinging server"); - server.ping(Empty::default()).await?; + runtime.block_on(server.ping(Empty::default()))?; loop { - if let Ok(msg) = conn.read().await { + if let Ok(msg) = conn.read() { match msg.msgtype() { 9 => { eprintln!("Got a message from node: {:?}", &msg.body); @@ -79,7 +83,7 @@ async fn process_requests( let (local, remote) = UnixStream::pair()?; let local = NodeConnection { - conn: DaemonConnection::new(TokioUnixStream::from_std(local)?), + conn: DaemonConnection::new(local), context: Some(ctx), }; let remote = remote.as_raw_fd(); @@ -87,8 +91,8 @@ async fn process_requests( let grpc = server.clone(); // Start new handler for the client - start_handler(local, request_counter.clone(), grpc); - if let Err(e) = conn.write(msg).await { + start_handler(local, request_counter.clone(), grpc, runtime.clone()); + if let Err(e) = conn.write(msg) { error!("error writing msg to node_connection: {:?}", e); return Err(e); } @@ -102,22 +106,23 @@ async fn process_requests( requests: Vec::new(), signer_state: Vec::new(), }); - let start_time = tokio::time::Instant::now(); + eprintln!( "WIRE: lightningd -> hsmd: Got a message from node: {:?}", &req ); - eprintln!("WIRE: hsmd -> plugin: Forwarding: {:?}", &req); - let res = server.request(req).await?.into_inner(); - let msg = Message::from_raw(res.raw); + let start_time = tokio::time::Instant::now(); + debug!("Got a message from node: {:?}", &req); + let res = runtime.block_on(server.request(req))?.into_inner(); let delta = start_time.elapsed(); + let msg = Message::from_raw(res.raw); eprintln!( "WIRE: plugin -> hsmd: Got respone from hsmd: {:?} after {}ms", &msg, delta.as_millis() ); eprintln!("WIRE: hsmd -> lightningd: {:?}", &msg); - conn.write(msg).await? + conn.write(msg)? } } } else { @@ -126,32 +131,34 @@ async fn process_requests( } } } -use std::path::PathBuf; -async fn grpc_connect() -> Result { - // We will ignore this uri because uds do not use it - // if your connector does use the uri it will be provided - // as the request to the `MakeConnection`. - // Connect to a Uds socket - let channel = Endpoint::try_from("http://[::]:50051")? - .connect_with_connector(service_fn(|_: Uri| { - let sock_path = get_sock_path().unwrap(); - let mut path = PathBuf::new(); - if !sock_path.starts_with('/') { - path.push(env::current_dir().unwrap()); - } - path.push(&sock_path); - let path = path.to_str().unwrap().to_string(); - info!("Connecting to hsmserver at {}", path); - TokioUnixStream::connect(path) - })) - .await - .context("could not connect to the socket file")?; +fn grpc_connect(runtime: &Runtime) -> Result { + runtime.block_on(async { + // We will ignore this uri because uds do not use it + // if your connector does use the uri it will be provided + // as the request to the `MakeConnection`. + // Connect to a Uds socket + let channel = Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(|_: Uri| { + let sock_path = get_sock_path().unwrap(); + let mut path = PathBuf::new(); + if !sock_path.starts_with('/') { + path.push(env::current_dir().unwrap()); + } + path.push(&sock_path); + + let path = path.to_str().unwrap().to_string(); + info!("Connecting to hsmserver at {}", path); + tokio::net::UnixStream::connect(path) + })) + .await + .context("could not connect to the socket file")?; - Ok(HsmClient::new(channel)) + Ok(HsmClient::new(channel)) + }) } -pub async fn run() -> Result<(), Error> { +pub fn run() -> Result<(), Error> { let args: Vec = std::env::args().collect(); // Start the counter at 1000 so we can inject some message before @@ -164,8 +171,16 @@ pub async fn run() -> Result<(), Error> { info!("Starting hsmproxy"); + // Create a dedicated tokio runtime for gRPC operations + let runtime = Arc::new( + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("failed to create tokio runtime")?, + ); + let node = setup_node_stream()?; - let grpc = grpc_connect().await?; + let grpc = grpc_connect(&runtime)?; process_requests( NodeConnection { @@ -174,6 +189,6 @@ pub async fn run() -> Result<(), Error> { }, request_counter, grpc, + runtime, ) - .await } diff --git a/libs/gl-signerproxy/src/lib.rs b/libs/gl-signerproxy/src/lib.rs index 240a03ed1..a3af560a0 100644 --- a/libs/gl-signerproxy/src/lib.rs +++ b/libs/gl-signerproxy/src/lib.rs @@ -12,7 +12,7 @@ impl Proxy { Proxy {} } - pub async fn run(&self) -> Result<()> { - hsmproxy::run().await + pub fn run(&self) -> Result<()> { + hsmproxy::run() } } diff --git a/libs/gl-signerproxy/src/wire.rs b/libs/gl-signerproxy/src/wire.rs index 9aeb1755c..28e4ba459 100644 --- a/libs/gl-signerproxy/src/wire.rs +++ b/libs/gl-signerproxy/src/wire.rs @@ -2,10 +2,10 @@ use crate::passfd::SyncFdPassingExt; use anyhow::{anyhow, Error, Result}; use byteorder::{BigEndian, ByteOrder}; use log::trace; +use std::io::{Read, Write}; use std::os::unix::io::{AsRawFd, RawFd}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::UnixStream; -use tokio::sync::Mutex; +use std::os::unix::net::UnixStream; +use std::sync::Mutex; /// A simple implementation of the inter-daemon protocol wrapping a /// UnixStream. Easy to read from and write to. @@ -65,11 +65,18 @@ impl DaemonConnection { } } - pub async fn read(&self) -> Result { - let mut sock = self.conn.lock().await; - let msglen = sock.read_u32().await?; - let mut buf = vec![0 as u8; msglen as usize]; - sock.read_exact(&mut buf).await?; + pub fn read(&self) -> Result { + let mut sock = self.conn.lock().unwrap(); + + // Read 4-byte length prefix in big-endian + let mut len_buf = [0u8; 4]; + sock.read_exact(&mut len_buf)?; + let msglen = BigEndian::read_u32(&len_buf); + + // Read the message body + let mut buf = vec![0u8; msglen as usize]; + sock.read_exact(&mut buf)?; + if buf.len() < msglen as usize { return Err(anyhow!("Short read from client")); } @@ -77,6 +84,7 @@ impl DaemonConnection { let typ = BigEndian::read_u16(&buf); let mut fds = vec![]; + // Receive any file descriptors associated with this message type let numfds = DaemonConnection::count_fds(typ); for _ in 0..numfds { fds.push(sock.as_raw_fd().recv_fd()?); @@ -89,17 +97,24 @@ impl DaemonConnection { } } - pub async fn write(&self, msg: Message) -> Result<(), Error> { + pub fn write(&self, msg: Message) -> Result<(), Error> { trace!( "Sending message {} ({} bytes, {} FDs)", msg.typ, msg.body.len(), msg.fds.len() ); - let mut client = self.conn.lock().await; - client.write_u32(msg.body.len() as u32).await?; - client.write_all(&msg.body).await?; + let mut client = self.conn.lock().unwrap(); + + // Write 4-byte length prefix in big-endian + let mut len_buf = [0u8; 4]; + BigEndian::write_u32(&mut len_buf, msg.body.len() as u32); + client.write_all(&len_buf)?; + + // Write the message body + client.write_all(&msg.body)?; + // Send any file descriptors for fd in msg.fds { client.as_raw_fd().send_fd(fd)?; } diff --git a/libs/gl-testing/gltesting/network.py b/libs/gl-testing/gltesting/network.py index 8e36b50f1..1bc9fb31c 100644 --- a/libs/gl-testing/gltesting/network.py +++ b/libs/gl-testing/gltesting/network.py @@ -27,6 +27,8 @@ def get_node(self, options=None, *args, **kwargs): options = {} options["allow-deprecated-apis"] = True options["developer"] = None + # Disable cln-grpc plugin to avoid port conflicts with GL nodes + options["disable-plugin"] = "cln-grpc" return NodeFactory.get_node(self, options=options, *args, **kwargs)