From c20eecdc8946d0f1538adbda5f14056c6501dd65 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 6 Mar 2025 18:40:15 +0100 Subject: [PATCH 1/3] trampoline: Rename ChannelData -> Channel Signed-off-by: Peter Neuroth --- libs/gl-plugin/src/tramp.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/libs/gl-plugin/src/tramp.rs b/libs/gl-plugin/src/tramp.rs index b9a8b7d1c..a3f6975bb 100644 --- a/libs/gl-plugin/src/tramp.rs +++ b/libs/gl-plugin/src/tramp.rs @@ -192,7 +192,7 @@ pub async fn trampolinepay( debug!("overpay={}, total_amt={}", overpay, amount_msat); - let mut channels: Vec = rpc + let mut channels: Vec = rpc .call_typed(&cln_rpc::model::requests::ListpeerchannelsRequest { id: Some(node_id) }) .await? .channels @@ -226,7 +226,7 @@ pub async fn trampolinepay( return None; } }; - return Some(ChannelData { + return Some(Channel { short_channel_id, spendable_msat, min_htlc_out_msat, @@ -448,11 +448,11 @@ async fn do_pay( } async fn reestablished_channels( - channels: Vec, + channels: Vec, node_id: PublicKey, rpc_path: PathBuf, deadline: Instant, -) -> Result> { +) -> Result> { // Wait for channels to re-establish. crate::awaitables::assert_send(AwaitableChannel::new( node_id, @@ -487,10 +487,11 @@ async fn reestablished_channels( Ok(_amount) => Some(channel_data), _ => None, }) - .collect::>()) + .collect::>()) } -struct ChannelData { +#[derive(Clone, Debug, PartialEq, Eq)] +struct Channel { short_channel_id: cln_rpc::primitives::ShortChannelId, spendable_msat: u64, min_htlc_out_msat: u64, From 24ff8798562e8805bf8d7c092ae78c255f1af6ff Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 6 Mar 2025 18:41:07 +0100 Subject: [PATCH 2/3] trampoline: Add channel filter for spendable_msat It makes more sense to filter directly from the get go instead of skipping channels that have an unsufficient spendable_msat amount and can cause zero value htlcs. Signed-off-by: Peter Neuroth --- libs/gl-plugin/src/tramp.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/gl-plugin/src/tramp.rs b/libs/gl-plugin/src/tramp.rs index a3f6975bb..48fd54c0c 100644 --- a/libs/gl-plugin/src/tramp.rs +++ b/libs/gl-plugin/src/tramp.rs @@ -232,6 +232,8 @@ pub async fn trampolinepay( min_htlc_out_msat, }); }) + .filter(|ch| ch.spendable_msat > 0) + .filter(|ch| ch.spendable_msat > ch.min_htlc_out_msat) .collect(); channels.sort_by(|a, b| b.spendable_msat.cmp(&a.spendable_msat)); From 32e64c3640cdc62e83cd8e03ad13f1e4ad99b6b5 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 6 Mar 2025 18:42:51 +0100 Subject: [PATCH 3/3] trampoline: Add new channel allocation logic The way we allocated channels for a trampoline payment could lead to a case where we could get stuck selecting channels for a payment when the rest amount was lower than the lower bound of the channel. This commit introduces a new selection logic that tries to be greedy first but will split payments more carefully if it is actually needded. Signed-off-by: Peter Neuroth --- libs/gl-plugin/src/tramp.rs | 643 ++++++++++++++++++++++++++++++------ 1 file changed, 550 insertions(+), 93 deletions(-) diff --git a/libs/gl-plugin/src/tramp.rs b/libs/gl-plugin/src/tramp.rs index 48fd54c0c..c3e28d189 100644 --- a/libs/gl-plugin/src/tramp.rs +++ b/libs/gl-plugin/src/tramp.rs @@ -32,8 +32,6 @@ const PAY_UNPARSEABLE_ONION_MSG: &str = "Malformed error reply"; const PAY_UNPARSEABLE_ONION_CODE: i32 = 202; // How long do we wait for channels to re-establish? const AWAIT_CHANNELS_TIMEOUT_SEC: u64 = 20; -// Minimum amount we can send through a channel. -const MIN_HTLC_AMOUNT: u64 = 1; fn feature_guard(features: impl Into>, feature_bit: usize) -> Result<()> { let mut features = features.into(); @@ -192,7 +190,7 @@ pub async fn trampolinepay( debug!("overpay={}, total_amt={}", overpay, amount_msat); - let mut channels: Vec = rpc + let channels: Vec = rpc .call_typed(&cln_rpc::model::requests::ListpeerchannelsRequest { id: Some(node_id) }) .await? .channels @@ -236,8 +234,6 @@ pub async fn trampolinepay( .filter(|ch| ch.spendable_msat > ch.min_htlc_out_msat) .collect(); - channels.sort_by(|a, b| b.spendable_msat.cmp(&a.spendable_msat)); - // Check if we actually got a channel to the trampoline node. if channels.is_empty() { return Err(anyhow!("Has no channels with trampoline node")); @@ -245,20 +241,39 @@ pub async fn trampolinepay( // Await and filter out re-established channels. let deadline = Instant::now() + Duration::from_secs(AWAIT_CHANNELS_TIMEOUT_SEC); - let channels = + let mut channels = reestablished_channels(channels, node_id, rpc_path.as_ref().to_path_buf(), deadline) .await?; - // Note: We can also do this inside the reestablished_channels function - // but as we want to be greedy picking our channels we don't want to - // introduce a race of the choosen channels for now. - let choosen = pick_channels(amount_msat, channels)?; - - // FIXME should not be neccessary as we already check on the amount. - let parts = choosen.len(); - if parts == 0 { - return Err(anyhow!("no channels found to send")); - } + // Try different allocation strategies in sequence. First try in ascending + // order of spendable_msat, giving us most drained channels first. Then + // try in descending order of spendable_msat giving us the channels with the + // biggest local balance first. + debug!( + "Trying to allocate {}msat accross {} channels in ascending order", + amount_msat, + channels.len() + ); + let alloc = match find_allocation_ascending_order(&mut channels, amount_msat) + .filter(|alloc| !alloc.is_empty()) + { + Some(alloc) => alloc, + None => { + debug!("Failed to allocate {}msat in ascending channel order {:?}, trying in descending order",amount_msat, &channels); + match find_allocation_descending_order(&mut channels, amount_msat) + .filter(|alloc| !alloc.is_empty()) + { + Some(alloc) => alloc, + None => { + return Err(anyhow!( + "could not allocate enough funds across channels {}msat<{}msat", + channels.iter().map(|ch| ch.spendable_msat).sum::(), + amount_msat + )); + } + } + } + }; // All set we can preapprove the invoice let _ = rpc @@ -274,16 +289,18 @@ pub async fn trampolinepay( payload.set_tu64(TLV_AMT_MSAT, tlv_amount_msat); let payload_hex = hex::encode(SerializedTlvStream::to_bytes(payload)); - let mut part_id = if choosen.len() == 1 { 0 } else { 1 }; + let mut part_id = if alloc.len() == 1 { 0 } else { 1 }; let group_id = max_group_id + 1; let mut handles: Vec< tokio::task::JoinHandle< std::result::Result, >, > = vec![]; - for (scid, part_amt) in choosen { + for ch in &alloc { let bolt11 = req.bolt11.clone(); let label = req.label.clone(); + let part_amt = ch.contrib_msat.clone(); + let scid = ch.channel.short_channel_id.clone(); let description = decoded.description.clone(); let payload_hex = payload_hex.clone(); let mut rpc = ClnRpc::new(&rpc_path).await?; @@ -331,7 +348,7 @@ pub async fn trampolinepay( amount_msat: cln_rpc::primitives::Amount::from_msat(amount_msat), amount_sent_msat: cln_rpc::primitives::Amount::from_msat(amount_msat), created_at: 0., - parts: parts as u32, + parts: alloc.len() as u32, payment_hash: decoded.payment_hash, payment_preimage, }) @@ -340,45 +357,6 @@ pub async fn trampolinepay( } } -fn pick_channels( - amount_msat: u64, - mut channels: Vec, -) -> Result> { - let mut acc = 0; - let mut choosen = vec![]; - while let Some(channel) = channels.pop() { - if acc == amount_msat { - break; - } - - // Filter out channels that lack minimum funds and can not send an htlc. - if std::cmp::max(MIN_HTLC_AMOUNT, channel.min_htlc_out_msat) > channel.spendable_msat { - debug!("Skip channel {}: has spendable_msat={} and minimum_htlc_out_msat={} and can not send htlc.", - channel.short_channel_id, - channel.spendable_msat, - channel.min_htlc_out_msat, - ); - continue; - } - - if (channel.spendable_msat + acc) <= amount_msat { - choosen.push((channel.short_channel_id, channel.spendable_msat)); - acc += channel.spendable_msat; - } else { - let rest = amount_msat - acc; - choosen.push((channel.short_channel_id, rest)); - acc += rest; - break; - } - } - - if acc < amount_msat { - return Err(anyhow!("missing balance {}msat<{}msat", acc, amount_msat)); - } - - Ok(choosen) -} - async fn do_pay( rpc: &mut ClnRpc, node_id: PublicKey, @@ -499,6 +477,191 @@ struct Channel { min_htlc_out_msat: u64, } +#[derive(Clone, Debug, PartialEq, Eq)] +struct ChannelContribution<'a> { + channel: &'a Channel, + contrib_msat: u64, +} + +// Finds a payment allocation by sorting channels in descending order of +// spendable amount. +/// +/// This strategy prioritizes channels with the most funds first, which tends to +/// minimize the number of channels used for large payments. For each spendable +/// amount, it further prioritizes channels with smaller minimum HTLC +/// requirements. +fn find_allocation_descending_order<'a>( + channels: &'a mut [Channel], + target_msat: u64, +) -> Option>> { + // We sort in descending order for spendable_msat and ascending for the + // min_htlc_out_msat, which means that we process the channels with the + // biggest local funds first. + channels.sort_by(|a, b| { + b.spendable_msat + .cmp(&a.spendable_msat) + .then_with(|| a.min_htlc_out_msat.cmp(&b.min_htlc_out_msat)) + }); + + find_allocation(channels, target_msat) +} + +/// Finds a payment allocation by sorting channels in ascending order of +/// spendable amount. +/// +/// This strategy prioritizes draining smaller channels first, which can help +/// consolidate funds into fewer channels. For each spendable amount, +/// it further prioritizes channels with smaller minimum HTLC requirements. +fn find_allocation_ascending_order<'a>( + channels: &'a mut [Channel], + target_msat: u64, +) -> Option>> { + // We sort in ascending order for spendable_msat and min_htlc_out_msat, + // which means that we process the smallest channels first. + channels.sort_by(|a, b| { + a.spendable_msat + .cmp(&b.spendable_msat) + .then_with(|| a.min_htlc_out_msat.cmp(&b.min_htlc_out_msat)) + }); + + find_allocation(channels, target_msat) +} + +/// Finds an allocation that covers the target amount while respecting channel +/// constraints. +/// +/// This function implements a recursive backtracking algorithm that attempts to +/// allocate funds from channels in the order they are provided. It handles +/// complex scenarios where channel minimum requirements may need cascading +/// adjustments to find a valid solution. +/// +/// # Algorithm Details +/// +/// The algorithm works by: +/// 1. Trying to allocate the maximum possible from each channel +/// 2. If a channel's minimum exceeds the remaining target, it tries to skip +/// that channel +/// 3. When a channel minimum can't be met, it backtracks and adjusts previous +/// allocations +/// 4. It uses a cascading approach to free up just enough space from previous +/// channels +fn find_allocation<'a>( + channels: &'a [Channel], + target_msat: u64, +) -> Option>> { + // We can not allocate channels for a zero amount. + if target_msat == 0 { + return None; + } + + /// Result type for the recursive allocation function + enum AllocResult { + /// Allocation succeeded + Success, + /// Allocation is impossible with current channels + Impossible, + /// Need more space (in msat) to satisfy minimum requirements + NeedSpace(u64), + } + + /// Recursive helper function that tries to find a valid allocation + /// + /// # Arguments + /// * `channels` - Remaining channels to consider + /// * `target_msat` - Remaining amount to allocate + /// * `allocations` - Current allocation state (modified in-place) + fn try_allocate<'a>( + channels: &'a [Channel], + target_msat: u64, + allocations: &mut Vec>, + ) -> AllocResult { + // Base case: If we've exactly allocated the target, we found a solution. + if target_msat == 0 { + return AllocResult::Success; + } + + // Check that we have channels left to allocate from. + if channels.is_empty() { + return AllocResult::Impossible; + } + + // Try to use the current channel (smallest amount) first. + let ch = &channels[0]; + + // Channel is drained or unusable, skip it. + if ch.spendable_msat < ch.min_htlc_out_msat || ch.spendable_msat == 0 { + return try_allocate(&channels[1..], target_msat, allocations); + } + + // Each channel has an upper and a lower bound defined by the minimum + // HTLC amount and the spendable amount. + let lower = ch.min_htlc_out_msat; + let upper = ch.spendable_msat.min(target_msat); + + // We need a higher target amount. + if target_msat < lower { + // First we try skipping this channel to see if later channels can + // handle it. + match try_allocate(&channels[1..], target_msat, allocations) { + AllocResult::Success => return AllocResult::Success, + // If that doesn't work, we need space from earlier allocations + _ => return AllocResult::NeedSpace(lower - target_msat), + } + } + + // We can allocate from this channel - try max amount first. + allocations.push(ChannelContribution { + channel: ch, + contrib_msat: upper, + }); + + // Try to allocate the remaining amount from subsequent channels. + match try_allocate(&channels[1..], target_msat - upper, allocations) { + // Success! We're done. + AllocResult::Success => return AllocResult::Success, + + // No solution possible with current allocations + AllocResult::Impossible => return AllocResult::Impossible, + + // Need to free up space + AllocResult::NeedSpace(shortfall) => { + // Calculate how much we can free from this allocation. + let free = upper - lower; + if shortfall <= free { + // We can cover the shortfall with free space in this channel + allocations.pop(); + let adjusted_amount = upper - shortfall; + allocations.push(ChannelContribution { + channel: ch, + contrib_msat: adjusted_amount, + }); + + // Try allocation with the adjusted amount. + match try_allocate(&channels[1..], target_msat - adjusted_amount, allocations) { + AllocResult::Success => return AllocResult::Success, + _ => { + // If that still don't work skip this channel completely. + // NOTE: We could also try to skip the next channel. + allocations.pop(); + return try_allocate(&channels[1..], target_msat, allocations); + } + } + } else { + // We can't fully cover the shortfall, need to pass up a remainder. + allocations.pop(); + return AllocResult::NeedSpace(shortfall - free); + } + } + }; + } + + let mut allocations = Vec::with_capacity(channels.len()); + match try_allocate(channels, target_msat, &mut allocations) { + AllocResult::Success => Some(allocations), + _ => None, + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct SendpayRequest { #[serde(skip_serializing_if = "Option::is_none")] @@ -524,47 +687,341 @@ pub struct SendpayRequest { } #[cfg(test)] -mod tests { +mod channel_allocation_tests { use super::*; + fn create_channel( + short_channel_id: ShortChannelId, + spendable_msat: u64, + min_htlc_out_msat: u64, + ) -> Channel { + Channel { + short_channel_id, + spendable_msat, + min_htlc_out_msat, + } + } + + fn scid(s: &str) -> ShortChannelId { + ShortChannelId::from_str(s).unwrap() + } + + fn verify_allocation( + allocations: &[ChannelContribution], + target_msat: u64, + expected_scids: &[ShortChannelId], + ) { + // Check that the total amount matches. + let total: u64 = allocations.iter().map(|ch| ch.contrib_msat).sum(); + assert_eq!(total, target_msat); + + // Check that the expected order matches. + for (i, alloc) in allocations.iter().enumerate() { + let ch = alloc.channel; + if i < expected_scids.len() { + assert_eq!(ch.short_channel_id, expected_scids[i]); + } + + // Check that the lower and upper limits have been respected. + assert!(alloc.contrib_msat >= ch.min_htlc_out_msat); + assert!(alloc.contrib_msat <= ch.spendable_msat); + } + } + #[test] - fn test_picking_channels() { - let scid1 = ShortChannelId::from_str("100000x100x0").unwrap(); - let scid2 = ShortChannelId::from_str("100000x101x0").unwrap(); - let scid3 = ShortChannelId::from_str("100000x102x0").unwrap(); - let scid4 = ShortChannelId::from_str("100000x103x0").unwrap(); + fn zero_target_amount() { + // A target sum of 0 should return None. let channels = vec![ - ChannelData { - short_channel_id: scid1, - spendable_msat: 100000, - min_htlc_out_msat: 0, - }, - // Below MIN_HTLC_AMOUNT. - ChannelData { - short_channel_id: scid2, - spendable_msat: 0, - min_htlc_out_msat: 0, - }, - // min_htlc_out_msat is larger than spendable_msat. - ChannelData { - short_channel_id: scid3, - spendable_msat: 1, - min_htlc_out_msat: 2, - }, - ChannelData { - short_channel_id: scid4, - spendable_msat: 55000, - min_htlc_out_msat: 55000, - }, + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + ]; + + let result = find_allocation(&channels, 0); + assert_eq!(result, None, "Zero target should return None"); + } + + #[test] + fn single_channel_exact_amount() { + // A single channel that can take the exact amount. + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + ]; + + let result = find_allocation(&channels, 1_000); + assert!(result.is_some(), "Should find an allocation"); + + let allocations = result.unwrap(); + verify_allocation(&allocations, 1_000, &[scid("1x1x1")]); + } + + #[test] + fn single_channel_partial() { + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + ]; + + let result = find_allocation(&channels, 500); + assert!(result.is_some(), "Should find an allocation"); + + let allocations = result.unwrap(); + verify_allocation(&allocations, 500, &[scid("1x1x1")]); + assert_eq!(allocations[0].contrib_msat, 500); + } + + #[test] + fn multiple_channels_simple() { + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + create_channel(scid("3x1x1"), 3_000, 300), + ]; + + let result = find_allocation(&channels, 2_500); + assert!(result.is_some(), "Should find an allocation"); + + let allocations = result.unwrap(); + verify_allocation(&allocations, 2_500, &[scid("1x1x1"), scid("2x1x1")]); + assert_eq!(allocations[0].contrib_msat, 1_000); // Use all of first channel + assert_eq!(allocations[1].contrib_msat, 1_500); // Use part of second channel + } + + #[test] + fn minimum_constraint() { + // Target is below channel's minimum + let channels = vec![create_channel(scid("1x1x1"), 1_000, 500)]; + + let result = find_allocation(&channels, 400); + assert_eq!(result, None, "Can't allocate below minimum"); + } + + #[test] + fn not_enough_funds() { + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + ]; + + let result = find_allocation(&channels, 5_000); + assert_eq!(result, None, "Can't allocate more than total available"); + } + + #[test] + fn adjusting_for_minimum_simple() { + // Need to adjust allocation to meet minimum of next channel + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 600), + ]; + + // Target 1500 would normally use 1000 from channel 1, + // leaving 500 for channel 2, but channel 2 needs at least 600 + let result = find_allocation(&channels, 1_500); + assert!(result.is_some(), "Should find an allocation by adjusting"); + + let allocations = result.unwrap(); + verify_allocation(&allocations, 1_500, &[scid("1x1x1"), scid("2x1x1")]); + assert_eq!(allocations[0].contrib_msat, 900); // Reduced from 1000 + assert_eq!(allocations[1].contrib_msat, 600); // Minimum of channel 2 + } + + #[test] + fn cascading_adjustment() { + // Need to adjust multiple channels to satisfy constraints + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 1_500, 1_300), + create_channel(scid("3x1x1"), 2_000, 800), + ]; + + let result = find_allocation(&channels, 3_000); + assert!( + result.is_some(), + "Should find an allocation by cascading adjustment" + ); + + let allocations = result.unwrap(); + verify_allocation( + &allocations, + 3_000, + &[scid("1x1x1"), scid("2x1x1"), scid("3x1x1")], + ); + + // Verify channel 3 gets at least its minimum + assert_eq!(allocations.len(), 3); + assert_eq!(allocations[0].contrib_msat, 900); + assert_eq!(allocations[1].contrib_msat, 1_300); + assert_eq!(allocations[2].contrib_msat, 800); + } + + #[test] + fn complex_adjustment() { + // Complex case requiring multiple adjustments + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 300), + create_channel(scid("2x1x1"), 1_200, 500), + create_channel(scid("3x1x1"), 1_500, 700), + create_channel(scid("4x1x1"), 2_000, 1_000), + ]; + + let result = find_allocation(&channels, 4_000); + assert!( + result.is_some(), + "Should find an allocation for complex case" + ); + + let allocations = result.unwrap(); + verify_allocation( + &allocations, + 4_000, + &[scid("1x1x1"), scid("2x1x1"), scid("3x1x1"), scid("4x1x1")], + ); + + assert_eq!(allocations.len(), 4); + assert_eq!(allocations[0].contrib_msat, 1_000); + assert_eq!(allocations[1].contrib_msat, 1_200); + assert_eq!(allocations[2].contrib_msat, 800); + assert_eq!(allocations[3].contrib_msat, 1_000); + } + + #[test] + fn skip_channel() { + // Case where we need to skip a channel with higher minimum + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 900), + create_channel(scid("2x1x1"), 1_500, 800), + create_channel(scid("3x1x1"), 2_000, 200), + ]; + + // For target 1500, we should skip channel 2 and use 1 and 3 + let result = find_allocation(&channels, 1_500); + assert!( + result.is_some(), + "Should find an allocation by skipping a channel" + ); + + let allocations = result.unwrap(); + + // We expect to use channels 1 and 3, not channel 2 + verify_allocation(&allocations, 1_500, &[scid("1x1x1"), scid("3x1x1")]); + // Skip specific ID check + } + + #[test] + fn exact_minimum_allocation() { + let channels = vec![ + create_channel(scid("1x1x1"), 999, 500), + create_channel(scid("2x1x1"), 1_000, 500), + ]; + let result = find_allocation(&channels, 1_000); + assert!(result.is_some()); + let allocations = result.unwrap(); + verify_allocation(&allocations, 1_000, &[scid("1x1x1"), scid("2x1x1")]); + // Both should be at their minimums + assert_eq!(allocations[0].contrib_msat, 500); + assert_eq!(allocations[1].contrib_msat, 500); + } + + #[test] + fn all_channels_at_maximum() { + let channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + create_channel(scid("3x1x1"), 3_000, 300), + ]; + let result = find_allocation(&channels, 6_000); + assert!(result.is_some()); + let allocations = result.unwrap(); + verify_allocation( + &allocations, + 6_000, + &[scid("1x1x1"), scid("2x1x1"), scid("3x1x1")], + ); + assert_eq!(allocations[0].contrib_msat, 1_000); + assert_eq!(allocations[1].contrib_msat, 2_000); + assert_eq!(allocations[2].contrib_msat, 3_000); + } + + #[test] + fn drained_channel_skip() { + let channels = vec![ + create_channel(scid("1x1x1"), 50, 100), // Spendable < min_htlc + create_channel(scid("2x1x1"), 1_000, 100), + ]; + let result = find_allocation(&channels, 500); + assert!(result.is_some()); + let allocations = result.unwrap(); + verify_allocation(&allocations, 500, &[scid("2x1x1")]); + } + + #[test] + fn zero_spendable_skip() { + // We have some channels with 0 spendable_msat and we do not want to + // allocate them with 0 amount HTLCs. + let channels = [ + create_channel(scid("1x1x1"), 0, 0), + create_channel(scid("2x1x1"), 5_000, 0), + ]; + let target_msat = 5_000; + let allocations = + find_allocation(&channels, target_msat).expect("Should be able to allocate"); + verify_allocation(&allocations, 5_000, &[scid("2x1x1")]); + } + + #[test] + fn respects_channel_order() { + // Same channels but different order should produce different allocations + let channels1 = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), + ]; + let channels2 = vec![ + create_channel(scid("2x1x1"), 2_000, 200), + create_channel(scid("1x1x1"), 1_000, 100), + ]; + + let result1 = find_allocation(&channels1, 1_500); + let result2 = find_allocation(&channels2, 1_500); + + assert!(result1.is_some()); + assert!(result2.is_some()); + + let alloc1 = result1.unwrap(); + let alloc2 = result2.unwrap(); + + // The allocations should be different because channel order is different + assert_eq!(alloc1[0].channel.short_channel_id, scid("1x1x1")); + assert_eq!(alloc2[0].channel.short_channel_id, scid("2x1x1")); + } + + #[test] + fn test_ascending_order() { + let mut channels = vec![ + create_channel(scid("2x1x1"), 2_000, 200), + create_channel(scid("1x1x1"), 1_000, 100), + ]; + + let result = find_allocation_ascending_order(&mut channels, 1_500); + assert!(result.is_some()); + let allocations = result.unwrap(); + + // Should use the smallest channel first + assert_eq!(allocations[0].channel.short_channel_id, scid("1x1x1")); + } + + #[test] + fn test_descending_order() { + let mut channels = vec![ + create_channel(scid("1x1x1"), 1_000, 100), + create_channel(scid("2x1x1"), 2_000, 200), ]; - let amount_msat = 150000; - let choosen = pick_channels(amount_msat, channels).unwrap(); + let result = find_allocation_descending_order(&mut channels, 1_500); + assert!(result.is_some()); + let allocations = result.unwrap(); - assert_eq!(choosen.len(), 2); - assert_eq!(choosen[0].0, scid4); - assert_eq!(choosen[0].1, 55000); - assert_eq!(choosen[1].0, scid1); - assert_eq!(choosen[1].1, 95000); + // Should use the largest channel first + assert_eq!(allocations[0].channel.short_channel_id, scid("2x1x1")); } }