Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 157 additions & 5 deletions lightning-custom-message/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,25 @@ macro_rules! composite_custom_message_handler {
}

fn peer_connected(&self, their_node_id: $crate::bitcoin::secp256k1::PublicKey, msg: &$crate::lightning::ln::msgs::Init, inbound: bool) -> Result<(), ()> {
let mut result = Ok(());
// Per the `CustomMessageHandler::peer_connected` contract, `peer_disconnected`
// will not be called by `PeerManager` if we return `Err`. To avoid leaking
// per-peer state in sub-handlers that already returned `Ok` when a later one
// errors, record each sub-handler's result and roll back the successful ones
// ourselves before propagating the failure.
$(
if let Err(e) = self.$field.peer_connected(their_node_id, msg, inbound) {
result = Err(e);
}
let $field = self.$field.peer_connected(their_node_id, msg, inbound);
)*
result
let any_err = false $( || $field.is_err() )*;
if any_err {
$(
if $field.is_ok() {
self.$field.peer_disconnected(their_node_id);
}
)*
Err(())
} else {
Ok(())
}
}

fn provided_node_features(&self) -> $crate::lightning::types::features::NodeFeatures {
Expand Down Expand Up @@ -376,3 +388,143 @@ macro_rules! composite_custom_message_handler {
}
}
}

#[cfg(test)]
mod tests {
use bitcoin::secp256k1::PublicKey;
use core::sync::atomic::{AtomicUsize, Ordering};
use lightning::io;
use lightning::ln::msgs::{DecodeError, Init, LightningError};
use lightning::ln::peer_handler::CustomMessageHandler;
use lightning::ln::wire::{CustomMessageReader, Type};
use lightning::types::features::{InitFeatures, NodeFeatures};
use lightning::util::ser::{LengthLimitedRead, Writeable, Writer};

#[derive(Debug)]
pub struct Foo;
impl Type for Foo {
fn type_id(&self) -> u16 {
32768
}
}
impl Writeable for Foo {
fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
Ok(())
}
}

pub struct CountingHandler {
pub connect_count: AtomicUsize,
}
impl CustomMessageReader for CountingHandler {
type CustomMessage = Foo;
fn read<R: LengthLimitedRead>(
&self, _t: u16, _b: &mut R,
) -> Result<Option<Foo>, DecodeError> {
Ok(None)
}
}
impl CustomMessageHandler for CountingHandler {
fn handle_custom_message(&self, _msg: Foo, _: PublicKey) -> Result<(), LightningError> {
Ok(())
}
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Foo)> {
vec![]
}
fn peer_disconnected(&self, _: PublicKey) {
self.connect_count.fetch_sub(1, Ordering::SeqCst);
}
fn peer_connected(&self, _: PublicKey, _: &Init, _: bool) -> Result<(), ()> {
self.connect_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn provided_node_features(&self) -> NodeFeatures {
NodeFeatures::empty()
}
fn provided_init_features(&self, _: PublicKey) -> InitFeatures {
InitFeatures::empty()
}
}

#[derive(Debug)]
pub struct Bar;
impl Type for Bar {
fn type_id(&self) -> u16 {
32769
}
}
impl Writeable for Bar {
fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
Ok(())
}
}

pub struct ErroringHandler;
impl CustomMessageReader for ErroringHandler {
type CustomMessage = Bar;
fn read<R: LengthLimitedRead>(
&self, _t: u16, _b: &mut R,
) -> Result<Option<Bar>, DecodeError> {
Ok(None)
}
}
impl CustomMessageHandler for ErroringHandler {
fn handle_custom_message(&self, _msg: Bar, _: PublicKey) -> Result<(), LightningError> {
Ok(())
}
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Bar)> {
vec![]
}
fn peer_disconnected(&self, _: PublicKey) {
debug_assert!(false);
}
fn peer_connected(&self, _: PublicKey, _: &Init, _: bool) -> Result<(), ()> {
Err(())
}
fn provided_node_features(&self) -> NodeFeatures {
NodeFeatures::empty()
}
fn provided_init_features(&self, _: PublicKey) -> InitFeatures {
InitFeatures::empty()
}
}

composite_custom_message_handler!(
pub struct CompositeHandler {
counting: CountingHandler,
erroring: ErroringHandler,
}

pub enum CompositeMessage {
Foo(32768),
Bar(32769),
}
);

#[test]
fn peer_connected_failure_does_not_leak_subhandler_state() {
let composite = CompositeHandler {
counting: CountingHandler { connect_count: AtomicUsize::new(0) },
erroring: ErroringHandler,
};
let pk_bytes = [
0x02, 0x79, 0xBE, 0x66, 0x7E, 0xF9, 0xDC, 0xBB, 0xAC, 0x55, 0xA0, 0x62, 0x95, 0xCE,
0x87, 0x0B, 0x07, 0x02, 0x9B, 0xFC, 0xDB, 0x2D, 0xCE, 0x28, 0xD9, 0x59, 0xF2, 0x81,
0x5B, 0x16, 0xF8, 0x17, 0x98,
];
let pk = PublicKey::from_slice(&pk_bytes).unwrap();
let init =
Init { features: InitFeatures::empty(), networks: None, remote_network_address: None };

let result = composite.peer_connected(pk, &init, true);
assert!(result.is_err(), "Composite must propagate the inner Err");

let leaked = composite.counting.connect_count.load(Ordering::SeqCst);
assert_eq!(
leaked, 0,
"CountingHandler tracked {leaked} connected peer(s) after the composite \
returned Err; this state will never be cleaned up because per the trait \
contract peer_disconnected won't be called when peer_connected returns Err.",
);
}
}
Loading