diff --git a/lightning-block-sync/Cargo.toml b/lightning-block-sync/Cargo.toml index 97f199963ac..d8d71da3fae 100644 --- a/lightning-block-sync/Cargo.toml +++ b/lightning-block-sync/Cargo.toml @@ -16,15 +16,16 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -rest-client = [ "serde_json", "chunked_transfer" ] -rpc-client = [ "serde_json", "chunked_transfer" ] +rest-client = [ "serde_json", "dep:bitreq" ] +rpc-client = [ "serde_json", "dep:bitreq" ] +tokio = [ "dep:tokio", "bitreq?/async" ] [dependencies] bitcoin = "0.32.2" lightning = { version = "0.3.0", path = "../lightning" } tokio = { version = "1.35", features = [ "io-util", "net", "time", "rt" ], optional = true } serde_json = { version = "1.0", optional = true } -chunked_transfer = { version = "1.4", optional = true } +bitreq = { version = "0.3", default-features = false, features = ["std"], optional = true } [dev-dependencies] lightning = { version = "0.3.0", path = "../lightning", features = ["_test_utils"] } diff --git a/lightning-block-sync/src/http.rs b/lightning-block-sync/src/http.rs index 0fb82b4acde..8668e86b2fc 100644 --- a/lightning-block-sync/src/http.rs +++ b/lightning-block-sync/src/http.rs @@ -1,27 +1,16 @@ //! Simple HTTP implementation which supports both async and traditional execution environments //! with minimal dependencies. This is used as the basis for REST and RPC clients. -use chunked_transfer; use serde_json; +#[cfg(feature = "tokio")] +use bitreq::RequestExt; + use std::convert::TryFrom; use std::fmt; -#[cfg(not(feature = "tokio"))] -use std::io::Write; use std::net::{SocketAddr, ToSocketAddrs}; use std::time::Duration; -#[cfg(feature = "tokio")] -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; - -#[cfg(not(feature = "tokio"))] -use std::io::BufRead; -use std::io::Read; -#[cfg(not(feature = "tokio"))] -use std::net::TcpStream; - /// Timeout for operations on TCP streams. const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5); @@ -92,10 +81,15 @@ impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint { } } +/// Maximum number of cached connections in the connection pool. +#[cfg(feature = "tokio")] +const MAX_CONNECTIONS: usize = 10; + /// Client for making HTTP requests. pub(crate) struct HttpClient { address: SocketAddr, - stream: TcpStream, + #[cfg(feature = "tokio")] + client: bitreq::Client, } impl HttpClient { @@ -110,17 +104,18 @@ impl HttpClient { }, Some(address) => address, }; + + // Verify reachability by attempting a connection. let stream = std::net::TcpStream::connect_timeout(&address, TCP_STREAM_TIMEOUT)?; stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT))?; stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT))?; + drop(stream); - #[cfg(feature = "tokio")] - let stream = { - stream.set_nonblocking(true)?; - TcpStream::from_std(stream)? - }; - - Ok(Self { address, stream }) + Ok(Self { + address, + #[cfg(feature = "tokio")] + client: bitreq::Client::new(MAX_CONNECTIONS), + }) } /// Sends a `GET` request for a resource identified by `uri` at the `host`. @@ -131,14 +126,19 @@ impl HttpClient { where F: TryFrom, Error = std::io::Error>, { - let request = format!( - "GET {} HTTP/1.1\r\n\ - Host: {}\r\n\ - Connection: keep-alive\r\n\ - \r\n", - uri, host - ); - let response_body = self.send_request_with_retry(&request).await?; + let address = self.address; + let response_body = self + .send_request_with_retry(|| { + let url = format!("http://{}{}", address, uri); + bitreq::get(url) + .with_header("Host", host) + .with_header("Connection", "keep-alive") + .with_timeout(TCP_STREAM_RESPONSE_TIMEOUT.as_secs()) + .with_max_headers_size(Some(MAX_HTTP_MESSAGE_HEADER_SIZE)) + .with_max_status_line_length(Some(MAX_HTTP_MESSAGE_HEADER_SIZE)) + .with_max_body_size(Some(MAX_HTTP_MESSAGE_BODY_SIZE)) + }) + .await?; F::try_from(response_body) } @@ -154,30 +154,32 @@ impl HttpClient { where F: TryFrom, Error = std::io::Error>, { + let address = self.address; let content = content.to_string(); - let request = format!( - "POST {} HTTP/1.1\r\n\ - Host: {}\r\n\ - Authorization: {}\r\n\ - Connection: keep-alive\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}", - uri, - host, - auth, - content.len(), - content - ); - let response_body = self.send_request_with_retry(&request).await?; + let response_body = self + .send_request_with_retry(|| { + let url = format!("http://{}{}", address, uri); + bitreq::post(url) + .with_header("Host", host) + .with_header("Authorization", auth) + .with_header("Connection", "keep-alive") + .with_header("Content-Type", "application/json") + .with_timeout(TCP_STREAM_RESPONSE_TIMEOUT.as_secs()) + .with_max_headers_size(Some(MAX_HTTP_MESSAGE_HEADER_SIZE)) + .with_max_status_line_length(Some(MAX_HTTP_MESSAGE_HEADER_SIZE)) + .with_max_body_size(Some(MAX_HTTP_MESSAGE_BODY_SIZE)) + .with_body(content.clone()) + }) + .await?; F::try_from(response_body) } /// Sends an HTTP request message and reads the response, returning its body. Attempts to /// reconnect and retry if the connection has been closed. - async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result> { - match self.send_request(request).await { + async fn send_request_with_retry( + &mut self, build_request: impl Fn() -> bitreq::Request, + ) -> std::io::Result> { + match self.send_request(build_request()).await { Ok(bytes) => Ok(bytes), Err(_) => { // Reconnect and retry on fail. This can happen if the connection was closed after @@ -191,204 +193,53 @@ impl HttpClient { #[cfg(not(feature = "tokio"))] std::thread::sleep(Duration::from_millis(100)); *self = Self::connect(self.address)?; - self.send_request(request).await + self.send_request(build_request()).await }, } } /// Sends an HTTP request message and reads the response, returning its body. - async fn send_request(&mut self, request: &str) -> std::io::Result> { - self.write_request(request).await?; - self.read_response().await - } - - /// Writes an HTTP request message. - async fn write_request(&mut self, request: &str) -> std::io::Result<()> { - #[cfg(feature = "tokio")] - { - self.stream.write_all(request.as_bytes()).await?; - self.stream.flush().await - } - #[cfg(not(feature = "tokio"))] - { - self.stream.write_all(request.as_bytes())?; - self.stream.flush() - } - } - - /// Reads an HTTP response message. - async fn read_response(&mut self) -> std::io::Result> { + async fn send_request(&self, request: bitreq::Request) -> std::io::Result> { #[cfg(feature = "tokio")] - let stream = self.stream.split().0; + let response = request.send_async_with_client(&self.client).await.map_err(bitreq_to_io_error)?; #[cfg(not(feature = "tokio"))] - let stream = std::io::Read::by_ref(&mut self.stream); + let response = request.send().map_err(bitreq_to_io_error)?; - let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64); - - #[cfg(feature = "tokio")] - let mut reader = tokio::io::BufReader::new(limited_stream); - #[cfg(not(feature = "tokio"))] - let mut reader = std::io::BufReader::new(limited_stream); - - macro_rules! read_line { - () => { - read_line!(0) - }; - ($retry_count: expr) => {{ - let mut line = String::new(); - let mut timeout_count: u64 = 0; - let bytes_read = loop { - #[cfg(feature = "tokio")] - let read_res = reader.read_line(&mut line).await; - #[cfg(not(feature = "tokio"))] - let read_res = reader.read_line(&mut line); - match read_res { - Ok(bytes_read) => break bytes_read, - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - timeout_count += 1; - if timeout_count > $retry_count { - return Err(e); - } else { - continue; - } - }, - Err(e) => return Err(e), - } - }; - - match bytes_read { - 0 => None, - _ => { - // Remove trailing CRLF - if line.ends_with('\n') { - line.pop(); - if line.ends_with('\r') { - line.pop(); - } - } - Some(line) - }, - } - }}; - } - - // Read and parse status line - // Note that we allow retrying a few times to reach TCP_STREAM_RESPONSE_TIMEOUT. - let status_line = - read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs()) - .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?; - let status = HttpStatus::parse(&status_line)?; - - // Read and parse relevant headers - let mut message_length = HttpMessageLength::Empty; - loop { - let line = read_line!() - .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?; - if line.is_empty() { - break; - } - - let header = HttpHeader::parse(&line)?; - if header.has_name("Content-Length") { - let length = header - .value - .parse() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - if let HttpMessageLength::Empty = message_length { - message_length = HttpMessageLength::ContentLength(length); - } - continue; - } - - if header.has_name("Transfer-Encoding") { - message_length = HttpMessageLength::TransferEncoding(header.value.into()); - continue; - } - } + let status_code = response.status_code; + let body = response.into_bytes(); - // Read message body - let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len(); - reader.get_mut().set_limit(read_limit as u64); - let contents = match message_length { - HttpMessageLength::Empty => Vec::new(), - HttpMessageLength::ContentLength(length) => { - if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("invalid response length: {} bytes", length), - )); - } else { - let mut content = vec![0; length]; - #[cfg(feature = "tokio")] - reader.read_exact(&mut content[..]).await?; - #[cfg(not(feature = "tokio"))] - reader.read_exact(&mut content[..])?; - content - } - }, - HttpMessageLength::TransferEncoding(coding) => { - if !coding.eq_ignore_ascii_case("chunked") { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "unsupported transfer coding", - )); - } else { - let mut content = Vec::new(); - #[cfg(feature = "tokio")] - { - // Since chunked_transfer doesn't have an async interface, only use it to - // determine the size of each chunk to read. - // - // TODO: Replace with an async interface when available. - // https://github.com/frewsxcv/rust-chunked-transfer/issues/7 - loop { - // Read the chunk header which contains the chunk size. - let mut chunk_header = String::new(); - reader.read_line(&mut chunk_header).await?; - if chunk_header == "0\r\n" { - // Read the terminator chunk since the decoder consumes the CRLF - // immediately when this chunk is encountered. - reader.read_line(&mut chunk_header).await?; - } - - // Decode the chunk header to obtain the chunk size. - let mut buffer = Vec::new(); - let mut decoder = - chunked_transfer::Decoder::new(chunk_header.as_bytes()); - decoder.read_to_end(&mut buffer)?; - - // Read the chunk body. - let chunk_size = match decoder.remaining_chunks_size() { - None => break, - Some(chunk_size) => chunk_size, - }; - let chunk_offset = content.len(); - content.resize(chunk_offset + chunk_size + "\r\n".len(), 0); - reader.read_exact(&mut content[chunk_offset..]).await?; - content.resize(chunk_offset + chunk_size, 0); - } - content - } - #[cfg(not(feature = "tokio"))] - { - let mut decoder = chunked_transfer::Decoder::new(reader); - decoder.read_to_end(&mut content)?; - content - } - } - }, - }; - - if !status.is_ok() { - // TODO: Handle 3xx redirection responses. - let error = HttpError { status_code: status.code.to_string(), contents }; + if !(200..300).contains(&status_code) { + let error = HttpError { status_code: status_code.to_string(), contents: body }; return Err(std::io::Error::new(std::io::ErrorKind::Other, error)); } - Ok(contents) + Ok(body) } } +/// Converts a bitreq error to an std::io::Error. +fn bitreq_to_io_error(err: bitreq::Error) -> std::io::Error { + use std::io::ErrorKind; + + let kind = match &err { + bitreq::Error::IoError(e) => e.kind(), + bitreq::Error::HeadersOverflow + | bitreq::Error::StatusLineOverflow + | bitreq::Error::BodyOverflow + | bitreq::Error::MalformedChunkLength + | bitreq::Error::MalformedChunkEnd + | bitreq::Error::MalformedContentLength + | bitreq::Error::InvalidUtf8InResponse + | bitreq::Error::InvalidUtf8InBody(_) => ErrorKind::InvalidData, + bitreq::Error::AddressNotFound | bitreq::Error::HttpsFeatureNotEnabled => { + ErrorKind::InvalidInput + }, + _ => ErrorKind::Other, + }; + + std::io::Error::new(kind, err) +} + /// HTTP error consisting of a status code and body contents. #[derive(Debug)] pub(crate) struct HttpError { @@ -405,94 +256,6 @@ impl fmt::Display for HttpError { } } -/// HTTP response status code as defined by [RFC 7231]. -/// -/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6 -struct HttpStatus<'a> { - code: &'a str, -} - -impl<'a> HttpStatus<'a> { - /// Parses an HTTP status line as defined by [RFC 7230]. - /// - /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2 - fn parse(line: &'a String) -> std::io::Result> { - let mut tokens = line.splitn(3, ' '); - - let http_version = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?; - if !http_version.eq_ignore_ascii_case("HTTP/1.1") - && !http_version.eq_ignore_ascii_case("HTTP/1.0") - { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid HTTP-Version", - )); - } - - let code = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?; - if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid Status-Code", - )); - } - - let _reason = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?; - - Ok(Self { code }) - } - - /// Returns whether the status is successful (i.e., 2xx status class). - fn is_ok(&self) -> bool { - self.code.starts_with('2') - } -} - -/// HTTP response header as defined by [RFC 7231]. -/// -/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7 -struct HttpHeader<'a> { - name: &'a str, - value: &'a str, -} - -impl<'a> HttpHeader<'a> { - /// Parses an HTTP header field as defined by [RFC 7230]. - /// - /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2 - fn parse(line: &'a String) -> std::io::Result> { - let mut tokens = line.splitn(2, ':'); - let name = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?; - let value = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))? - .trim_start(); - Ok(Self { name, value }) - } - - /// Returns whether the header field has the given name. - fn has_name(&self, name: &str) -> bool { - self.name.eq_ignore_ascii_case(name) - } -} - -/// HTTP message body length as defined by [RFC 7230]. -/// -/// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3 -enum HttpMessageLength { - Empty, - ContentLength(usize), - TransferEncoding(String), -} - /// An HTTP response body in binary format. pub struct BinaryResponse(pub Vec); @@ -572,16 +335,26 @@ mod endpoint_tests { #[cfg(test)] pub(crate) mod client_tests { use super::*; - use std::io::BufRead; - use std::io::Write; + use std::io::{BufRead, Read, Write}; /// Server for handling HTTP client requests with a stock response. pub struct HttpServer { address: std::net::SocketAddr, - handler: std::thread::JoinHandle<()>, + handler: Option>, shutdown: std::sync::Arc, } + impl Drop for HttpServer { + fn drop(&mut self) { + self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst); + // Make a connection to unblock the listener's accept() call + let _ = std::net::TcpStream::connect(self.address); + if let Some(handler) = self.handler.take() { + let _ = handler.join(); + } + } + } + /// Body of HTTP response messages. pub enum MessageBody { Empty, @@ -589,10 +362,27 @@ pub(crate) mod client_tests { ChunkedContent(T), } + /// Encodes a body using chunked transfer encoding. + fn encode_chunked(body: &str, chunk_size: usize) -> String { + let mut out = String::new(); + for chunk in body.as_bytes().chunks(chunk_size) { + out.push_str(&format!("{:X}\r\n", chunk.len())); + out.push_str(std::str::from_utf8(chunk).unwrap()); + out.push_str("\r\n"); + } + out.push_str("0\r\n\r\n"); + out + } + impl HttpServer { fn responding_with_body(status: &str, body: MessageBody) -> Self { let response = match body { - MessageBody::Empty => format!("{}\r\n\r\n", status), + MessageBody::Empty => format!( + "{}\r\n\ + Content-Length: 0\r\n\ + \r\n", + status + ), MessageBody::Content(body) => { let body = body.to_string(); format!( @@ -606,19 +396,14 @@ pub(crate) mod client_tests { ) }, MessageBody::ChunkedContent(body) => { - let mut chuncked_body = Vec::new(); - { - use chunked_transfer::Encoder; - let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8); - encoder.write_all(body.to_string().as_bytes()).unwrap(); - } + let body = body.to_string(); + let chunked_body = encode_chunked(&body, 8); format!( "{}\r\n\ Transfer-Encoding: chunked\r\n\ \r\n\ {}", - status, - String::from_utf8(chuncked_body).unwrap() + status, chunked_body ) }, }; @@ -646,38 +431,75 @@ pub(crate) mod client_tests { let shutdown_signaled = std::sync::Arc::clone(&shutdown); let handler = std::thread::spawn(move || { for stream in listener.incoming() { - let mut stream = stream.unwrap(); + if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { + return; + } + + let stream = stream.unwrap(); stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap(); + stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap(); - let lines_read = std::io::BufReader::new(&stream) - .lines() - .take_while(|line| !line.as_ref().unwrap().is_empty()) - .count(); - if lines_read == 0 { - continue; - } + let mut reader = std::io::BufReader::new(stream); - for chunk in response.as_bytes().chunks(16) { + // Handle multiple requests on the same connection (keep-alive) + loop { if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { return; - } else { - if let Err(_) = stream.write(chunk) { + } + + // Read request headers + let mut lines_read = 0; + let mut content_length: usize = 0; + loop { + let mut line = String::new(); + match reader.read_line(&mut line) { + Ok(0) => break, // eof + Ok(_) => { + if line == "\r\n" || line == "\n" { + break; // end of headers + } + // Parse content_length for POST body handling + if let Some(value) = line.strip_prefix("Content-Length:") { + content_length = value.trim().parse().unwrap_or(0); + } + lines_read += 1; + }, + Err(_) => break, // Read error or timeout + } + } + + if lines_read == 0 { + break; // No request received, connection closed + } + + // Consume request body if present (needed for POST keep-alive) + if content_length > 0 { + let mut body = vec![0u8; content_length]; + if reader.read_exact(&mut body).is_err() { break; } - if let Err(_) = stream.flush() { + } + + // Send response + let stream = reader.get_mut(); + let mut write_error = false; + for chunk in response.as_bytes().chunks(16) { + if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { + return; + } + if stream.write(chunk).is_err() || stream.flush().is_err() { + write_error = true; break; } } + if write_error { + break; + } } } }); - Self { address, handler, shutdown } - } - - fn shutdown(self) { - self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst); - self.handler.join().unwrap(); + Self { address, handler: Some(handler), shutdown } } pub fn endpoint(&self) -> HttpEndpoint { @@ -735,93 +557,6 @@ pub(crate) mod client_tests { } } - #[tokio::test] - async fn read_empty_message() { - let server = HttpServer::responding_with("".to_string()); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no status line"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_incomplete_message() { - let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string()); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no headers"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_too_large_message_headers() { - let response = format!( - "HTTP/1.1 302 Found\r\n\ - Location: {}\r\n\ - \r\n", - "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE) - ); - let server = HttpServer::responding_with(response); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no headers"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_too_large_message_body() { - let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1); - let server = HttpServer::responding_with_ok::(MessageBody::Content(body)); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!( - e.get_ref().unwrap().to_string(), - "invalid response length: 8032001 bytes" - ); - }, - Ok(_) => panic!("Expected error"), - } - server.shutdown(); - } - - #[tokio::test] - async fn read_message_with_unsupported_transfer_coding() { - let response = String::from( - "HTTP/1.1 200 OK\r\n\ - Transfer-Encoding: gzip\r\n\ - \r\n\ - foobar", - ); - let server = HttpServer::responding_with(response); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput); - assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding"); - }, - Ok(_) => panic!("Expected error"), - } - } - #[tokio::test] async fn read_error() { let server = HttpServer::responding_with_server_error("foo");