From e8680dee5a8a41437bbb5cb6fdf681f995c4bc10 Mon Sep 17 00:00:00 2001 From: Johnosezele Date: Wed, 6 Aug 2025 16:31:19 +0100 Subject: [PATCH] Implement sticky max fee rate per payjoin session Make max_fee_rate persistent across payjoin session resumptions and simplify FFI layer architecture by eliminating duplicate fee rate capping logic. - Add max_fee_rate field to SessionContext with serde defaults for backward compatibility - Remove session_max_fee_rate() accessor method (no longer needed) - Simplify FFI layer to convert Option to Option and pass directly to core apply_fee_range() method - Eliminate code duplication between FFI and core layers - Maintain all fee rate capping logic in core library only Fixes #897 --- payjoin-cli/src/app/config.rs | 2 +- payjoin-cli/src/app/rpc.rs | 8 +- payjoin-cli/src/app/v2/mod.rs | 1 + payjoin-cli/src/cli/mod.rs | 2 +- payjoin-cli/src/main.rs | 2 +- .../test/test_payjoin_integration_test.dart | 2 +- .../dart/test/test_payjoin_unit_test.dart | 2 + .../test/test_payjoin_integration_test.py | 6 +- .../python/test/test_payjoin_unit_test.py | 6 +- payjoin-ffi/src/receive/mod.rs | 6 +- payjoin/src/core/receive/v2/mod.rs | 103 ++++++++++++++++++ payjoin/src/core/send/v2/mod.rs | 13 ++- payjoin/tests/integration.rs | 18 ++- 13 files changed, 145 insertions(+), 26 deletions(-) diff --git a/payjoin-cli/src/app/config.rs b/payjoin-cli/src/app/config.rs index 017492962..f2ef2af61 100644 --- a/payjoin-cli/src/app/config.rs +++ b/payjoin-cli/src/app/config.rs @@ -328,7 +328,7 @@ fn handle_subcommands(config: Builder, cli: &Cli) -> Result Ok(config), + Commands::Resume { .. } => Ok(config), } } diff --git a/payjoin-cli/src/app/rpc.rs b/payjoin-cli/src/app/rpc.rs index 3574eb4df..09b353df8 100644 --- a/payjoin-cli/src/app/rpc.rs +++ b/payjoin-cli/src/app/rpc.rs @@ -82,12 +82,12 @@ impl AsyncBitcoinRpc { .basic_auth(&self.username, Some(&self.password)); let response = - request.send().await.with_context(|| format!("RPC '{}': connection failed", method))?; + request.send().await.with_context(|| format!("RPC '{method}': connection failed"))?; let json = response .json::>() .await - .with_context(|| format!("RPC '{}': invalid response", method))?; + .with_context(|| format!("RPC '{method}': invalid response"))?; match json { RpcResponse::Success { result, .. } => Ok(result), @@ -291,7 +291,7 @@ mod tests { .await .expect_err("Should fail due to invalid address"); let error_msg = error.to_string(); - println!("{}", error_msg); + println!("{error_msg}"); assert_rpc_error_format( &error_msg, @@ -321,7 +321,7 @@ mod tests { .await .expect_err("Should fail due to insufficient funds"); let error_msg = error.to_string(); - println!("{}", error_msg); + println!("{error_msg}"); assert_rpc_error_format( &error_msg, diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index b679c9493..496e08f4c 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -159,6 +159,7 @@ impl AppTrait for App { ohttp_keys, None, Some(amount), + self.config.max_fee_rate, )? .save(&persister)?; println!("Receive session established"); diff --git a/payjoin-cli/src/cli/mod.rs b/payjoin-cli/src/cli/mod.rs index a3db87d46..b25183360 100644 --- a/payjoin-cli/src/cli/mod.rs +++ b/payjoin-cli/src/cli/mod.rs @@ -129,7 +129,7 @@ pub enum Commands { }, /// Resume pending payjoins (BIP77/v2 only) #[cfg(feature = "v2")] - Resume, + Resume {}, } pub fn parse_amount_in_sat(s: &str) -> Result { diff --git a/payjoin-cli/src/main.rs b/payjoin-cli/src/main.rs index 373875aae..ed58071e1 100644 --- a/payjoin-cli/src/main.rs +++ b/payjoin-cli/src/main.rs @@ -64,7 +64,7 @@ async fn main() -> Result<()> { app.receive_payjoin(*amount).await?; } #[cfg(feature = "v2")] - Commands::Resume => { + Commands::Resume { .. } => { app.resume_payjoins().await?; } }; diff --git a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart index 79fcb9b4b..f654dd1f8 100644 --- a/payjoin-ffi/dart/test/test_payjoin_integration_test.dart +++ b/payjoin-ffi/dart/test/test_payjoin_integration_test.dart @@ -135,7 +135,7 @@ payjoin.Initialized create_receiver_context( payjoin.OhttpKeys ohttp_keys, InMemoryReceiverPersister persister) { var receiver = payjoin.UninitializedReceiver() - .createSession(address, directory, ohttp_keys, null, null) + .createSession(address, directory, ohttp_keys, null, null, null) .save(persister); return receiver; } diff --git a/payjoin-ffi/dart/test/test_payjoin_unit_test.dart b/payjoin-ffi/dart/test/test_payjoin_unit_test.dart index c6fe0dad2..9bfabe676 100644 --- a/payjoin-ffi/dart/test/test_payjoin_unit_test.dart +++ b/payjoin-ffi/dart/test/test_payjoin_unit_test.dart @@ -110,6 +110,7 @@ void main() { payjoin.OhttpKeys.fromString( "OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"), null, + null, null) .save(persister); final result = payjoin.replayReceiverEventLog(persister); @@ -128,6 +129,7 @@ void main() { payjoin.OhttpKeys.fromString( "OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"), null, + null, null) .save(receiver_persister); var uri = receiver.pjUri(); diff --git a/payjoin-ffi/python/test/test_payjoin_integration_test.py b/payjoin-ffi/python/test/test_payjoin_integration_test.py index 36cec8b1b..dd51f52b5 100644 --- a/payjoin-ffi/python/test/test_payjoin_integration_test.py +++ b/payjoin-ffi/python/test/test_payjoin_integration_test.py @@ -84,7 +84,7 @@ async def process_receiver_proposal(self, receiver: ReceiveSession, recv_persist raise Exception(f"Unknown receiver state: {receiver}") def create_receiver_context(self, receiver_address: bitcoinffi.Address, directory: Url, ohttp_keys: OhttpKeys, recv_persister: InMemoryReceiverSessionEventLog) -> Initialized: - receiver = UninitializedReceiver().create_session(address=receiver_address, directory=directory.as_string(), ohttp_keys=ohttp_keys, expire_after=None, amount=None).save(recv_persister) + receiver = UninitializedReceiver().create_session(address=receiver_address, directory=directory.as_string(), ohttp_keys=ohttp_keys, expire_after=None, amount=None, max_fee_rate_sat_per_vb=10).save(recv_persister) return receiver async def retrieve_receiver_proposal(self, receiver: Initialized, recv_persister: InMemoryReceiverSessionEventLog, ohttp_relay: Url): @@ -124,9 +124,9 @@ async def process_wants_outputs(self, proposal: WantsOutputs, recv_persister: In async def process_wants_inputs(self, proposal: WantsInputs, recv_persister: InMemoryReceiverSessionEventLog): provisional_proposal = proposal.contribute_inputs(get_inputs(self.receiver)).commit_inputs().save(recv_persister) return await self.process_wants_fee_range(provisional_proposal, recv_persister) - + async def process_wants_fee_range(self, proposal: WantsFeeRange, recv_persister: InMemoryReceiverSessionEventLog): - provisional_proposal = proposal.apply_fee_range(1, 10).save(recv_persister) + provisional_proposal = proposal.apply_fee_range(1, None).save(recv_persister) return await self.process_provisional_proposal(provisional_proposal, recv_persister) async def process_provisional_proposal(self, proposal: ProvisionalProposal, recv_persister: InMemoryReceiverSessionEventLog): diff --git a/payjoin-ffi/python/test/test_payjoin_unit_test.py b/payjoin-ffi/python/test/test_payjoin_unit_test.py index 63d5f06f4..ff8b394f0 100644 --- a/payjoin-ffi/python/test/test_payjoin_unit_test.py +++ b/payjoin-ffi/python/test/test_payjoin_unit_test.py @@ -55,7 +55,8 @@ def test_receiver_persistence(self): "https://example.com", payjoin.OhttpKeys.from_string("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"), None, - None + None, + None, ).save(persister) result = payjoin.payjoin_ffi.replay_receiver_event_log(persister) self.assertTrue(result.state().is_INITIALIZED()) @@ -85,7 +86,8 @@ def test_sender_persistence(self): "https://example.com", payjoin.OhttpKeys.from_string("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"), None, - None + None, + None, ).save(persister) uri = receiver.pj_uri() diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 65563e2fb..821a9b010 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -266,6 +266,7 @@ impl UninitializedReceiver { ohttp_keys: Arc, expire_after: Option, amount: Option, + max_fee_rate_sat_per_vb: Option, ) -> Result { payjoin::receive::v2::Receiver::create_session( (*address).clone().into(), @@ -273,9 +274,10 @@ impl UninitializedReceiver { (*ohttp_keys).clone().into(), expire_after.map(Duration::from_secs), amount.map(payjoin::bitcoin::Amount::from_sat), + max_fee_rate_sat_per_vb.and_then(FeeRate::from_sat_per_vb), ) - .map(|receiver| InitialReceiveTransition(Arc::new(RwLock::new(Some(receiver))))) - .map_err(IntoUrlError::from) + .map_err(Into::into) + .map(|session| InitialReceiveTransition(Arc::new(RwLock::new(Some(session))))) } } diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 79c5932f7..76e759202 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -73,6 +73,7 @@ pub struct SessionContext { amount: Option, s: HpkeKeyPair, e: Option, + max_fee_rate: FeeRate, } impl SessionContext { @@ -268,8 +269,10 @@ impl Receiver { ohttp_keys: OhttpKeys, expire_after: Option, amount: Option, + max_fee_rate: Option, ) -> Result>, IntoUrlError> { let directory = directory.into_url()?; + let session_context = SessionContext { address, directory, @@ -279,6 +282,7 @@ impl Receiver { s: HpkeKeyPair::gen_keypair(), e: None, amount, + max_fee_rate: max_fee_rate.unwrap_or(FeeRate::BROADCAST_MIN), }; Ok(NextStateTransition::success( SessionEvent::Created(session_context.clone()), @@ -892,6 +896,9 @@ impl Receiver { min_fee_rate: Option, max_effective_fee_rate: Option, ) -> MaybeFatalTransition, ReplyableError> { + let max_effective_fee_rate = + max_effective_fee_rate.or(Some(self.state.session_context.max_fee_rate)); + let inner = match self.state.v1.apply_fee_range(min_fee_rate, max_effective_fee_rate) { Ok(inner) => inner, Err(e) => { @@ -1116,6 +1123,7 @@ pub mod test { ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), ), expiry: SystemTime::now() + Duration::from_secs(60), + max_fee_rate: FeeRate::BROADCAST_MIN, s: HpkeKeyPair::gen_keypair(), e: None, amount: None, @@ -1141,6 +1149,23 @@ pub mod test { } } + fn create_wants_fee_range_with_context(context: SessionContext) -> WantsFeeRange { + let unchecked = v1::test::unchecked_proposal_from_test_vector(); + let wants_outputs = unchecked + .assume_interactive_receiver() + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect("No inputs should be owned") + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .expect("No inputs should be seen before") + .identify_receiver_outputs(&mut |_| Ok(true)) + .expect("Receiver output should be identified"); + + let wants_inputs = wants_outputs.commit_outputs(); + let v1_wants_fee_range = wants_inputs.commit_inputs(); + + WantsFeeRange { v1: v1_wants_fee_range, session_context: context } + } + #[test] fn test_v2_mutable_receiver_state_closures() { let persister = NoopSessionPersister::default(); @@ -1361,6 +1386,7 @@ pub mod test { SHARED_CONTEXT.ohttp_keys.clone(), None, None, + None, ) .expect("constructor on test vector should not fail") .save(&noop_persister) @@ -1375,6 +1401,83 @@ pub mod test { #[test] fn test_v2_pj_uri() { + let context = SHARED_CONTEXT.clone(); + let uri = pj_uri(&context, OutputSubstitution::Disabled); + assert!(!uri.to_string().is_empty()); + } + + #[test] + fn test_session_creation_with_max_fee_rate() { + let custom_fee_rate = FeeRate::from_sat_per_vb_unchecked(5); + let address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") + .expect("valid address") + .assume_checked(); + + let session = Receiver::create_session( + address, + EXAMPLE_URL.clone(), + SHARED_CONTEXT.ohttp_keys.clone(), + None, + None, + Some(custom_fee_rate), + ); + + let noop_persister = NoopSessionPersister::default(); + let session = session + .expect("Session creation should not fail") + .save(&noop_persister) + .expect("Noop persister shouldn't fail"); + + assert_eq!(session.context.max_fee_rate, custom_fee_rate); + } + + #[test] + fn test_apply_fee_range_session_max_overrides_parameter() { + let session_max = FeeRate::from_sat_per_vb_unchecked(5); + let context = SessionContext { max_fee_rate: session_max, ..SHARED_CONTEXT.clone() }; + let receiver = Receiver { state: create_wants_fee_range_with_context(context) }; + + let higher_rate = FeeRate::from_sat_per_vb_unchecked(10); + + let result = receiver + .apply_fee_range(None, Some(higher_rate)) + .save(&NoopSessionPersister::default()) + .expect("Noop persister shouldn't fail"); + + let payjoin_psbt = &result.state.psbt_context.payjoin_psbt; + let payjoin_fee = payjoin_psbt.fee().expect("PSBT should have fee"); + let actual_fee_rate = + payjoin_fee / payjoin_psbt.clone().extract_tx_unchecked_fee_rate().weight(); + + assert!( + actual_fee_rate <= session_max, + "Fee rate {actual_fee_rate} should be capped at session maximum {session_max} even when higher parameter rate {higher_rate} is provided" + ); + } + + #[test] + fn test_apply_fee_range_with_none_uses_session_max() { + let session_max = FeeRate::from_sat_per_vb_unchecked(7); + let context = SessionContext { max_fee_rate: session_max, ..SHARED_CONTEXT.clone() }; + let receiver = Receiver { state: create_wants_fee_range_with_context(context) }; + + let result = receiver + .apply_fee_range(None, None) + .save(&NoopSessionPersister::default()) + .expect("Noop persister shouldn't fail"); + let payjoin_psbt = &result.state.psbt_context.payjoin_psbt; + let payjoin_fee = payjoin_psbt.fee().expect("PSBT should have fee"); + let actual_fee_rate = + payjoin_fee / payjoin_psbt.clone().extract_tx_unchecked_fee_rate().weight(); + + assert!( + actual_fee_rate <= session_max, + "Fee rate {actual_fee_rate} should not exceed session maximum {session_max}" + ); + } + + #[test] + fn test_v2_pj_uri_with_output_substitution() { let uri = Receiver { state: Initialized { context: SHARED_CONTEXT.clone() } }.pj_uri(); assert_ne!(uri.extras.pj_param.endpoint(), EXAMPLE_URL.clone()); assert_eq!(uri.extras.output_substitution, OutputSubstitution::Disabled); diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index b3c2a71ba..e9e4b11b3 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -593,7 +593,7 @@ mod test { } #[test] - fn test_v2_sender_builder() { + fn test_v2_sender_builder() -> Result<(), BoxError> { let address = Address::from_str("2N47mmrWXsNBvQR6k78hWJoTji57zXwNcU7") .expect("valid address") .assume_checked(); @@ -601,11 +601,11 @@ mod test { let ohttp_keys = OhttpKeys( ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), ); - let pj_uri = Receiver::create_session(address.clone(), directory, ohttp_keys, None, None) - .expect("constructor on test vector should not fail") - .save(&NoopSessionPersister::default()) - .expect("receiver should succeed") - .pj_uri(); + let pj_uri = + Receiver::create_session(address.clone(), directory, ohttp_keys, None, None, None)? + .save(&NoopSessionPersister::default()) + .expect("receiver should succeed") + .pj_uri(); let req_ctx = SenderBuilder::new(PARSED_ORIGINAL_PSBT.clone(), pj_uri.clone()) .build_recommended(FeeRate::BROADCAST_MIN) .expect("build on test vector should succeed") @@ -641,5 +641,6 @@ mod test { .save(&NoopSessionPersister::default()) .expect("sender should succeed"); assert_eq!(req_ctx.state.psbt_ctx.output_substitution, OutputSubstitution::Disabled); + Ok(()) } } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index ecc36fd50..10d01f94c 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -209,9 +209,15 @@ mod integration { let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4")? .assume_checked(); let noop_persister = NoopSessionPersister::default(); - let mut bad_initializer = - Receiver::create_session(mock_address, directory, bad_ohttp_keys, None, None)? - .save(&noop_persister)?; + let mut bad_initializer = Receiver::create_session( + mock_address, + directory, + bad_ohttp_keys, + None, + None, + None, + )? + .save(&noop_persister)?; let (req, _ctx) = bad_initializer.create_poll_request(&ohttp_relay)?; agent .post(req.url) @@ -255,6 +261,7 @@ mod integration { ohttp_keys, Some(Duration::from_secs(0)), None, + None, )? .save(&recv_noop_persister)?; match expired_receiver.create_poll_request(&ohttp_relay) { @@ -307,7 +314,7 @@ mod integration { let address = receiver.get_new_address(None, None)?.assume_checked(); let mut session = - Receiver::create_session(address, directory, ohttp_keys, None, None)? + Receiver::create_session(address, directory, ohttp_keys, None, None, None)? .save(&persister)?; println!("session: {:#?}", &session); // Poll receive request @@ -425,7 +432,7 @@ mod integration { // test session with expiry in the future let mut session = - Receiver::create_session(address, directory, ohttp_keys, None, None)? + Receiver::create_session(address, directory, ohttp_keys, None, None, None)? .save(&recv_persister)?; println!("session: {:#?}", &session); // Poll receive request @@ -610,6 +617,7 @@ mod integration { ohttp_keys.clone(), None, None, + None, )? .save(&recv_persister)?;