diff --git a/lightning/src/blinded_path/payment.rs b/lightning/src/blinded_path/payment.rs index 03b676adc92..e97f93146f9 100644 --- a/lightning/src/blinded_path/payment.rs +++ b/lightning/src/blinded_path/payment.rs @@ -161,8 +161,35 @@ impl BlindedPaymentPath { ) } - fn new_inner( - intermediate_nodes: &[PaymentForwardNode], payee_node_id: PublicKey, + /// Create a blinded path for a trampoline payment, to be forwarded along `intermediate_nodes`. + #[cfg(any(test, feature = "_test_utils"))] + pub(crate) fn new_for_trampoline< + ES: EntropySource, + T: secp256k1::Signing + secp256k1::Verification, + >( + intermediate_nodes: &[ForwardNode], payee_node_id: PublicKey, + local_node_receive_key: ReceiveAuthKey, payee_tlvs: ReceiveTlvs, htlc_maximum_msat: u64, + min_final_cltv_expiry_delta: u16, entropy_source: ES, secp_ctx: &Secp256k1, + ) -> Result { + Self::new_inner( + intermediate_nodes, + payee_node_id, + local_node_receive_key, + &[], + payee_tlvs, + htlc_maximum_msat, + min_final_cltv_expiry_delta, + entropy_source, + secp_ctx, + ) + } + + fn new_inner< + F: ForwardTlvsInfo, + ES: EntropySource, + T: secp256k1::Signing + secp256k1::Verification, + >( + intermediate_nodes: &[ForwardNode], payee_node_id: PublicKey, local_node_receive_key: ReceiveAuthKey, dummy_tlvs: &[DummyTlvs], payee_tlvs: ReceiveTlvs, htlc_maximum_msat: u64, min_final_cltv_expiry_delta: u16, entropy_source: ES, secp_ctx: &Secp256k1, @@ -323,18 +350,36 @@ impl BlindedPaymentPath { } } -/// An intermediate node, its outbound channel, and relay parameters. +/// Common interface for forward TLV types used in blinded payment paths. +/// +/// Both [`ForwardTlvs`] (channel-based forwarding) and [`TrampolineForwardTlvs`] (trampoline +/// node-based forwarding) implement this trait, allowing blinded path construction to be generic +/// over the forwarding mechanism. +pub trait ForwardTlvsInfo: Writeable + Clone { + /// The payment relay parameters for this hop. + fn payment_relay(&self) -> &PaymentRelay; + /// The payment constraints for this hop. + fn payment_constraints(&self) -> &PaymentConstraints; + /// The features for this hop. + fn features(&self) -> &BlindedHopFeatures; +} + +/// An intermediate node, its forwarding parameters, and its [`ForwardTlvsInfo`] for use in a +/// [`BlindedPaymentPath`]. #[derive(Clone, Debug)] -pub struct PaymentForwardNode { +pub struct ForwardNode { /// The TLVs for this node's [`BlindedHop`], where the fee parameters contained within are also /// used for [`BlindedPayInfo`] construction. - pub tlvs: ForwardTlvs, + pub tlvs: F, /// This node's pubkey. pub node_id: PublicKey, /// The maximum value, in msat, that may be accepted by this node. pub htlc_maximum_msat: u64, } +/// An intermediate node for a regular (non-trampoline) [`BlindedPaymentPath`]. +pub type PaymentForwardNode = ForwardNode; + /// Data to construct a [`BlindedHop`] for forwarding a payment. #[derive(Clone, Debug)] pub struct ForwardTlvs { @@ -354,6 +399,18 @@ pub struct ForwardTlvs { pub next_blinding_override: Option, } +impl ForwardTlvsInfo for ForwardTlvs { + fn payment_relay(&self) -> &PaymentRelay { + &self.payment_relay + } + fn payment_constraints(&self) -> &PaymentConstraints { + &self.payment_constraints + } + fn features(&self) -> &BlindedHopFeatures { + &self.features + } +} + /// Data to construct a [`BlindedHop`] for forwarding a Trampoline payment. #[derive(Clone, Debug)] pub struct TrampolineForwardTlvs { @@ -373,6 +430,18 @@ pub struct TrampolineForwardTlvs { pub next_blinding_override: Option, } +impl ForwardTlvsInfo for TrampolineForwardTlvs { + fn payment_relay(&self) -> &PaymentRelay { + &self.payment_relay + } + fn payment_constraints(&self) -> &PaymentConstraints { + &self.payment_constraints + } + fn features(&self) -> &BlindedHopFeatures { + &self.features + } +} + /// TLVs carried by a dummy hop within a blinded payment path. /// /// Dummy hops do not correspond to real forwarding decisions, but are processed @@ -440,8 +509,8 @@ pub(crate) enum BlindedTrampolineTlvs { // Used to include forward and receive TLVs in the same iterator for encoding. #[derive(Clone)] -enum BlindedPaymentTlvsRef<'a> { - Forward(&'a ForwardTlvs), +enum BlindedPaymentTlvsRef<'a, F: ForwardTlvsInfo = ForwardTlvs> { + Forward(&'a F), Dummy(&'a DummyTlvs), Receive(&'a ReceiveTlvs), } @@ -619,7 +688,7 @@ impl Writeable for ReceiveTlvs { } } -impl<'a> Writeable for BlindedPaymentTlvsRef<'a> { +impl<'a, F: ForwardTlvsInfo> Writeable for BlindedPaymentTlvsRef<'a, F> { fn write(&self, w: &mut W) -> Result<(), io::Error> { match self { Self::Forward(tlvs) => tlvs.write(w)?, @@ -723,8 +792,8 @@ impl Readable for BlindedTrampolineTlvs { pub(crate) const PAYMENT_PADDING_ROUND_OFF: usize = 30; /// Construct blinded payment hops for the given `intermediate_nodes` and payee info. -pub(super) fn blinded_hops( - secp_ctx: &Secp256k1, intermediate_nodes: &[PaymentForwardNode], payee_node_id: PublicKey, +pub(super) fn blinded_hops( + secp_ctx: &Secp256k1, intermediate_nodes: &[ForwardNode], payee_node_id: PublicKey, dummy_tlvs: &[DummyTlvs], payee_tlvs: ReceiveTlvs, session_priv: &SecretKey, local_node_receive_key: ReceiveAuthKey, ) -> Vec { @@ -823,15 +892,15 @@ where Ok((curr_base_fee, curr_prop_mil)) } -pub(super) fn compute_payinfo( - intermediate_nodes: &[PaymentForwardNode], dummy_tlvs: &[DummyTlvs], payee_tlvs: &ReceiveTlvs, +pub(super) fn compute_payinfo( + intermediate_nodes: &[ForwardNode], dummy_tlvs: &[DummyTlvs], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64, min_final_cltv_expiry_delta: u16, ) -> Result { let routing_fees = intermediate_nodes .iter() .map(|node| RoutingFees { - base_msat: node.tlvs.payment_relay.fee_base_msat, - proportional_millionths: node.tlvs.payment_relay.fee_proportional_millionths, + base_msat: node.tlvs.payment_relay().fee_base_msat, + proportional_millionths: node.tlvs.payment_relay().fee_proportional_millionths, }) .chain(dummy_tlvs.iter().map(|tlvs| RoutingFees { base_msat: tlvs.payment_relay.fee_base_msat, @@ -847,24 +916,24 @@ pub(super) fn compute_payinfo( for node in intermediate_nodes.iter() { // In the future, we'll want to take the intersection of all supported features for the // `BlindedPayInfo`, but there are no features in that context right now. - if node.tlvs.features.requires_unknown_bits_from(&BlindedHopFeatures::empty()) { + if node.tlvs.features().requires_unknown_bits_from(&BlindedHopFeatures::empty()) { return Err(()); } cltv_expiry_delta = - cltv_expiry_delta.checked_add(node.tlvs.payment_relay.cltv_expiry_delta).ok_or(())?; + cltv_expiry_delta.checked_add(node.tlvs.payment_relay().cltv_expiry_delta).ok_or(())?; // The min htlc for an intermediate node is that node's min minus the fees charged by all of the // following hops for forwarding that min, since that fee amount will automatically be included // in the amount that this node receives and contribute towards reaching its min. htlc_minimum_msat = amt_to_forward_msat( - core::cmp::max(node.tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat), - &node.tlvs.payment_relay, + core::cmp::max(node.tlvs.payment_constraints().htlc_minimum_msat, htlc_minimum_msat), + node.tlvs.payment_relay(), ) .unwrap_or(1); // If underflow occurs, we definitely reached this node's min htlc_maximum_msat = amt_to_forward_msat( core::cmp::min(node.htlc_maximum_msat, htlc_maximum_msat), - &node.tlvs.payment_relay, + node.tlvs.payment_relay(), ) .ok_or(())?; // If underflow occurs, we cannot send to this hop without exceeding their max } @@ -1038,8 +1107,14 @@ mod tests { payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 }, payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}), }; - let blinded_payinfo = - super::compute_payinfo(&[], &[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap(); + let blinded_payinfo = super::compute_payinfo::( + &[], + &[], + &recv_tlvs, + 4242, + TEST_FINAL_CLTV as u16, + ) + .unwrap(); assert_eq!(blinded_payinfo.fee_base_msat, 0); assert_eq!(blinded_payinfo.fee_proportional_millionths, 0); assert_eq!(blinded_payinfo.cltv_expiry_delta, TEST_FINAL_CLTV as u16); diff --git a/lightning/src/events/mod.rs b/lightning/src/events/mod.rs index 011b7f595bc..df359497438 100644 --- a/lightning/src/events/mod.rs +++ b/lightning/src/events/mod.rs @@ -174,6 +174,9 @@ pub enum PaymentPurpose { /// Because this is a spontaneous payment, the payer generated their own preimage rather than us /// (the payee) providing a preimage. SpontaneousPayment(PaymentPreimage), + /// HTLCs terminating at our node are intended for forwarding onwards as a trampoline + /// forward. + Trampoline {}, } impl PaymentPurpose { @@ -184,6 +187,7 @@ impl PaymentPurpose { PaymentPurpose::Bolt12OfferPayment { payment_preimage, .. } => *payment_preimage, PaymentPurpose::Bolt12RefundPayment { payment_preimage, .. } => *payment_preimage, PaymentPurpose::SpontaneousPayment(preimage) => Some(*preimage), + PaymentPurpose::Trampoline {} => None, } } @@ -193,6 +197,7 @@ impl PaymentPurpose { PaymentPurpose::Bolt12OfferPayment { .. } => false, PaymentPurpose::Bolt12RefundPayment { .. } => false, PaymentPurpose::SpontaneousPayment(..) => true, + PaymentPurpose::Trampoline {} => false, } } @@ -240,8 +245,9 @@ impl_writeable_tlv_based_enum_legacy!(PaymentPurpose, (2, payment_secret, required), (4, payment_context, required), }, + (3, Trampoline) => {}, ; - (2, SpontaneousPayment) + (2, SpontaneousPayment), ); /// Information about an HTLC that is part of a payment that can be claimed. @@ -1932,6 +1938,11 @@ impl Writeable for Event { PaymentPurpose::SpontaneousPayment(preimage) => { payment_preimage = Some(*preimage); }, + PaymentPurpose::Trampoline {} => { + payment_secret = None; + payment_preimage = None; + payment_context = None; + }, } let skimmed_fee_opt = if counterparty_skimmed_fee_msat == 0 { None diff --git a/lightning/src/ln/blinded_payment_tests.rs b/lightning/src/ln/blinded_payment_tests.rs index e148ce2c474..c9cb4211ac4 100644 --- a/lightning/src/ln/blinded_payment_tests.rs +++ b/lightning/src/ln/blinded_payment_tests.rs @@ -8,13 +8,14 @@ // licenses. use crate::blinded_path::payment::{ - BlindedPaymentPath, Bolt12RefundContext, DummyTlvs, ForwardTlvs, PaymentConstraints, - PaymentContext, PaymentForwardNode, PaymentRelay, ReceiveTlvs, PAYMENT_PADDING_ROUND_OFF, + BlindedPaymentPath, Bolt12RefundContext, DummyTlvs, ForwardNode, ForwardTlvs, + PaymentConstraints, PaymentContext, PaymentForwardNode, PaymentRelay, ReceiveTlvs, + PAYMENT_PADDING_ROUND_OFF, }; use crate::blinded_path::utils::is_padded; use crate::blinded_path::{self, BlindedHop}; use crate::events::{Event, HTLCHandlingFailureType, PaymentFailureReason}; -use crate::ln::channelmanager::{self, HTLCFailureMsg, PaymentId}; +use crate::ln::channelmanager::{self, HTLCFailureMsg, PaymentId, MPP_TIMEOUT_TICKS}; use crate::ln::functional_test_utils::*; use crate::ln::inbound_payment::ExpandedKey; use crate::ln::msgs::{ @@ -34,7 +35,7 @@ use crate::routing::router::{ use crate::sign::{NodeSigner, PeerStorageKey, ReceiveAuthKey, Recipient}; use crate::types::features::{BlindedHopFeatures, ChannelFeatures, NodeFeatures}; use crate::types::payment::{PaymentHash, PaymentSecret}; -use crate::util::config::{HTLCInterceptionFlags, UserConfig}; +use crate::util::config::{ChannelConfig, HTLCInterceptionFlags, UserConfig}; use crate::util::ser::{WithoutLength, Writeable}; use crate::util::test_utils::{self, bytes_from_hex, pubkey_from_hex, secret_from_hex}; use bitcoin::hex::DisplayHex; @@ -2420,50 +2421,6 @@ fn test_trampoline_blinded_receive() { do_test_trampoline_relay(true, TrampolineTestCase::OuterCLTVLessThanTrampoline); } -/// Creates a blinded tail where Carol receives via a blinded path. -fn create_blinded_tail( - secp_ctx: &Secp256k1, override_random_bytes: [u8; 32], carol_node_id: PublicKey, - carol_auth_key: ReceiveAuthKey, trampoline_cltv_expiry_delta: u32, - excess_final_cltv_delta: u32, final_value_msat: u64, payment_secret: PaymentSecret, -) -> BlindedTail { - let outer_session_priv = SecretKey::from_slice(&override_random_bytes).unwrap(); - let trampoline_session_priv = onion_utils::compute_trampoline_session_priv(&outer_session_priv); - - let carol_blinding_point = PublicKey::from_secret_key(&secp_ctx, &trampoline_session_priv); - let carol_blinded_hops = { - let payee_tlvs = ReceiveTlvs { - payment_secret, - payment_constraints: PaymentConstraints { - max_cltv_expiry: u32::max_value(), - htlc_minimum_msat: final_value_msat, - }, - payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}), - } - .encode(); - - let path = [((carol_node_id, Some(carol_auth_key)), WithoutLength(&payee_tlvs))]; - - blinded_path::utils::construct_blinded_hops( - &secp_ctx, - path.into_iter(), - &trampoline_session_priv, - ) - }; - - BlindedTail { - trampoline_hops: vec![TrampolineHop { - pubkey: carol_node_id, - node_features: Features::empty(), - fee_msat: final_value_msat, - cltv_expiry_delta: trampoline_cltv_expiry_delta + excess_final_cltv_delta, - }], - hops: carol_blinded_hops, - blinding_point: carol_blinding_point, - excess_final_cltv_expiry_delta: excess_final_cltv_delta, - final_value_msat, - } -} - // Creates a replacement onion that is used to produce scenarios that we don't support, specifically // payloads that send to unblinded receives and invalid payloads. fn replacement_onion( @@ -2631,15 +2588,23 @@ fn do_test_trampoline_relay(blinded: bool, test_case: TrampolineTestCase) { // Create a blinded tail where Carol is receiving. In our unblinded test cases, we'll // override this anyway (with a tail sending to an unblinded receive, which LDK doesn't // allow). - blinded_tail: Some(create_blinded_tail( + blinded_tail: Some(create_trampoline_forward_blinded_tail( &secp_ctx, - override_random_bytes, + &nodes[2].keys_manager, + &[], carol_node_id, nodes[2].keys_manager.get_receive_auth_key(), + ReceiveTlvs { + payment_secret, + payment_constraints: PaymentConstraints { + max_cltv_expiry: u32::max_value(), + htlc_minimum_msat: original_amt_msat, + }, + payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}), + }, original_trampoline_cltv, excess_final_cltv, original_amt_msat, - payment_secret, )), }], route_params: None, @@ -2752,122 +2717,289 @@ fn do_test_trampoline_relay(blinded: bool, test_case: TrampolineTestCase) { } } -#[test] -#[rustfmt::skip] -fn test_trampoline_forward_rejection() { - const TOTAL_NODE_COUNT: usize = 3; - - let chanmon_cfgs = create_chanmon_cfgs(TOTAL_NODE_COUNT); - let node_cfgs = create_node_cfgs(TOTAL_NODE_COUNT, &chanmon_cfgs); - let node_chanmgrs = create_node_chanmgrs(TOTAL_NODE_COUNT, &node_cfgs, &vec![None; TOTAL_NODE_COUNT]); - let mut nodes = create_network(TOTAL_NODE_COUNT, &node_cfgs, &node_chanmgrs); - - let (_, _, chan_id_alice_bob, _) = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1_000_000, 0); - let (_, _, chan_id_bob_carol, _) = create_announced_chan_between_nodes_with_value(&nodes, 1, 2, 1_000_000, 0); - - for i in 0..TOTAL_NODE_COUNT { // connect all nodes' blocks - connect_blocks(&nodes[i], (TOTAL_NODE_COUNT as u32) * CHAN_CONFIRM_DEPTH + 1 - nodes[i].best_block_info().1); - } - - let alice_node_id = nodes[0].node().get_our_node_id(); - let bob_node_id = nodes[1].node().get_our_node_id(); - let carol_node_id = nodes[2].node().get_our_node_id(); +/// Sets up channels and sends a trampoline MPP payment across two paths. +/// +/// Topology: +/// Alice (0) --> Bob (1) --> Carol (2, trampoline node) +/// Alice (0) --> Barry (3) --> Carol (2, trampoline node) +/// +/// Carol's inner trampoline onion is a forward to an unknown next node. We don't need the +/// next hop as a real node since forwarding isn't implemented yet -- we just need the onion to +/// contain a valid forward payload. +/// +/// Returns (payment_hash, per_path_amount, ev_to_bob, ev_to_barry). +fn send_trampoline_mpp_payment<'a, 'b, 'c>( + nodes: &'a Vec>, +) -> (PaymentHash, u64, MessageSendEvent, MessageSendEvent) { + let secp_ctx = Secp256k1::new(); - let alice_bob_scid = nodes[0].node().list_channels().iter().find(|c| c.channel_id == chan_id_alice_bob).unwrap().short_channel_id.unwrap(); - let bob_carol_scid = nodes[1].node().list_channels().iter().find(|c| c.channel_id == chan_id_bob_carol).unwrap().short_channel_id.unwrap(); + let alice_bob_chan = + create_announced_chan_between_nodes_with_value(nodes, 0, 1, 1_000_000, 0).2; + let bob_carol_chan = + create_announced_chan_between_nodes_with_value(nodes, 1, 2, 1_000_000, 0).2; + let alice_barry_chan = + create_announced_chan_between_nodes_with_value(nodes, 0, 3, 1_000_000, 0).2; + let barry_carol_chan = + create_announced_chan_between_nodes_with_value(nodes, 3, 2, 1_000_000, 0).2; + + let per_path_amt = 500_000; + let total_amt = per_path_amt * 2; + let (_, payment_hash, payment_secret) = + get_payment_preimage_hash(&nodes[2], Some(total_amt), None); + + let bob_node_id = nodes[1].node.get_our_node_id(); + let carol_node_id = nodes[2].node.get_our_node_id(); + let barry_node_id = nodes[3].node.get_our_node_id(); + + let alice_bob_scid = get_scid_from_channel_id(&nodes[0], alice_bob_chan); + let bob_carol_scid = get_scid_from_channel_id(&nodes[1], bob_carol_chan); + let alice_barry_scid = get_scid_from_channel_id(&nodes[0], alice_barry_chan); + let barry_carol_scid = get_scid_from_channel_id(&nodes[3], barry_carol_chan); + + let trampoline_cltv = 42; + let excess_final_cltv = 70; - let amt_msat = 1000; - let (payment_preimage, payment_hash, _) = get_payment_preimage_hash(&nodes[2], Some(amt_msat), None); + // Not we don't actually have an outgoing channel for Carol, we just use our default fee + // policy. + let carol_relay = ChannelConfig::default(); - let route = Route { - paths: vec![Path { - hops: vec![ - // Bob - RouteHop { - pubkey: bob_node_id, - node_features: NodeFeatures::empty(), - short_channel_id: alice_bob_scid, - channel_features: ChannelFeatures::empty(), - fee_msat: 1000, - cltv_expiry_delta: 48, - maybe_announced_channel: false, + let next_trampoline = PublicKey::from_slice(&[2; 33]).unwrap(); + let fwd_tail = || { + let intermediate_nodes = [ForwardNode { + tlvs: blinded_path::payment::TrampolineForwardTlvs { + next_trampoline, + payment_constraints: PaymentConstraints { + max_cltv_expiry: u32::max_value(), + htlc_minimum_msat: 1, }, + features: BlindedHopFeatures::empty(), + payment_relay: PaymentRelay { + cltv_expiry_delta: carol_relay.cltv_expiry_delta, + fee_proportional_millionths: carol_relay.forwarding_fee_proportional_millionths, + fee_base_msat: carol_relay.forwarding_fee_base_msat, + }, + next_blinding_override: None, + }, + node_id: carol_node_id, + htlc_maximum_msat: u64::max_value(), + }]; + let payee_tlvs = ReceiveTlvs { + payment_secret: PaymentSecret([0; 32]), + payment_constraints: PaymentConstraints { + max_cltv_expiry: u32::max_value(), + htlc_minimum_msat: 1, + }, + payment_context: PaymentContext::Bolt12Refund(Bolt12RefundContext {}), + }; + create_trampoline_forward_blinded_tail( + &secp_ctx, + &nodes[2].keys_manager, + &intermediate_nodes, + next_trampoline, + ReceiveAuthKey([0; 32]), + payee_tlvs, + trampoline_cltv, + excess_final_cltv, + per_path_amt, + ) + }; - // Carol - RouteHop { - pubkey: carol_node_id, - node_features: NodeFeatures::empty(), - short_channel_id: bob_carol_scid, - channel_features: ChannelFeatures::empty(), - fee_msat: 0, - cltv_expiry_delta: 24 + 24 + 39, - maybe_announced_channel: false, - } - ], - blinded_tail: Some(BlindedTail { - trampoline_hops: vec![ - // Carol - TrampolineHop { - pubkey: carol_node_id, - node_features: Features::empty(), - fee_msat: amt_msat, - cltv_expiry_delta: 24, - }, + let hop = |pubkey, short_channel_id, fee_msat, cltv_expiry_delta| RouteHop { + pubkey, + node_features: NodeFeatures::empty(), + short_channel_id, + channel_features: ChannelFeatures::empty(), + fee_msat, + cltv_expiry_delta, + maybe_announced_channel: true, + }; + let build_path_hops = |first_hop_node_id, first_hop_scid, second_hop_scid| { + vec![ + hop(first_hop_node_id, first_hop_scid, 1000, 48), + hop(carol_node_id, second_hop_scid, 0, trampoline_cltv + excess_final_cltv), + ] + }; - // Alice (unreachable) - TrampolineHop { - pubkey: alice_node_id, - node_features: Features::empty(), - fee_msat: amt_msat, - cltv_expiry_delta: 24 + 39, - }, - ], - hops: vec![BlindedHop{ - // Fake public key - blinded_node_id: alice_node_id, - encrypted_payload: vec![], - }], - blinding_point: alice_node_id, - excess_final_cltv_expiry_delta: 39, - final_value_msat: amt_msat, - }) - }], + let placeholder_tail = fwd_tail(); + let mut route = Route { + paths: vec![ + Path { + hops: build_path_hops(bob_node_id, alice_bob_scid, bob_carol_scid), + blinded_tail: Some(placeholder_tail.clone()), + }, + Path { + hops: build_path_hops(barry_node_id, alice_barry_scid, barry_carol_scid), + blinded_tail: Some(placeholder_tail), + }, + ], route_params: None, }; - nodes[0].node.send_payment_with_route(route.clone(), payment_hash, RecipientOnionFields::spontaneous_empty(amt_msat), PaymentId(payment_hash.0)).unwrap(); + let cur_height = nodes[0].best_block_info().1 + 1; + let payment_id = PaymentId(payment_hash.0); + let onion = RecipientOnionFields::secret_only(payment_secret, total_amt); + let session_privs = nodes[0] + .node + .test_add_new_pending_payment(payment_hash, onion.clone(), payment_id, &route) + .unwrap(); - check_added_monitors(&nodes[0], 1); + route.paths[0].blinded_tail = Some(fwd_tail()); + route.paths[1].blinded_tail = Some(fwd_tail()); + + for (i, path) in route.paths.iter().enumerate() { + nodes[0] + .node + .test_send_payment_along_path( + path, + &payment_hash, + onion.clone(), + cur_height, + payment_id, + &None, + session_privs[i], + ) + .unwrap(); + check_added_monitors(&nodes[0], 1); + } let mut events = nodes[0].node.get_and_clear_pending_msg_events(); - assert_eq!(events.len(), 1); - let first_message_event = remove_first_msg_event_to_node(&nodes[1].node.get_our_node_id(), &mut events); + assert_eq!(events.len(), 2); + let ev_bob = remove_first_msg_event_to_node(&bob_node_id, &mut events); + let ev_barry = remove_first_msg_event_to_node(&barry_node_id, &mut events); + (payment_hash, per_path_amt, ev_bob, ev_barry) +} - let route: &[&Node] = &[&nodes[1], &nodes[2]]; - let args = PassAlongPathArgs::new(&nodes[0], route, amt_msat, payment_hash, first_message_event) - .with_payment_preimage(payment_preimage) - .without_claimable_event() - .expect_failure(HTLCHandlingFailureType::Receive { payment_hash }); +/// How an incomplete trampoline MPP times out (if at all). +enum TrampolineTimeout { + /// Tick timers until MPP timeout fires. + Ticks, + /// Mine blocks until on-chain CLTV timeout fires. + OnChain, +} + +fn do_trampoline_mpp_test(timeout: Option) { + let chanmon_cfgs = create_chanmon_cfgs(4); + let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &vec![None; 4]); + let nodes = create_network(4, &node_cfgs, &node_chanmgrs); + + let (payment_hash, per_path_amt, ev_bob, ev_barry) = send_trampoline_mpp_payment(&nodes); + let send_both = timeout.is_none(); + + let bob_path: &[&Node] = &[&nodes[1], &nodes[2]]; + let barry_path: &[&Node] = &[&nodes[3], &nodes[2]]; + + // Pass first part along Alice -> Bob -> Carol. + let args = PassAlongPathArgs::new(&nodes[0], bob_path, per_path_amt, payment_hash, ev_bob) + .without_claimable_event(); do_pass_along_path(args); - { - let unblinded_node_updates = get_htlc_update_msgs(&nodes[2], &nodes[1].node.get_our_node_id()); - nodes[1].node.handle_update_fail_htlc( - nodes[2].node.get_our_node_id(), &unblinded_node_updates.update_fail_htlcs[0] - ); - do_commitment_signed_dance(&nodes[1], &nodes[2], &unblinded_node_updates.commitment_signed, true, false); + // Either complete the MPP (triggering trampoline rejection) or trigger a timeout. + let expected_reason = match timeout { + None => { + let args = + PassAlongPathArgs::new(&nodes[0], barry_path, per_path_amt, payment_hash, ev_barry) + .without_clearing_recipient_events(); + do_pass_along_path(args); + LocalHTLCFailureReason::TemporaryTrampolineFailure + }, + Some(TrampolineTimeout::Ticks) => { + for _ in 0..MPP_TIMEOUT_TICKS { + nodes[2].node.timer_tick_occurred(); + } + LocalHTLCFailureReason::MPPTimeout + }, + Some(TrampolineTimeout::OnChain) => { + let current_height = nodes[2].best_block_info().1; + connect_blocks(&nodes[2], 200 - current_height); + LocalHTLCFailureReason::CLTVExpiryTooSoon + }, + }; + + // Carol rejects the trampoline forward (either after MPP completion or timeout). + let events = nodes[2].node.get_and_clear_pending_events(); + assert_eq!(events.len(), 1); + match events[0] { + crate::events::Event::HTLCHandlingFailed { + ref failure_type, ref failure_reason, .. + } => { + assert_eq!(failure_type, &HTLCHandlingFailureType::TrampolineForward {}); + match failure_reason { + Some(crate::events::HTLCHandlingFailureReason::Local { reason }) => { + assert_eq!(*reason, expected_reason) + }, + Some(_) | None => panic!("expected failure_reason for failed trampoline"), + } + }, + _ => panic!("Unexpected destination"), } - { - let unblinded_node_updates = get_htlc_update_msgs(&nodes[1], &nodes[0].node.get_our_node_id()); - nodes[0].node.handle_update_fail_htlc( - nodes[1].node.get_our_node_id(), &unblinded_node_updates.update_fail_htlcs[0] - ); - do_commitment_signed_dance(&nodes[0], &nodes[1], &unblinded_node_updates.commitment_signed, false, false); + expect_and_process_pending_htlcs(&nodes[2], false); + assert!(nodes[2].node.get_and_clear_pending_events().is_empty()); + + // Propagate failures back through each forwarded path to Alice. + let both: [&[&Node]; 2] = [bob_path, barry_path]; + let one: [&[&Node]; 1] = [bob_path]; + let forwarded: &[&[&Node]] = if send_both { &both } else { &one }; + let carol_id = nodes[2].node.get_our_node_id(); + check_added_monitors(&nodes[2], forwarded.len()); + let mut carol_msgs = nodes[2].node.get_and_clear_pending_msg_events(); + assert_eq!(carol_msgs.len(), forwarded.len()); + for path in forwarded { + let hop = path[0]; + let hop_id = hop.node.get_our_node_id(); + let ev = remove_first_msg_event_to_node(&hop_id, &mut carol_msgs); + let updates = match ev { + MessageSendEvent::UpdateHTLCs { updates, .. } => updates, + _ => panic!("Expected UpdateHTLCs"), + }; + hop.node.handle_update_fail_htlc(carol_id, &updates.update_fail_htlcs[0]); + do_commitment_signed_dance(hop, &nodes[2], &updates.commitment_signed, true, false); + + let fwd = get_htlc_update_msgs(hop, &nodes[0].node.get_our_node_id()); + nodes[0].node.handle_update_fail_htlc(hop_id, &fwd.update_fail_htlcs[0]); + do_commitment_signed_dance(&nodes[0], hop, &fwd.commitment_signed, false, false); } - { - // Expect UnknownNextPeer error while we are unable to route forwarding Trampoline payments. - let payment_failed_conditions = PaymentFailedConditions::new() - .expected_htlc_error_data(LocalHTLCFailureReason::UnknownNextPeer, &[0; 0]); - expect_payment_failed_conditions(&nodes[0], payment_hash, false, payment_failed_conditions); + + // Check Alice's failure events. + let events = nodes[0].node.get_and_clear_pending_events(); + assert_eq!(events.len(), if send_both { 3 } else { 1 }); + for ev in &events[..forwarded.len()] { + match ev { + Event::PaymentPathFailed { payment_hash: h, payment_failed_permanently, .. } => { + assert_eq!(*h, payment_hash); + assert!(!payment_failed_permanently); + }, + _ => panic!("Expected PaymentPathFailed, got {:?}", ev), + } + } + if send_both { + match &events[2] { + Event::PaymentFailed { payment_hash: h, reason, .. } => { + assert_eq!(*h, Some(payment_hash)); + assert_eq!(*reason, Some(PaymentFailureReason::RetriesExhausted)); + }, + _ => panic!("Expected PaymentFailed, got {:?}", events[2]), + } + + // Verify no spurious timeout fires after the MPP set was dispatched. + for _ in 0..(MPP_TIMEOUT_TICKS * 3) { + nodes[2].node.timer_tick_occurred(); + } + assert!(nodes[2].node.get_and_clear_pending_events().is_empty()); } } + +#[test] +fn test_trampoline_mpp_receive_success() { + do_trampoline_mpp_test(None); +} + +#[test] +fn test_trampoline_mpp_timeout_partial() { + do_trampoline_mpp_test(Some(TrampolineTimeout::Ticks)); +} + +#[test] +fn test_trampoline_mpp_onchain_timeout() { + do_trampoline_mpp_test(Some(TrampolineTimeout::OnChain)); +} diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index f8b5ef32fc3..f010112871e 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -89,9 +89,9 @@ use crate::ln::outbound_payment; #[cfg(any(test, feature = "_externalize_tests"))] use crate::ln::outbound_payment::PaymentSendFailure; use crate::ln::outbound_payment::{ - Bolt11PaymentError, Bolt12PaymentError, OutboundPayments, PendingOutboundPayment, - ProbeSendFailure, RecipientCustomTlvs, RecipientOnionFields, Retry, RetryableInvoiceRequest, - RetryableSendFailure, SendAlongPathArgs, StaleExpiration, + Bolt11PaymentError, Bolt12PaymentError, NextTrampolineHopInfo, OutboundPayments, + PendingOutboundPayment, ProbeSendFailure, RecipientCustomTlvs, RecipientOnionFields, Retry, + RetryableInvoiceRequest, RetryableSendFailure, SendAlongPathArgs, StaleExpiration, }; use crate::ln::types::ChannelId; use crate::offers::async_receive_offer_cache::AsyncReceiveOfferCache; @@ -231,11 +231,12 @@ pub enum PendingHTLCRouting { }, /// An HTLC which should be forwarded on to another Trampoline node. TrampolineForward { - /// The onion shared secret we build with the sender (or the preceding Trampoline node) used - /// to decrypt the onion. + /// The onion shared secret we build with the node that forwarded us this trampoline + /// forward (either the original sender, or a preceding Trampoline node), used to decrypt + /// the inner trampoline onion. /// /// This is later used to encrypt failure packets in the event that the HTLC is failed. - incoming_shared_secret: [u8; 32], + trampoline_shared_secret: [u8; 32], /// The onion which should be included in the forwarded HTLC, telling the next hop what to /// do with the HTLC. onion_packet: msgs::TrampolineOnionPacket, @@ -245,6 +246,12 @@ pub enum PendingHTLCRouting { blinded: Option, /// The absolute CLTV of the inbound HTLC incoming_cltv_expiry: u32, + /// MPP data for accumulating incoming HTLCs before dispatching an outbound payment. + incoming_multipath_data: Option, + /// The amount that the next trampoline is expecting to receive. + next_trampoline_amt_msat: u64, + /// The CLTV expiry height that the next trampoline is expecting to receive. + next_trampoline_cltv_expiry: u32, }, /// The onion indicates that this is a payment for an invoice (supposedly) generated by us. /// @@ -474,6 +481,9 @@ impl PendingAddHTLCInfo { PendingHTLCRouting::Receive { trampoline_shared_secret, .. } => { trampoline_shared_secret }, + PendingHTLCRouting::TrampolineForward { trampoline_shared_secret, .. } => { + Some(trampoline_shared_secret) + }, _ => None, }; @@ -514,7 +524,7 @@ pub enum BlindedFailure { } #[derive(PartialEq, Eq)] -enum OnionPayload { +pub(crate) enum OnionPayload { /// Indicates this incoming onion payload is for the purpose of paying an invoice. Invoice { /// This is only here for backwards-compatibility in serialization, in the future it can be @@ -523,11 +533,13 @@ enum OnionPayload { }, /// Contains the payer-provided preimage. Spontaneous(PaymentPreimage), + /// Indicates that the incoming onion payload is for a trampoline forward. + Trampoline { next_hop_info: NextTrampolineHopInfo, next_trampoline: PublicKey }, } /// HTLCs that are to us and can be failed/claimed by the user #[derive(PartialEq, Eq)] -struct ClaimableHTLC { +pub(crate) struct ClaimableHTLC { prev_hop: HTLCPreviousHopData, cltv_expiry: u32, /// The amount (in msats) of this MPP part @@ -544,6 +556,36 @@ struct ClaimableHTLC { counterparty_skimmed_fee_msat: Option, } +impl ClaimableHTLC { + pub(crate) fn new( + prev_hop: HTLCPreviousHopData, value: u64, sender_intended_value: u64, cltv_expiry: u32, + onion_payload: OnionPayload, counterparty_skimmed_fee_msat: Option, + ) -> Self { + ClaimableHTLC { + prev_hop, + cltv_expiry, + value, + sender_intended_value, + onion_payload, + timer_ticks: 0, + total_value_received: None, + counterparty_skimmed_fee_msat, + } + } + + // Increments timer ticks and returns a boolean indicating whether HTLC is timed out. + fn mpp_timer_tick(&mut self) -> bool { + self.timer_ticks += 1; + self.timer_ticks >= MPP_TIMEOUT_TICKS + } + + /// Returns a boolean indicating whether the HTLC has timed out on chain, accounting for a buffer + /// that gives us time to resolve it. + fn check_onchain_timeout(&self, height: u32, buffer: u32) -> bool { + height >= self.cltv_expiry - buffer + } +} + impl From<&ClaimableHTLC> for events::ClaimedHTLC { fn from(val: &ClaimableHTLC) -> Self { events::ClaimedHTLC { @@ -820,7 +862,6 @@ mod fuzzy_channelmanager { /// We might be forwarding an incoming payment that was received over MPP, and therefore /// need to store the vector of corresponding `HTLCPreviousHopData` values. previous_hop_data: Vec, - incoming_trampoline_shared_secret: [u8; 32], /// Track outbound payment details once the payment has been dispatched, will be `None` /// when waiting for incoming MPP to accumulate. outbound_payment: Option, @@ -857,6 +898,14 @@ mod fuzzy_channelmanager { }, } } + + pub fn previous_hop_data(&self) -> &[HTLCPreviousHopData] { + match self { + HTLCSource::PreviousHopData(prev_hop) => core::slice::from_ref(prev_hop), + HTLCSource::TrampolineForward { previous_hop_data, .. } => &previous_hop_data[..], + HTLCSource::OutboundRoute { .. } => &[], + } + } } /// Tracks the inbound corresponding to an outbound HTLC @@ -917,14 +966,9 @@ impl core::hash::Hash for HTLCSource { first_hop_htlc_msat.hash(hasher); bolt12_invoice.hash(hasher); }, - HTLCSource::TrampolineForward { - previous_hop_data, - incoming_trampoline_shared_secret, - outbound_payment, - } => { + HTLCSource::TrampolineForward { previous_hop_data, outbound_payment } => { 2u8.hash(hasher); previous_hop_data.hash(hasher); - incoming_trampoline_shared_secret.hash(hasher); if let Some(payment) = outbound_payment { payment.payment_id.hash(hasher); payment.path.hash(hasher); @@ -1231,6 +1275,25 @@ impl ClaimablePayment { .map(|htlc| (htlc.prev_hop.channel_id, htlc.prev_hop.user_channel_id)) .collect() } + + /// Returns the total counterparty skimmed fee across all HTLCs. + fn total_counterparty_skimmed_msat(&self) -> u64 { + self.htlcs.iter().map(|htlc| htlc.counterparty_skimmed_fee_msat.unwrap_or(0)).sum() + } +} + +/// Increments MPP timeout tick for all HTLCs and returns a boolean indicating whether the HTLC +/// set has hit its MPP timeout. Will return false if the set have reached the sender's intended +/// total, as the MPP has completed in this case. +fn check_mpp_timeout(payment: &mut ClaimablePayment) -> bool { + // This condition determining whether the MPP is complete here must match exactly the condition + // used in `process_pending_htlc_forwards`. + let total_intended_recvd_value = payment.htlcs.iter().map(|h| h.sender_intended_value).sum(); + let total_mpp_value = payment.onion_fields.total_mpp_amount_msat; + if total_mpp_value <= total_intended_recvd_value { + return false; + } + payment.htlcs.iter_mut().any(|htlc| htlc.mpp_timer_tick()) } /// Represent the channel funding transaction type. @@ -2810,6 +2873,10 @@ pub struct ChannelManager< /// [`ClaimablePayments`]' individual field docs for more info. claimable_payments: Mutex, + /// The sets of trampoline payments which are in the process of being accumulated on inbound + /// channel(s). + awaiting_trampoline_forwards: Mutex>, + /// The set of outbound SCID aliases across all our channels, including unconfirmed channels /// and some closed channels which reached a usable state prior to being closed. This is used /// only to avoid duplicates, and is not persisted explicitly to disk, but rebuilt from the @@ -3602,6 +3669,7 @@ impl< forward_htlcs: Mutex::new(new_hash_map()), decode_update_add_htlcs: Mutex::new(new_hash_map()), claimable_payments: Mutex::new(ClaimablePayments { claimable_payments: new_hash_map(), pending_claiming_payments: new_hash_map() }), + awaiting_trampoline_forwards: Mutex::new(new_hash_map()), pending_intercepted_htlcs: Mutex::new(new_hash_map()), short_to_chan_info: FairRwLock::new(new_hash_map()), @@ -5110,6 +5178,7 @@ impl< fn can_forward_htlc_should_intercept( &self, msg: &msgs::UpdateAddHTLC, prev_chan_public: bool, next_hop: &NextPacketDetails, ) -> Result { + let cur_height = self.best_block.read().unwrap().height + 1; let outgoing_scid = match next_hop.outgoing_connector { HopConnector::ShortChannelId(scid) => scid, HopConnector::Dummy => { @@ -5117,8 +5186,24 @@ impl< debug_assert!(false, "Dummy hop reached HTLC handling."); return Err(LocalHTLCFailureReason::InvalidOnionPayload); }, + // We can't make forwarding checks on trampoline forwards where we don't know the + // outgoing channel on receipt of the incoming htlc. Our trampoline logic will check + // our required delta and fee later on, so here we just check that the forwarding node + // did not "skim" off some of the sender's intended fee/cltv. HopConnector::Trampoline(_) => { - return Err(LocalHTLCFailureReason::InvalidTrampolineForward); + if msg.amount_msat < next_hop.outgoing_amt_msat { + return Err(LocalHTLCFailureReason::FeeInsufficient); + } + + check_incoming_htlc_cltv( + cur_height, + next_hop.outgoing_cltv_value, + msg.cltv_expiry, + 0, + )?; + + // TODO: add interception flag specifically for trampoline + return Ok(false); }, }; // TODO: We do the fake SCID namespace check a bunch of times here (and indirectly via @@ -5157,9 +5242,12 @@ impl< }, }; - let cur_height = self.best_block.read().unwrap().height + 1; - check_incoming_htlc_cltv(cur_height, next_hop.outgoing_cltv_value, msg.cltv_expiry)?; - + check_incoming_htlc_cltv( + cur_height, + next_hop.outgoing_cltv_value, + msg.cltv_expiry, + MIN_CLTV_EXPIRY_DELTA.into(), + )?; Ok(intercept) } @@ -5700,6 +5788,20 @@ impl< self.pending_outbound_payments.test_set_payment_metadata(payment_id, new_payment_metadata); } + #[cfg(test)] + pub(crate) fn test_handle_trampoline_htlc( + &self, claimable_htlc: ClaimableHTLC, onion_fields: RecipientOnionFields, + payment_hash: PaymentHash, next_hop_info: NextTrampolineHopInfo, next_node_id: PublicKey, + ) -> Result<(), (HTLCSource, onion_utils::HTLCFailReason)> { + self.handle_trampoline_htlc( + claimable_htlc, + onion_fields, + payment_hash, + next_hop_info, + next_node_id, + ) + } + /// Pays a [`Bolt11Invoice`] associated with the `payment_id`. See [`Self::send_payment`] for more info. /// /// # Payment Id @@ -8192,6 +8294,259 @@ impl< } } + // Checks whether an incoming htlc can be added to our [`claimable_payments`], and handles + // MPP accumulation. On successful add, returns Ok() with a boolean indicating whether all + // MPP parts have arrrived. Callers *MUST NOT* fail htlcs if Ok(..) is returned. + fn check_claimable_incoming_htlc( + &self, claimable_payment: &mut ClaimablePayment, claimable_htlc: ClaimableHTLC, + mut onion_fields: RecipientOnionFields, payment_hash: PaymentHash, + ) -> Result { + let onions_compatible = claimable_payment.onion_fields.check_merge(&mut onion_fields); + if onions_compatible.is_err() { + return Err(()); + } + let mut total_intended_recvd_value = claimable_htlc.sender_intended_value; + let mut earliest_expiry = claimable_htlc.cltv_expiry; + for htlc in claimable_payment.htlcs.iter() { + total_intended_recvd_value += htlc.sender_intended_value; + earliest_expiry = cmp::min(earliest_expiry, htlc.cltv_expiry); + if total_intended_recvd_value >= msgs::MAX_VALUE_MSAT { + break; + } + } + let total_mpp_value = claimable_payment.onion_fields.total_mpp_amount_msat; + // The condition determining whether an MPP is complete must + // match exactly the condition used in `timer_tick_occurred` + if total_intended_recvd_value >= msgs::MAX_VALUE_MSAT { + return Err(()); + } else if total_intended_recvd_value - claimable_htlc.sender_intended_value + >= total_mpp_value + { + log_trace!( + self.logger, + "Failing HTLC with payment_hash {} as payment is already claimable", + &payment_hash + ); + return Err(()); + } else if total_intended_recvd_value >= total_mpp_value { + claimable_payment.htlcs.push(claimable_htlc); + let amount_msat = claimable_payment.htlcs.iter().map(|htlc| htlc.value).sum(); + claimable_payment + .htlcs + .iter_mut() + .for_each(|htlc| htlc.total_value_received = Some(amount_msat)); + let counterparty_skimmed_fee_msat = claimable_payment + .htlcs + .iter() + .map(|htlc| htlc.counterparty_skimmed_fee_msat.unwrap_or(0)) + .sum(); + debug_assert!( + total_intended_recvd_value.saturating_sub(amount_msat) + <= counterparty_skimmed_fee_msat + ); + claimable_payment.htlcs.sort(); + Ok(true) + } else { + // Nothing to do - we haven't reached the total + // payment value yet, wait until we receive more + // MPP parts. + claimable_payment.htlcs.push(claimable_htlc); + Ok(false) + } + } + + // Handles the addition of a HTLC associated with a payment we're receiving. Err(bool) indicates + // whether we have failed after adding committing to the HTLC - callers should assert that this + // value is false. + fn handle_claimable_htlc( + &self, purpose: events::PaymentPurpose, claimable_htlc: ClaimableHTLC, + onion_fields: RecipientOnionFields, payment_hash: PaymentHash, receiver_node_id: PublicKey, + new_events: &mut VecDeque<(Event, Option)>, + ) -> Result<(), bool> { + let mut committed_to_claimable = false; + + let mut claimable_payments = self.claimable_payments.lock().unwrap(); + if claimable_payments.pending_claiming_payments.contains_key(&payment_hash) { + return Err(committed_to_claimable); + } + + let ref mut claimable_payment = claimable_payments + .claimable_payments + .entry(payment_hash) + // Note that if we insert here we MUST NOT fail_htlc!() + .or_insert_with(|| { + committed_to_claimable = true; + ClaimablePayment { + purpose: purpose.clone(), + htlcs: Vec::new(), + onion_fields: onion_fields.clone(), + } + }); + + let is_keysend = purpose.is_keysend(); + if purpose != claimable_payment.purpose { + let log_keysend = |keysend| if keysend { "keysend" } else { "non-keysend" }; + log_trace!(self.logger, "Failing new {} HTLC with payment_hash {} as we already had an existing {} HTLC with the same payment hash", log_keysend(is_keysend), &payment_hash, log_keysend(!is_keysend)); + return Err(committed_to_claimable); + } + + if self + .check_claimable_incoming_htlc( + claimable_payment, + claimable_htlc, + onion_fields, + payment_hash, + ) + .map_err(|_| committed_to_claimable)? + { + new_events.push_back(( + events::Event::PaymentClaimable { + receiver_node_id: Some(receiver_node_id), + payment_hash, + purpose, + amount_msat: claimable_payment.htlcs.iter().map(|htlc| htlc.value).sum(), + counterparty_skimmed_fee_msat: claimable_payment + .total_counterparty_skimmed_msat(), + receiving_channel_ids: claimable_payment.receiving_channel_ids(), + claim_deadline: Some( + claimable_payment.htlcs.iter().map(|h| h.cltv_expiry).min().unwrap() // TODO: don't unwrap + - HTLC_FAIL_BACK_BUFFER, + ), + onion_fields: Some(claimable_payment.onion_fields.clone()), + payment_id: Some( + claimable_payment.inbound_payment_id(&self.inbound_payment_id_secret), + ), + }, + None, + )); + } + Ok(()) + } + + // Handles the addition of a HTLC associated with a trampoline forward that we need to accumulate + // on the incoming link before forwarding onwards. If the HTLC is failed, it returns the source + // and error that should be used to fail the HTLC(s) back. + fn handle_trampoline_htlc( + &self, claimable_htlc: ClaimableHTLC, onion_fields: RecipientOnionFields, + payment_hash: PaymentHash, next_hop_info: NextTrampolineHopInfo, _next_node_id: PublicKey, + ) -> Result<(), (HTLCSource, HTLCFailReason)> { + let mut trampoline_payments = self.awaiting_trampoline_forwards.lock().unwrap(); + + let mut committed_to_claimable = false; + let claimable_payment = trampoline_payments.entry(payment_hash).or_insert_with(|| { + committed_to_claimable = true; + ClaimablePayment { + purpose: events::PaymentPurpose::Trampoline {}, + htlcs: Vec::new(), + onion_fields: onion_fields.clone(), + } + }); + + // If MPP hasn't fully arrived yet, return early (saving indentation below). + let prev_hop = claimable_htlc.prev_hop.clone(); + if !self + .check_claimable_incoming_htlc( + claimable_payment, + claimable_htlc, + onion_fields, + payment_hash, + ) + .map_err(|_| { + debug_assert!(!committed_to_claimable); + ( + // When we couldn't add a new HTLC, we just fail back our last received htlc, + // allowing others to wait for more MPP parts to arrive. If this was the first + // htlc we'll eventually clean up the awaiting_trampoline_forwards entry in + // our MPP timeout logic. + HTLCSource::TrampolineForward { + previous_hop_data: vec![prev_hop], + outbound_payment: None, + }, + HTLCFailReason::reason( + LocalHTLCFailureReason::InvalidTrampolineForward, + vec![], + ), + ) + })? { + return Ok(()); + } + + let incoming_amt_msat: u64 = claimable_payment.htlcs.iter().map(|h| h.value).sum(); + let incoming_cltv_expiry = + claimable_payment.htlcs.iter().map(|h| h.cltv_expiry).min().unwrap(); + + let (forwarding_fee_proportional_millionths, forwarding_fee_base_msat, cltv_delta) = { + let config = self.config.read().unwrap(); + ( + config.channel_config.forwarding_fee_proportional_millionths, + config.channel_config.forwarding_fee_base_msat, + config.channel_config.cltv_expiry_delta as u32, + ) + }; + + let proportional_fee = (forwarding_fee_proportional_millionths as u128 + * next_hop_info.amount_msat as u128 + / 1_000_000) as u64; + let our_forwarding_fee_msat = proportional_fee + forwarding_fee_base_msat as u64; + + let trampoline_source = || -> HTLCSource { + HTLCSource::TrampolineForward { + previous_hop_data: claimable_payment + .htlcs + .iter() + .map(|htlc| htlc.prev_hop.clone()) + .collect(), + outbound_payment: None, + } + }; + let trampoline_failure = || -> HTLCFailReason { + let mut err_data = Vec::with_capacity(10); + err_data.extend_from_slice(&forwarding_fee_base_msat.to_be_bytes()); + err_data.extend_from_slice(&forwarding_fee_proportional_millionths.to_be_bytes()); + err_data.extend_from_slice(&(cltv_delta as u16).to_be_bytes()); + HTLCFailReason::reason( + LocalHTLCFailureReason::TrampolineFeeOrExpiryInsufficient, + err_data, + ) + }; + + let _max_total_routing_fee_msat = match incoming_amt_msat + .checked_sub(our_forwarding_fee_msat + next_hop_info.amount_msat) + { + Some(amount) => amount, + None => { + return Err((trampoline_source(), trampoline_failure())); + }, + }; + + let _max_total_cltv_expiry_delta = + match incoming_cltv_expiry.checked_sub(next_hop_info.cltv_expiry_height + cltv_delta) { + Some(cltv_delta) => cltv_delta, + None => { + return Err((trampoline_source(), trampoline_failure())); + }, + }; + + log_debug!( + self.logger, + "Rejecting trampoline forward because we do not fully support forwarding yet.", + ); + + let source = trampoline_source(); + if trampoline_payments.remove(&payment_hash).is_none() { + log_error!( + &self.logger, + "Dispatched trampoline payment: {} was not present in awaiting inbound", + payment_hash + ); + } + + Err(( + source, + HTLCFailReason::reason(LocalHTLCFailureReason::TemporaryTrampolineFailure, vec![]), + )) + } + fn process_receive_htlcs( &self, pending_forwards: &mut Vec, new_events: &mut VecDeque<(Event, Option)>, @@ -8222,7 +8577,7 @@ impl< payment_data, payment_context, phantom_shared_secret, - mut onion_fields, + onion_fields, has_recipient_created_payment_secret, invoice_request_opt, trampoline_shared_secret, @@ -8290,57 +8645,117 @@ impl< None, ) }, + PendingHTLCRouting::TrampolineForward { + trampoline_shared_secret: incoming_trampoline_shared_secret, + onion_packet, + node_id: next_trampoline, + blinded, + incoming_cltv_expiry, + incoming_multipath_data, + next_trampoline_amt_msat, + next_trampoline_cltv_expiry, + } => { + // Trampoline forwards only *need* to have MPP data if they're + // multi-part. + let onion_fields = match incoming_multipath_data { + Some(ref final_mpp) => RecipientOnionFields::secret_only( + final_mpp.payment_secret, + final_mpp.total_msat, + ), + None => RecipientOnionFields::spontaneous_empty(outgoing_amt_msat), + }; + ( + incoming_cltv_expiry, + OnionPayload::Trampoline { + next_hop_info: NextTrampolineHopInfo { + onion_packet, + blinding_point: blinded.and_then(|b| { + b.next_blinding_override.or_else(|| { + let encrypted_tlvs_ss = self + .node_signer + .ecdh( + Recipient::Node, + &b.inbound_blinding_point, + None, + ) + .unwrap() + .secret_bytes(); + onion_utils::next_hop_pubkey( + &self.secp_ctx, + b.inbound_blinding_point, + &encrypted_tlvs_ss, + ) + .ok() + }) + }), + amount_msat: next_trampoline_amt_msat, + cltv_expiry_height: next_trampoline_cltv_expiry, + }, + next_trampoline, + }, + incoming_multipath_data, + None, + None, + onion_fields, + false, + None, + Some(incoming_trampoline_shared_secret), + ) + }, _ => { panic!("short_channel_id == 0 should imply any pending_forward entries are of type Receive"); }, }; - let claimable_htlc = ClaimableHTLC { + let htlc_value = incoming_amt_msat.unwrap_or(outgoing_amt_msat); + let HTLCPreviousHopData { + prev_outbound_scid_alias, + user_channel_id, + counterparty_node_id, + htlc_id, + incoming_packet_shared_secret, + .. + } = prev_hop; + + let claimable_htlc = ClaimableHTLC::new( prev_hop, // We differentiate the received value from the sender intended value // if possible so that we don't prematurely mark MPP payments complete // if routing nodes overpay - value: incoming_amt_msat.unwrap_or(outgoing_amt_msat), - sender_intended_value: outgoing_amt_msat, - timer_ticks: 0, - total_value_received: None, + htlc_value, + outgoing_amt_msat, cltv_expiry, onion_payload, - counterparty_skimmed_fee_msat: skimmed_fee_msat, - }; - - let mut committed_to_claimable = false; + skimmed_fee_msat, + ); - macro_rules! fail_htlc { - ($htlc: expr, $payment_hash: expr) => { - debug_assert!(!committed_to_claimable); + macro_rules! fail_receive_htlc { + ($committed_to_claimable: expr) => { + let htlc_source = HTLCSource::PreviousHopData(HTLCPreviousHopData { + prev_outbound_scid_alias, + user_channel_id, + counterparty_node_id, + channel_id: prev_channel_id, + outpoint: prev_funding_outpoint, + htlc_id, + incoming_packet_shared_secret, + phantom_shared_secret, + trampoline_shared_secret, + blinded_failure, + cltv_expiry: Some(cltv_expiry), + }); + debug_assert!(!$committed_to_claimable); let err_data = invalid_payment_err_data( - $htlc.value, + htlc_value, self.best_block.read().unwrap().height, ); - let counterparty_node_id = $htlc.prev_hop.counterparty_node_id; - let incoming_packet_shared_secret = - $htlc.prev_hop.incoming_packet_shared_secret; - let prev_outbound_scid_alias = $htlc.prev_hop.prev_outbound_scid_alias; failed_forwards.push(( - HTLCSource::PreviousHopData(HTLCPreviousHopData { - prev_outbound_scid_alias, - user_channel_id: $htlc.prev_hop.user_channel_id, - counterparty_node_id, - channel_id: prev_channel_id, - outpoint: prev_funding_outpoint, - htlc_id: $htlc.prev_hop.htlc_id, - incoming_packet_shared_secret, - phantom_shared_secret, - trampoline_shared_secret, - blinded_failure, - cltv_expiry: Some(cltv_expiry), - }), + htlc_source, payment_hash, HTLCFailReason::reason( LocalHTLCFailureReason::IncorrectPaymentDetails, err_data, ), - HTLCHandlingFailureType::Receive { payment_hash: $payment_hash }, + HTLCHandlingFailureType::Receive { payment_hash }, )); continue 'next_forwardable_htlc; }; @@ -8354,94 +8769,6 @@ impl< .expect("Failed to get node_id for phantom node recipient"); } - macro_rules! check_total_value { - ($purpose: expr) => {{ - let mut payment_claimable_generated = false; - let is_keysend = $purpose.is_keysend(); - let mut claimable_payments = self.claimable_payments.lock().unwrap(); - if claimable_payments.pending_claiming_payments.contains_key(&payment_hash) { - fail_htlc!(claimable_htlc, payment_hash); - } - let ref mut claimable_payment = claimable_payments.claimable_payments - .entry(payment_hash) - // Note that if we insert here we MUST NOT fail_htlc!() - .or_insert_with(|| { - committed_to_claimable = true; - ClaimablePayment { - purpose: $purpose.clone(), - htlcs: Vec::new(), - onion_fields: onion_fields.clone(), - } - }); - if $purpose != claimable_payment.purpose { - let log_keysend = |keysend| if keysend { "keysend" } else { "non-keysend" }; - log_trace!(self.logger, "Failing new {} HTLC with payment_hash {} as we already had an existing {} HTLC with the same payment hash", log_keysend(is_keysend), &payment_hash, log_keysend(!is_keysend)); - fail_htlc!(claimable_htlc, payment_hash); - } - let onions_compatible = - claimable_payment.onion_fields.check_merge(&mut onion_fields); - if onions_compatible.is_err() { - fail_htlc!(claimable_htlc, payment_hash); - } - let mut total_intended_recvd_value = - claimable_htlc.sender_intended_value; - let mut earliest_expiry = claimable_htlc.cltv_expiry; - for htlc in claimable_payment.htlcs.iter() { - total_intended_recvd_value += htlc.sender_intended_value; - earliest_expiry = cmp::min(earliest_expiry, htlc.cltv_expiry); - if total_intended_recvd_value >= msgs::MAX_VALUE_MSAT { break; } - } - let total_mpp_value = - claimable_payment.onion_fields.total_mpp_amount_msat; - // The condition determining whether an MPP is complete must - // match exactly the condition used in `timer_tick_occurred` - if total_intended_recvd_value >= msgs::MAX_VALUE_MSAT { - fail_htlc!(claimable_htlc, payment_hash); - } else if total_intended_recvd_value - claimable_htlc.sender_intended_value >= total_mpp_value { - log_trace!(self.logger, "Failing HTLC with payment_hash {} as payment is already claimable", - &payment_hash); - fail_htlc!(claimable_htlc, payment_hash); - } else if total_intended_recvd_value >= total_mpp_value { - #[allow(unused_assignments)] { - committed_to_claimable = true; - } - claimable_payment.htlcs.push(claimable_htlc); - let amount_msat = - claimable_payment.htlcs.iter().map(|htlc| htlc.value).sum(); - claimable_payment.htlcs.iter_mut() - .for_each(|htlc| htlc.total_value_received = Some(amount_msat)); - let counterparty_skimmed_fee_msat = claimable_payment.htlcs.iter() - .map(|htlc| htlc.counterparty_skimmed_fee_msat.unwrap_or(0)).sum(); - debug_assert!(total_intended_recvd_value.saturating_sub(amount_msat) - <= counterparty_skimmed_fee_msat); - claimable_payment.htlcs.sort(); - let payment_id = - claimable_payment.inbound_payment_id(&self.inbound_payment_id_secret); - new_events.push_back((events::Event::PaymentClaimable { - receiver_node_id: Some(receiver_node_id), - payment_hash, - purpose: $purpose, - amount_msat, - counterparty_skimmed_fee_msat, - receiving_channel_ids: claimable_payment.receiving_channel_ids(), - claim_deadline: Some(earliest_expiry - HTLC_FAIL_BACK_BUFFER), - onion_fields: Some(claimable_payment.onion_fields.clone()), - payment_id: Some(payment_id), - }, None)); - payment_claimable_generated = true; - } else { - // Nothing to do - we haven't reached the total - // payment value yet, wait until we receive more - // MPP parts. - claimable_payment.htlcs.push(claimable_htlc); - #[allow(unused_assignments)] { - committed_to_claimable = true; - } - } - payment_claimable_generated - }} - } - // Check that the payment hash and secret are known. Note that we // MUST take care to handle the "unknown payment hash" and // "incorrect payment secret" cases here identically or we'd expose @@ -8461,7 +8788,7 @@ impl< Ok(result) => result, Err(()) => { log_trace!(self.logger, "Failing new HTLC with payment_hash {} as payment verification failed", &payment_hash); - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); }, }; if let Some(min_final_cltv_expiry_delta) = min_final_cltv_expiry_delta { @@ -8471,12 +8798,12 @@ impl< if (cltv_expiry as u64) < expected_min_expiry_height { log_trace!(self.logger, "Failing new HTLC with payment_hash {} as its CLTV expiry was too soon (had {}, earliest expected {})", &payment_hash, cltv_expiry, expected_min_expiry_height); - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); } } payment_preimage } else { - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); } } else { None @@ -8492,10 +8819,20 @@ impl< let purpose = match from_parts_res { Ok(purpose) => purpose, Err(()) => { - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); }, }; - check_total_value!(purpose); + + if let Err(committed_to_claimable) = self.handle_claimable_htlc( + purpose, + claimable_htlc, + onion_fields, + payment_hash, + receiver_node_id, + new_events, + ) { + fail_receive_htlc!(committed_to_claimable); + } }, OnionPayload::Spontaneous(keysend_preimage) => { let purpose = if let Some(PaymentContext::AsyncBolt12Offer( @@ -8509,7 +8846,7 @@ impl< false, "We checked that payment_data is Some above" ); - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); }, }; @@ -8528,13 +8865,13 @@ impl< verified_invreq.amount_msats() { if payment_data.total_msat < invreq_amt_msat { - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); } } verified_invreq }, None => { - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); }, }; let payment_purpose_context = @@ -8550,16 +8887,43 @@ impl< match from_parts_res { Ok(purpose) => purpose, Err(()) => { - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); }, } } else if payment_context.is_some() { log_trace!(self.logger, "Failing new HTLC with payment_hash {}: received a keysend payment to a non-async payments context {:#?}", payment_hash, payment_context); - fail_htlc!(claimable_htlc, payment_hash); + fail_receive_htlc!(false); } else { events::PaymentPurpose::SpontaneousPayment(keysend_preimage) }; - check_total_value!(purpose); + if let Err(committed_to_claimable) = self.handle_claimable_htlc( + purpose, + claimable_htlc, + onion_fields, + payment_hash, + receiver_node_id, + new_events, + ) { + fail_receive_htlc!(committed_to_claimable); + } + }, + OnionPayload::Trampoline { ref next_hop_info, next_trampoline } => { + let next_hop_info = next_hop_info.clone(); + if let Err((htlc_source, failure_reason)) = self.handle_trampoline_htlc( + claimable_htlc, + onion_fields, + payment_hash, + next_hop_info, + next_trampoline, + ) { + failed_forwards.push(( + htlc_source, + payment_hash, + failure_reason, + HTLCHandlingFailureType::TrampolineForward {}, + )); + continue 'next_forwardable_htlc; + } }, } }, @@ -8866,42 +9230,68 @@ impl< self.claimable_payments.lock().unwrap().claimable_payments.retain( |payment_hash, payment| { if payment.htlcs.is_empty() { - // This should be unreachable debug_assert!(false); return false; } if let OnionPayload::Invoice { .. } = payment.htlcs[0].onion_payload { - // Check if we've received all the parts we need for an MPP (the value of the parts adds to total_msat). - // In this case we're not going to handle any timeouts of the parts here. - // This condition determining whether the MPP is complete here must match - // exactly the condition used in `process_pending_htlc_forwards`. - let total_intended_recvd_value = - payment.htlcs.iter().map(|h| h.sender_intended_value).sum(); - let total_mpp_value = payment.onion_fields.total_mpp_amount_msat; - if total_mpp_value <= total_intended_recvd_value { - return true; - } else if payment.htlcs.iter_mut().any(|htlc| { - htlc.timer_ticks += 1; - return htlc.timer_ticks >= MPP_TIMEOUT_TICKS; - }) { - let htlcs = payment - .htlcs - .drain(..) - .map(|htlc: ClaimableHTLC| (htlc.prev_hop, *payment_hash)); - timed_out_mpp_htlcs.extend(htlcs); - return false; + let mpp_timeout = check_mpp_timeout(payment); + if mpp_timeout { + timed_out_mpp_htlcs.extend(payment.htlcs.drain(..).map(|h| { + ( + HTLCSource::PreviousHopData(h.prev_hop), + *payment_hash, + HTLCHandlingFailureType::Receive { + payment_hash: *payment_hash, + }, + ) + })); } + return !mpp_timeout; } true }, ); - for htlc_source in timed_out_mpp_htlcs.drain(..) { - let source = HTLCSource::PreviousHopData(htlc_source.0.clone()); + self.awaiting_trampoline_forwards.lock().unwrap().retain(|payment_hash, payment| { + if payment.htlcs.is_empty() { + debug_assert!(false); + return false; + } + if let OnionPayload::Trampoline { .. } = payment.htlcs[0].onion_payload { + let mpp_timeout = check_mpp_timeout(payment); + if mpp_timeout { + let previous_hop_data = + payment.htlcs.drain(..).map(|claimable| claimable.prev_hop).collect(); + + timed_out_mpp_htlcs.push(( + HTLCSource::TrampolineForward { + previous_hop_data, + outbound_payment: None, + }, + *payment_hash, + HTLCHandlingFailureType::TrampolineForward {}, + )); + } + !mpp_timeout + } else { + debug_assert!( + false, + "awaiting_trampoline_forwards should only contain trampolines" + ); + true + } + }); + + for (htlc_source, payment_hash, failure_type) in timed_out_mpp_htlcs.drain(..) { let failure_reason = LocalHTLCFailureReason::MPPTimeout; let reason = HTLCFailReason::from_failure_code(failure_reason); - let receiver = HTLCHandlingFailureType::Receive { payment_hash: htlc_source.1 }; - self.fail_htlc_backwards_internal(&source, &htlc_source.1, &reason, receiver, None); + self.fail_htlc_backwards_internal( + &htlc_source, + &payment_hash, + &reason, + failure_type, + None, + ); } for (err, counterparty_node_id) in handle_errors { @@ -9196,11 +9586,7 @@ impl< None, )); }, - HTLCSource::TrampolineForward { - previous_hop_data, - incoming_trampoline_shared_secret, - .. - } => { + HTLCSource::TrampolineForward { previous_hop_data, .. } => { let decoded_onion_failure = onion_error.decode_onion_failure(&self.secp_ctx, &self.logger, &source); log_trace!( @@ -9212,8 +9598,6 @@ impl< "unknown channel".to_string() }, ); - let incoming_trampoline_shared_secret = Some(*incoming_trampoline_shared_secret); - // TODO: when we receive a failure from a single outgoing trampoline HTLC, we don't // necessarily want to fail all of our incoming HTLCs back yet. We may have other // outgoing HTLCs that need to resolve first. This will be tracked in our @@ -9225,6 +9609,7 @@ impl< incoming_packet_shared_secret, blinded_failure, channel_id, + trampoline_shared_secret, .. } = current_hop_data; log_trace!( @@ -9236,13 +9621,17 @@ impl< LocalHTLCFailureReason::TemporaryTrampolineFailure, Vec::new(), ); + debug_assert!( + trampoline_shared_secret.is_some(), + "trampoline hop should have secret" + ); push_forward_htlcs_failure( *prev_outbound_scid_alias, get_htlc_forward_failure( blinded_failure, &onion_error, incoming_packet_shared_secret, - &incoming_trampoline_shared_secret, + &trampoline_shared_secret, &None, *htlc_id, ), @@ -12481,15 +12870,8 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ chan.update_fulfill_htlc(&msg), chan_entry ); - let prev_hops = match &res.0 { - HTLCSource::PreviousHopData(prev_hop) => vec![prev_hop], - HTLCSource::TrampolineForward { previous_hop_data, .. } => { - previous_hop_data.iter().collect() - }, - _ => vec![], - }; let logger = WithChannelContext::from(&self.logger, &chan.context, None); - for prev_hop in prev_hops { + for prev_hop in res.0.previous_hop_data() { log_trace!(logger, "Holding the next revoke_and_ack until the preimage is durably persisted in the inbound edge's ChannelMonitor", ); @@ -16126,14 +16508,16 @@ impl< } if let Some(height) = height_opt { + // If height is approaching the number of blocks we think it takes us to get our + // commitment transaction confirmed before the HTLC expires, plus the number of blocks + // we generally consider it to take to do a commitment update, just give up on it and + // fail the HTLC. self.claimable_payments.lock().unwrap().claimable_payments.retain( |payment_hash, payment| { payment.htlcs.retain(|htlc| { - // If height is approaching the number of blocks we think it takes us to get - // our commitment transaction confirmed before the HTLC expires, plus the - // number of blocks we generally consider it to take to do a commitment update, - // just give up on it and fail the HTLC. - if height >= htlc.cltv_expiry - HTLC_FAIL_BACK_BUFFER { + let htlc_timed_out = + htlc.check_onchain_timeout(height, HTLC_FAIL_BACK_BUFFER); + if htlc_timed_out { let reason = LocalHTLCFailureReason::PaymentClaimBuffer; timed_out_htlcs.push(( HTLCSource::PreviousHopData(htlc.prev_hop.clone()), @@ -16146,15 +16530,51 @@ impl< payment_hash: payment_hash.clone(), }, )); - false - } else { - true } + !htlc_timed_out }); !payment.htlcs.is_empty() // Only retain this entry if htlcs has at least one entry. }, ); + self.awaiting_trampoline_forwards.lock().unwrap().retain(|payment_hash, payment| { + if payment.htlcs.is_empty() { + debug_assert!(false); + return false; + } + if let OnionPayload::Trampoline { .. } = payment.htlcs[0].onion_payload { + let htlc_timed_out = payment + .htlcs + .iter() + .any(|htlc| htlc.check_onchain_timeout(height, HTLC_FAIL_BACK_BUFFER)); + if htlc_timed_out { + let previous_hop_data = + payment.htlcs.drain(..).map(|claimable| claimable.prev_hop).collect(); + + let failure_reason = LocalHTLCFailureReason::CLTVExpiryTooSoon; + timed_out_htlcs.push(( + HTLCSource::TrampolineForward { + previous_hop_data, + outbound_payment: None, + }, + *payment_hash, + HTLCFailReason::reason( + failure_reason, + self.get_htlc_inbound_temp_fail_data(failure_reason), + ), + HTLCHandlingFailureType::TrampolineForward {}, + )); + } + !htlc_timed_out + } else { + debug_assert!( + false, + "awaiting_trampoline_forwards should only contain trampolines" + ); + true + } + }); + let mut intercepted_htlcs = self.pending_intercepted_htlcs.lock().unwrap(); intercepted_htlcs.retain(|_, htlc| { if height >= htlc.forward_info.outgoing_cltv_value - HTLC_FAIL_BACK_BUFFER { @@ -17495,11 +17915,14 @@ impl_writeable_tlv_based_enum!(PendingHTLCRouting, (11, invoice_request, option), }, (3, TrampolineForward) => { - (0, incoming_shared_secret, required), + (0, trampoline_shared_secret, required), (2, onion_packet, required), (4, blinded, option), (6, node_id, required), (8, incoming_cltv_expiry, required), + (10, incoming_multipath_data, option), + (12, next_trampoline_amt_msat, required), + (14, next_trampoline_cltv_expiry, required), } ); @@ -17617,9 +18040,14 @@ impl_writeable_tlv_based!(HTLCPreviousHopData, { fn write_claimable_htlc( htlc: &ClaimableHTLC, total_mpp_value_msat: u64, writer: &mut W, ) -> Result<(), io::Error> { - let (payment_data, keysend_preimage) = match &htlc.onion_payload { - OnionPayload::Invoice { _legacy_hop_data } => (_legacy_hop_data.as_ref(), None), - OnionPayload::Spontaneous(preimage) => (None, Some(preimage)), + let (payment_data, keysend_preimage, trampoline_next_hop, trampoline_next_node) = match &htlc + .onion_payload + { + OnionPayload::Invoice { _legacy_hop_data } => (_legacy_hop_data.as_ref(), None, None, None), + OnionPayload::Spontaneous(preimage) => (None, Some(preimage), None, None), + OnionPayload::Trampoline { next_hop_info, next_trampoline } => { + (None, None, Some(next_hop_info), Some(next_trampoline)) + }, }; write_tlv_fields!(writer, { (0, htlc.prev_hop, required), @@ -17631,6 +18059,8 @@ fn write_claimable_htlc( (6, htlc.cltv_expiry, required), (8, keysend_preimage, option), (10, htlc.counterparty_skimmed_fee_msat, option), + (12, trampoline_next_hop, option), + (14, trampoline_next_node, option) }); Ok(()) } @@ -17648,17 +18078,26 @@ impl Readable for (ClaimableHTLC, u64) { (6, cltv_expiry, required), (8, keysend_preimage, option), (10, counterparty_skimmed_fee_msat, option), + (12, trampoline_next_hop, option), + (14, trampoline_next_node, option) }); let payment_data: Option = payment_data_opt; let value = value_ser.0.unwrap(); - let onion_payload = match keysend_preimage { - Some(p) => { + let onion_payload = match (keysend_preimage, trampoline_next_hop, trampoline_next_node) { + (Some(p), None, None) => { if payment_data.is_some() { return Err(DecodeError::InvalidValue) } OnionPayload::Spontaneous(p) }, - None => OnionPayload::Invoice { _legacy_hop_data: payment_data }, + (None, None, None) => OnionPayload::Invoice { _legacy_hop_data: payment_data }, + (None, Some(next_hop_info), Some(next_trampoline)) => { + OnionPayload::Trampoline { + next_hop_info, + next_trampoline, + } + }, + _ => return Err(DecodeError::InvalidValue), }; Ok((ClaimableHTLC { prev_hop: prev_hop.0.unwrap(), @@ -17754,16 +18193,11 @@ impl Writeable for HTLCSource { 1u8.write(writer)?; field.write(writer)?; }, - HTLCSource::TrampolineForward { - ref previous_hop_data, - incoming_trampoline_shared_secret, - ref outbound_payment, - } => { + HTLCSource::TrampolineForward { ref previous_hop_data, ref outbound_payment } => { 2u8.write(writer)?; write_tlv_fields!(writer, { (1, *previous_hop_data, required_vec), - (3, incoming_trampoline_shared_secret, required), - (5, outbound_payment, option), + (3, outbound_payment, option), }); }, } @@ -19716,17 +20150,11 @@ impl< .into_iter() .filter_map(|(htlc_source, (htlc, preimage_opt))| { let payment_preimage = preimage_opt?; - let prev_htlcs = match &htlc_source { - HTLCSource::PreviousHopData(prev_hop) => vec![prev_hop], - HTLCSource::TrampolineForward { previous_hop_data, .. } => { - previous_hop_data.iter().collect() - }, - // If it was an outbound payment, we've handled it above - if a preimage - // came in and we persisted the `ChannelManager` we either handled it - // and are good to go or the channel force-closed - we don't have to - // handle the channel still live case here. - _ => vec![], - }; + // If it was an outbound payment, we've handled it above - if a preimage + // came in and we persisted the `ChannelManager` we either handled it + // and are good to go or the channel force-closed - we don't have to + // handle the channel still live case here. + let prev_htlcs = htlc_source.previous_hop_data(); let prev_htlcs_count = prev_htlcs.len(); if prev_htlcs_count == 0 { return None; @@ -20077,6 +20505,7 @@ impl< claimable_payments, pending_claiming_payments, }), + awaiting_trampoline_forwards: Mutex::new(new_hash_map()), outbound_scid_aliases: Mutex::new(outbound_scid_aliases), short_to_chan_info: FairRwLock::new(short_to_chan_info), fake_scid_rand_bytes: fake_scid_rand_bytes.unwrap(), diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 641842ddaff..73e3071fb23 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -10,7 +10,9 @@ //! A bunch of useful utilities for building networks of nodes and exchanging messages between //! nodes for functional tests. -use crate::blinded_path::payment::DummyTlvs; +use crate::blinded_path::payment::{ + BlindedPaymentPath, DummyTlvs, ForwardNode, ReceiveTlvs, TrampolineForwardTlvs, +}; use crate::chain::channelmonitor::{ChannelMonitor, HTLC_FAIL_BACK_BUFFER}; use crate::chain::transaction::OutPoint; use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Confirm, Listen, Watch}; @@ -40,7 +42,8 @@ use crate::ln::types::ChannelId; use crate::onion_message::messenger::OnionMessenger; use crate::routing::gossip::{NetworkGraph, NetworkUpdate, P2PGossipSync}; use crate::routing::router::{self, PaymentParameters, Route, RouteParameters}; -use crate::sign::{EntropySource, RandomBytes}; +use crate::routing::router::{compute_fees, BlindedTail, TrampolineHop}; +use crate::sign::{EntropySource, RandomBytes, ReceiveAuthKey}; use crate::types::features::ChannelTypeFeatures; use crate::types::features::InitFeatures; use crate::types::payment::{PaymentHash, PaymentPreimage, PaymentSecret}; @@ -3735,6 +3738,9 @@ pub fn do_pass_along_path<'a, 'b, 'c>(args: PassAlongPathArgs) -> Option onion_fields.as_ref().unwrap().payment_secret ); }, + PaymentPurpose::Trampoline {} => { + panic!("Trampoline should not emit PaymentClaimable"); + }, } assert_eq!(*amount_msat, recv_value); let channels = node.node.list_channels(); @@ -5762,3 +5768,47 @@ pub fn get_scid_from_channel_id<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, channel_id: .short_channel_id .unwrap() } + +/// Creates a [`BlindedTail`] for a trampoline forward through a single intermediate node. +/// +/// The resulting tail contains blinded hops built from `intermediate_nodes` plus a dummy receive +/// TLV, with the `TrampolineHop` fee and CLTV derived from the blinded path's aggregated payinfo. +pub fn create_trampoline_forward_blinded_tail( + secp_ctx: &bitcoin::secp256k1::Secp256k1, entropy_source: ES, + intermediate_nodes: &[ForwardNode], payee_node_id: PublicKey, + payee_receive_key: ReceiveAuthKey, payee_tlvs: ReceiveTlvs, min_final_cltv_expiry_delta: u32, + excess_final_cltv_delta: u32, final_value_msat: u64, +) -> BlindedTail { + let blinded_path = BlindedPaymentPath::new_for_trampoline( + intermediate_nodes, + payee_node_id, + payee_receive_key, + payee_tlvs, + u64::max_value(), + min_final_cltv_expiry_delta as u16, + entropy_source, + secp_ctx, + ) + .unwrap(); + + BlindedTail { + trampoline_hops: vec![TrampolineHop { + pubkey: intermediate_nodes.first().map(|n| n.node_id).unwrap_or(payee_node_id), + node_features: types::features::Features::empty(), + fee_msat: compute_fees( + final_value_msat, + lightning_types::routing::RoutingFees { + base_msat: blinded_path.payinfo.fee_base_msat, + proportional_millionths: blinded_path.payinfo.fee_proportional_millionths, + }, + ) + .unwrap(), + cltv_expiry_delta: blinded_path.payinfo.cltv_expiry_delta as u32 + + excess_final_cltv_delta, + }], + hops: blinded_path.blinded_hops().to_vec(), + blinding_point: blinded_path.blinding_point(), + excess_final_cltv_expiry_delta: excess_final_cltv_delta, + final_value_msat, + } +} diff --git a/lightning/src/ln/mod.rs b/lightning/src/ln/mod.rs index d6e0b92f1d0..30a8109fc43 100644 --- a/lightning/src/ln/mod.rs +++ b/lightning/src/ln/mod.rs @@ -118,6 +118,8 @@ mod reorg_tests; mod shutdown_tests; #[cfg(any(feature = "_test_utils", test))] pub mod splicing_tests; +#[cfg(test)] +mod trampoline_forward_tests; #[cfg(any(test, feature = "_externalize_tests"))] #[allow(unused_mut)] pub mod update_fee_tests; diff --git a/lightning/src/ln/onion_payment.rs b/lightning/src/ln/onion_payment.rs index 5111f6982fe..9fed000eb72 100644 --- a/lightning/src/ln/onion_payment.rs +++ b/lightning/src/ln/onion_payment.rs @@ -111,6 +111,9 @@ enum RoutingInfo { next_hop_hmac: [u8; 32], shared_secret: SharedSecret, current_path_key: Option, + incoming_multipath_data: Option, + next_trampoline_amt_msat: u64, + next_trampoline_cltv: u32, }, } @@ -167,24 +170,27 @@ pub(super) fn create_fwd_pending_htlc_info( reason: LocalHTLCFailureReason::InvalidOnionPayload, err_data: Vec::new(), }), - onion_utils::Hop::TrampolineForward { next_trampoline_hop_data, next_trampoline_hop_hmac, new_trampoline_packet_bytes, trampoline_shared_secret, .. } => { + onion_utils::Hop::TrampolineForward { outer_hop_data, next_trampoline_hop_data, next_trampoline_hop_hmac, new_trampoline_packet_bytes, trampoline_shared_secret, .. } => { ( RoutingInfo::Trampoline { next_trampoline: next_trampoline_hop_data.next_trampoline, new_packet_bytes: new_trampoline_packet_bytes, next_hop_hmac: next_trampoline_hop_hmac, shared_secret: trampoline_shared_secret, - current_path_key: None + current_path_key: None, + incoming_multipath_data: outer_hop_data.multipath_trampoline_data, + next_trampoline_amt_msat: next_trampoline_hop_data.amt_to_forward, + next_trampoline_cltv: next_trampoline_hop_data.outgoing_cltv_value, }, - next_trampoline_hop_data.amt_to_forward, - next_trampoline_hop_data.outgoing_cltv_value, + outer_hop_data.amt_to_forward, + outer_hop_data.outgoing_cltv_value, None, None ) }, onion_utils::Hop::TrampolineBlindedForward { outer_hop_data, next_trampoline_hop_data, next_trampoline_hop_hmac, new_trampoline_packet_bytes, trampoline_shared_secret, .. } => { - let (amt_to_forward, outgoing_cltv_value) = check_blinded_forward( - msg.amount_msat, msg.cltv_expiry, &next_trampoline_hop_data.payment_relay, &next_trampoline_hop_data.payment_constraints, &next_trampoline_hop_data.features + let (next_hop_amount, next_hop_cltv) = check_blinded_forward( + outer_hop_data.multipath_trampoline_data.as_ref().map(|f| f.total_msat).unwrap_or(msg.amount_msat), msg.cltv_expiry, &next_trampoline_hop_data.payment_relay, &next_trampoline_hop_data.payment_constraints, &next_trampoline_hop_data.features ).map_err(|()| { // We should be returning malformed here if `msg.blinding_point` is set, but this is // unreachable right now since we checked it in `decode_update_add_htlc_onion`. @@ -200,10 +206,13 @@ pub(super) fn create_fwd_pending_htlc_info( new_packet_bytes: new_trampoline_packet_bytes, next_hop_hmac: next_trampoline_hop_hmac, shared_secret: trampoline_shared_secret, - current_path_key: outer_hop_data.current_path_key + current_path_key: outer_hop_data.current_path_key, + incoming_multipath_data: outer_hop_data.multipath_trampoline_data, + next_trampoline_amt_msat: next_hop_amount, + next_trampoline_cltv: next_hop_cltv, }, - amt_to_forward, - outgoing_cltv_value, + outer_hop_data.amt_to_forward, + outer_hop_data.outgoing_cltv_value, next_trampoline_hop_data.intro_node_blinding_point, next_trampoline_hop_data.next_blinding_override ) @@ -233,7 +242,7 @@ pub(super) fn create_fwd_pending_htlc_info( }), } } - RoutingInfo::Trampoline { next_trampoline, new_packet_bytes, next_hop_hmac, shared_secret, current_path_key } => { + RoutingInfo::Trampoline { next_trampoline, new_packet_bytes, next_hop_hmac, shared_secret, current_path_key, incoming_multipath_data: multipath_trampoline_data, next_trampoline_amt_msat: next_hop_amount, next_trampoline_cltv: next_hop_cltv} => { let next_trampoline_packet_pubkey = match next_packet_pubkey_opt { Some(Ok(pubkey)) => pubkey, _ => return Err(InboundHTLCErr { @@ -249,7 +258,7 @@ pub(super) fn create_fwd_pending_htlc_info( hmac: next_hop_hmac, }; PendingHTLCRouting::TrampolineForward { - incoming_shared_secret: shared_secret.secret_bytes(), + trampoline_shared_secret: shared_secret.secret_bytes(), onion_packet: outgoing_packet, node_id: next_trampoline, incoming_cltv_expiry: msg.cltv_expiry, @@ -260,7 +269,10 @@ pub(super) fn create_fwd_pending_htlc_info( failure: intro_node_blinding_point .map(|_| BlindedFailure::FromIntroductionNode) .unwrap_or(BlindedFailure::FromBlindedNode), - }) + }), + incoming_multipath_data: multipath_trampoline_data, + next_trampoline_amt_msat: next_hop_amount, + next_trampoline_cltv_expiry: next_hop_cltv, } } }; @@ -515,7 +527,7 @@ pub fn peel_payment_onion }; if let Err(reason) = check_incoming_htlc_cltv( - cur_height, outgoing_cltv_value, msg.cltv_expiry, + cur_height, outgoing_cltv_value, msg.cltv_expiry, MIN_CLTV_EXPIRY_DELTA.into(), ) { return Err(InboundHTLCErr { msg: "incoming cltv check failed", @@ -683,33 +695,24 @@ pub(super) fn decode_incoming_update_add_htlc_onion { + onion_utils::Hop::TrampolineForward { next_trampoline_hop_data: msgs::InboundTrampolineForwardPayload { next_trampoline, .. }, ref outer_hop_data, trampoline_shared_secret, incoming_trampoline_public_key, .. } => { let next_trampoline_packet_pubkey = onion_utils::next_hop_pubkey(secp_ctx, incoming_trampoline_public_key, &trampoline_shared_secret.secret_bytes()); Some(NextPacketDetails { next_packet_pubkey: next_trampoline_packet_pubkey, outgoing_connector: HopConnector::Trampoline(next_trampoline), - outgoing_amt_msat: amt_to_forward, - outgoing_cltv_value, + outgoing_amt_msat: outer_hop_data.amt_to_forward, + outgoing_cltv_value: outer_hop_data.outgoing_cltv_value, }) } - onion_utils::Hop::TrampolineBlindedForward { next_trampoline_hop_data: msgs::InboundTrampolineBlindedForwardPayload { next_trampoline, ref payment_relay, ref payment_constraints, ref features, .. }, outer_shared_secret, trampoline_shared_secret, incoming_trampoline_public_key, .. } => { - let (amt_to_forward, outgoing_cltv_value) = match check_blinded_forward( - msg.amount_msat, msg.cltv_expiry, &payment_relay, &payment_constraints, &features - ) { - Ok((amt, cltv)) => (amt, cltv), - Err(()) => { - return encode_relay_error("Underflow calculating outbound amount or cltv value for blinded trampoline forward", - LocalHTLCFailureReason::InvalidOnionBlinding, outer_shared_secret.secret_bytes(), Some(trampoline_shared_secret.secret_bytes()), &[0; 32]); - } - }; + onion_utils::Hop::TrampolineBlindedForward { next_trampoline_hop_data: msgs::InboundTrampolineBlindedForwardPayload { next_trampoline, .. }, ref outer_hop_data, trampoline_shared_secret, incoming_trampoline_public_key, .. } => { let next_trampoline_packet_pubkey = onion_utils::next_hop_pubkey(secp_ctx, incoming_trampoline_public_key, &trampoline_shared_secret.secret_bytes()); Some(NextPacketDetails { next_packet_pubkey: next_trampoline_packet_pubkey, outgoing_connector: HopConnector::Trampoline(next_trampoline), - outgoing_amt_msat: amt_to_forward, - outgoing_cltv_value, + outgoing_amt_msat: outer_hop_data.amt_to_forward, + outgoing_cltv_value: outer_hop_data.outgoing_cltv_value, }) } _ => None @@ -719,9 +722,9 @@ pub(super) fn decode_incoming_update_add_htlc_onion Result<(), LocalHTLCFailureReason> { - if (cltv_expiry as u64) < (outgoing_cltv_value) as u64 + MIN_CLTV_EXPIRY_DELTA as u64 { + if (cltv_expiry as u64) < (outgoing_cltv_value) as u64 + min_cltv_expiry_delta { return Err(LocalHTLCFailureReason::IncorrectCLTVExpiry); } // Theoretically, channel counterparty shouldn't send us a HTLC expiring now, diff --git a/lightning/src/ln/onion_utils.rs b/lightning/src/ln/onion_utils.rs index 9b1b009e93a..bc867e85694 100644 --- a/lightning/src/ln/onion_utils.rs +++ b/lightning/src/ln/onion_utils.rs @@ -1909,7 +1909,7 @@ impl From<&HTLCFailReason> for HTLCHandlingFailureReason { #[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug #[cfg_attr(test, derive(PartialEq))] -pub(super) struct HTLCFailReason(HTLCFailReasonRepr); +pub(crate) struct HTLCFailReason(HTLCFailReasonRepr); #[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug #[cfg_attr(test, derive(PartialEq))] @@ -2124,6 +2124,10 @@ impl HTLCFailReason { let mut err = err.clone(); let hold_time = hold_time.unwrap_or(0); + if let Some(secondary_shared_secret) = secondary_shared_secret { + process_failure_packet(&mut err, secondary_shared_secret, hold_time); + crypt_failure_packet(secondary_shared_secret, &mut err); + } process_failure_packet(&mut err, incoming_packet_shared_secret, hold_time); crypt_failure_packet(incoming_packet_shared_secret, &mut err); @@ -2135,6 +2139,23 @@ impl HTLCFailReason { pub(super) fn decode_onion_failure( &self, secp_ctx: &Secp256k1, logger: &L, htlc_source: &HTLCSource, ) -> DecodedOnionFailure { + macro_rules! decoded_onion_failure { + ($short_channel_id:expr, $failure_reason:expr, $data:expr) => { + DecodedOnionFailure { + network_update: None, + payment_failed_permanently: false, + short_channel_id: $short_channel_id, + failed_within_blinded_path: false, + hold_times: Vec::new(), + #[cfg(any(test, feature = "_test_utils"))] + onion_error_code: Some($failure_reason), + #[cfg(any(test, feature = "_test_utils"))] + onion_error_data: Some($data.clone()), + #[cfg(test)] + attribution_failed_channel: None, + } + }; + } match self.0 { HTLCFailReasonRepr::LightningError { ref err, .. } => { process_onion_failure(secp_ctx, logger, &htlc_source, err.clone()) @@ -2146,22 +2167,19 @@ impl HTLCFailReason { // failures here, but that would be insufficient as find_route // generally ignores its view of our own channels as we provide them via // ChannelDetails. - if let &HTLCSource::OutboundRoute { ref path, .. } = htlc_source { - DecodedOnionFailure { - network_update: None, - payment_failed_permanently: false, - short_channel_id: Some(path.hops[0].short_channel_id), - failed_within_blinded_path: false, - hold_times: Vec::new(), - #[cfg(any(test, feature = "_test_utils"))] - onion_error_code: Some(*failure_reason), - #[cfg(any(test, feature = "_test_utils"))] - onion_error_data: Some(data.clone()), - #[cfg(test)] - attribution_failed_channel: None, - } - } else { - unreachable!(); + match htlc_source { + &HTLCSource::OutboundRoute { ref path, .. } => { + decoded_onion_failure!( + (Some(path.hops[0].short_channel_id)), + *failure_reason, + data + ) + }, + &HTLCSource::TrampolineForward { ref outbound_payment, .. } => { + debug_assert!(outbound_payment.is_none()); + decoded_onion_failure!(None, *failure_reason, data) + }, + _ => unreachable!(), } }, } diff --git a/lightning/src/ln/outbound_payment.rs b/lightning/src/ln/outbound_payment.rs index b08b0f5a886..91728e390c3 100644 --- a/lightning/src/ln/outbound_payment.rs +++ b/lightning/src/ln/outbound_payment.rs @@ -11,7 +11,7 @@ use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::Hash; -use bitcoin::secp256k1::{self, Secp256k1, SecretKey}; +use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; use lightning_invoice::Bolt11Invoice; use crate::blinded_path::{IntroductionNode, NodeIdLookUp}; @@ -21,7 +21,7 @@ use crate::ln::channelmanager::{ EventCompletionAction, HTLCSource, OptionalBolt11PaymentParams, PaymentCompleteUpdate, PaymentId, }; -use crate::ln::msgs::DecodeError; +use crate::ln::msgs::{DecodeError, TrampolineOnionPacket}; use crate::ln::onion_utils; use crate::ln::onion_utils::{DecodedOnionFailure, HTLCFailReason}; use crate::offers::invoice::{Bolt12Invoice, DerivedSigningPubkey, InvoiceBuilder}; @@ -167,6 +167,25 @@ pub(crate) enum PendingOutboundPayment { }, } +#[derive(Clone, Eq, PartialEq)] +pub(crate) struct NextTrampolineHopInfo { + /// The Trampoline packet to include for the next Trampoline hop. + pub(crate) onion_packet: TrampolineOnionPacket, + /// If blinded, the current_path_key to set at the next Trampoline hop. + pub(crate) blinding_point: Option, + /// The amount that the next trampoline is expecting to receive. + pub(crate) amount_msat: u64, + /// The cltv expiry height that the next trampoline is expecting. + pub(crate) cltv_expiry_height: u32, +} + +impl_writeable_tlv_based!(NextTrampolineHopInfo, { + (1, onion_packet, required), + (3, blinding_point, option), + (5, amount_msat, required), + (7, cltv_expiry_height, required), +}); + #[derive(Clone)] pub(crate) struct RetryableInvoiceRequest { pub(crate) invoice_request: InvoiceRequest, diff --git a/lightning/src/ln/trampoline_forward_tests.rs b/lightning/src/ln/trampoline_forward_tests.rs new file mode 100644 index 00000000000..1a1ba42c0f8 --- /dev/null +++ b/lightning/src/ln/trampoline_forward_tests.rs @@ -0,0 +1,197 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! Tests for trampoline MPP accumulation and forwarding validation in +//! [`ChannelManager::handle_trampoline_htlc`]. + +use crate::chain::transaction::OutPoint; +use crate::events::HTLCHandlingFailureReason; +use crate::ln::channelmanager::{ClaimableHTLC, HTLCPreviousHopData, OnionPayload}; +use crate::ln::functional_test_utils::*; +use crate::ln::msgs; +use crate::ln::onion_utils::LocalHTLCFailureReason; +use crate::ln::outbound_payment::{NextTrampolineHopInfo, RecipientOnionFields}; +use crate::ln::types::ChannelId; +use crate::types::payment::{PaymentHash, PaymentSecret}; + +use bitcoin::hashes::Hash; +use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + +fn test_prev_hop_data() -> HTLCPreviousHopData { + HTLCPreviousHopData { + prev_outbound_scid_alias: 0, + user_channel_id: None, + htlc_id: 0, + incoming_packet_shared_secret: [0; 32], + phantom_shared_secret: None, + trampoline_shared_secret: Some([0; 32]), + blinded_failure: None, + channel_id: ChannelId::from_bytes([0; 32]), + outpoint: OutPoint { txid: bitcoin::Txid::all_zeros(), index: 0 }, + counterparty_node_id: None, + cltv_expiry: None, + } +} + +fn test_trampoline_onion_packet() -> msgs::TrampolineOnionPacket { + let secp = Secp256k1::new(); + let test_secret = SecretKey::from_slice(&[42; 32]).unwrap(); + msgs::TrampolineOnionPacket { + version: 0, + public_key: PublicKey::from_secret_key(&secp, &test_secret), + hop_data: vec![0; 650], + hmac: [0; 32], + } +} + +fn test_onion_fields(total_msat: u64) -> RecipientOnionFields { + RecipientOnionFields { + payment_secret: Some(PaymentSecret([0; 32])), + total_mpp_amount_msat: total_msat, + payment_metadata: None, + custom_tlvs: Vec::new(), + } +} + +enum TrampolineMppValidationTestCase { + FeeInsufficient, + CltvInsufficient, + TrampolineAmountExceedsReceived, + TrampolineCLTVExceedsReceived, + MismatchedPaymentSecret, +} + +/// Sends two MPP parts through [`ChannelManager::handle_trampoline_htlc`], testing various MPP +/// validation steps with a base case that succeeds. +fn do_test_trampoline_mpp_validation(test_case: Option) { + let update_add_value: u64 = 500_000; // Actual amount we received in update_add_htlc. + let update_add_cltv: u32 = 500; // Actual CLTV we received in update_add_htlc. + let sender_intended_incoming_value: u64 = 500_000; // Amount we expect for one HTLC, outer onion. + let incoming_mpp_total: u64 = 1_000_000; // Total we expect to receive across MPP parts, outer onion. + let mut next_trampoline_amount: u64 = 750_000; // Total next trampoline expects, inner onion. + let mut next_trampoline_cltv: u32 = 100; // CLTV next trampoline expects, inner onion. + + // By default, set our forwarding fee and CLTV delta to exactly what we're being offered + // for this trampoline forward, so that we can force failures by just adding one. + let mut forwarding_fee_base_msat = incoming_mpp_total - next_trampoline_amount; + let mut cltv_delta = update_add_cltv - next_trampoline_cltv; + let mut mismatch_payment_secret = false; + + let expected = match test_case { + Some(TrampolineMppValidationTestCase::FeeInsufficient) => { + forwarding_fee_base_msat += 1; + LocalHTLCFailureReason::TrampolineFeeOrExpiryInsufficient + }, + Some(TrampolineMppValidationTestCase::CltvInsufficient) => { + cltv_delta += 1; + LocalHTLCFailureReason::TrampolineFeeOrExpiryInsufficient + }, + Some(TrampolineMppValidationTestCase::TrampolineAmountExceedsReceived) => { + next_trampoline_amount = incoming_mpp_total + 1; + LocalHTLCFailureReason::TrampolineFeeOrExpiryInsufficient + }, + Some(TrampolineMppValidationTestCase::TrampolineCLTVExceedsReceived) => { + next_trampoline_cltv = update_add_cltv + 1; + LocalHTLCFailureReason::TrampolineFeeOrExpiryInsufficient + }, + Some(TrampolineMppValidationTestCase::MismatchedPaymentSecret) => { + mismatch_payment_secret = true; + LocalHTLCFailureReason::InvalidTrampolineForward + }, + // We currently reject trampoline forwards once accumulated. + None => LocalHTLCFailureReason::TemporaryTrampolineFailure, + }; + + let chanmon_cfgs = create_chanmon_cfgs(1); + let node_cfgs = create_node_cfgs(1, &chanmon_cfgs); + let mut cfg = test_default_channel_config(); + cfg.channel_config.forwarding_fee_base_msat = forwarding_fee_base_msat as u32; + cfg.channel_config.forwarding_fee_proportional_millionths = 0; + cfg.channel_config.cltv_expiry_delta = cltv_delta as u16; + let node_chanmgrs = create_node_chanmgrs(1, &node_cfgs, &[Some(cfg)]); + let nodes = create_network(1, &node_cfgs, &node_chanmgrs); + + let payment_hash = PaymentHash([1; 32]); + + let secp = Secp256k1::new(); + let test_secret = SecretKey::from_slice(&[2; 32]).unwrap(); + let next_trampoline = PublicKey::from_secret_key(&secp, &test_secret); + let next_hop_info = NextTrampolineHopInfo { + onion_packet: test_trampoline_onion_packet(), + blinding_point: None, + amount_msat: next_trampoline_amount, + cltv_expiry_height: next_trampoline_cltv, + }; + + let htlc1 = ClaimableHTLC::new( + test_prev_hop_data(), + update_add_value, + sender_intended_incoming_value, + update_add_cltv, + OnionPayload::Trampoline { next_hop_info: next_hop_info.clone(), next_trampoline }, + None, + ); + assert!(nodes[0] + .node + .test_handle_trampoline_htlc( + htlc1, + test_onion_fields(incoming_mpp_total), + payment_hash, + next_hop_info.clone(), + next_trampoline, + ) + .is_ok()); + + let htlc2 = ClaimableHTLC::new( + test_prev_hop_data(), + update_add_value, + sender_intended_incoming_value, + update_add_cltv, + OnionPayload::Trampoline { next_hop_info: next_hop_info.clone(), next_trampoline }, + None, + ); + let onion2 = if mismatch_payment_secret { + RecipientOnionFields { + payment_secret: Some(PaymentSecret([1; 32])), + total_mpp_amount_msat: incoming_mpp_total, + payment_metadata: None, + custom_tlvs: Vec::new(), + } + } else { + test_onion_fields(incoming_mpp_total) + }; + let result = nodes[0].node.test_handle_trampoline_htlc( + htlc2, + onion2, + payment_hash, + next_hop_info, + next_trampoline, + ); + + assert_eq!( + HTLCHandlingFailureReason::from(&result.expect_err("expect trampoline failure").1), + HTLCHandlingFailureReason::Local { reason: expected }, + ); +} + +#[test] +fn test_trampoline_mpp_validation() { + do_test_trampoline_mpp_validation(Some(TrampolineMppValidationTestCase::FeeInsufficient)); + do_test_trampoline_mpp_validation(Some(TrampolineMppValidationTestCase::CltvInsufficient)); + do_test_trampoline_mpp_validation(Some( + TrampolineMppValidationTestCase::TrampolineAmountExceedsReceived, + )); + do_test_trampoline_mpp_validation(Some( + TrampolineMppValidationTestCase::TrampolineCLTVExceedsReceived, + )); + do_test_trampoline_mpp_validation(Some( + TrampolineMppValidationTestCase::MismatchedPaymentSecret, + )); + do_test_trampoline_mpp_validation(None); +} diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 0c0d14b43fd..edb048c8c7d 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -2464,7 +2464,7 @@ impl<'a> PaymentPath<'a> { #[inline(always)] /// Calculate the fees required to route the given amount over a channel with the given fees. #[rustfmt::skip] -fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option { +pub(crate) fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option { amount_msat.checked_mul(channel_fees.proportional_millionths as u64) .and_then(|part| (channel_fees.base_msat as u64).checked_add(part / 1_000_000)) }