From 9180eddd560eec60763c6b23b6143dde20e1f5a7 Mon Sep 17 00:00:00 2001 From: Armin Sabouri Date: Tue, 26 Aug 2025 13:36:20 -0400 Subject: [PATCH] Add recv pk as session metadata for pj sender This allows the refrence to find a resumable session without having to replay each session. And removes a misuse of pj endpoint when trying to filter resumable sessions. --- payjoin-cli/src/app/v2/mod.rs | 25 +++++++++++++++---------- payjoin-cli/src/db/mod.rs | 1 + payjoin-cli/src/db/v2.rs | 18 +++++++++++++++--- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index b679c9493..4337bf84b 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -104,26 +104,31 @@ impl AppTrait for App { Ok(()) } PjParam::V2(pj_param) => { - // TODO: perhaps we should store pj uri in the session wrapper as to not replay the event log for each session + let receiver_pubkey = pj_param.receiver_pubkey(); let sender_state = self.db.get_send_session_ids()?.into_iter().find_map(|session_id| { - let sender_persister = - SenderPersister::from_id(self.db.clone(), session_id).ok()?; - let (send_session, session_history) = - replay_sender_event_log(&sender_persister) + let session_receiver_pubkey = self + .db + .get_send_session_receiver_pk(&session_id) + .expect("Receiver pubkey should exist if session id exists"); + if session_receiver_pubkey == *receiver_pubkey { + let sender_persister = + SenderPersister::from_id(self.db.clone(), session_id).ok()?; + let (send_session, _) = replay_sender_event_log(&sender_persister) .map_err(|e| anyhow!("Failed to replay sender event log: {:?}", e)) .ok()?; - let pj_uri = session_history.pj_param().map(|pj_param| pj_param.endpoint()); - let sender_state = - pj_uri.filter(|uri| uri == &pj_param.endpoint()).map(|_| send_session); - sender_state.map(|sender_state| (sender_state, sender_persister)) + Some((send_session, sender_persister)) + } else { + None + } }); let (sender_state, persister) = match sender_state { Some((sender_state, persister)) => (sender_state, persister), None => { - let persister = SenderPersister::new(self.db.clone())?; + let persister = + SenderPersister::new(self.db.clone(), receiver_pubkey.clone())?; let psbt = self.create_original_psbt(&address, amount, fee_rate)?; let sender = SenderBuilder::from_parts(psbt, pj_param, &address, Some(amount)) diff --git a/payjoin-cli/src/db/mod.rs b/payjoin-cli/src/db/mod.rs index 2931fab5a..093a25b0d 100644 --- a/payjoin-cli/src/db/mod.rs +++ b/payjoin-cli/src/db/mod.rs @@ -37,6 +37,7 @@ impl Database { conn.execute( "CREATE TABLE IF NOT EXISTS send_sessions ( session_id INTEGER PRIMARY KEY AUTOINCREMENT, + receiver_pubkey BLOB NOT NULL, completed_at INTEGER )", [], diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 980612a3a..708f3b9c3 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use payjoin::persist::SessionPersister; use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent; use payjoin::send::v2::SessionEvent as SenderSessionEvent; +use payjoin::HpkePublicKey; use rusqlite::params; use super::*; @@ -22,13 +23,13 @@ pub(crate) struct SenderPersister { } impl SenderPersister { - pub fn new(db: Arc) -> crate::db::Result { + pub fn new(db: Arc, receiver_pubkey: HpkePublicKey) -> crate::db::Result { let conn = db.get_connection()?; // Create a new session in send_sessions and get its ID let session_id: i64 = conn.query_row( - "INSERT INTO send_sessions (session_id) VALUES (NULL) RETURNING session_id", - [], + "INSERT INTO send_sessions (session_id, receiver_pubkey) VALUES (NULL, ?1) RETURNING session_id", + params![receiver_pubkey.to_compressed_bytes()], |row| row.get(0), )?; @@ -217,4 +218,15 @@ impl Database { Ok(session_ids) } + + pub(crate) fn get_send_session_receiver_pk( + &self, + session_id: &SessionId, + ) -> Result { + let conn = self.get_connection()?; + let mut stmt = + conn.prepare("SELECT receiver_pubkey FROM send_sessions WHERE session_id = ?1")?; + let receiver_pubkey: Vec = stmt.query_row(params![session_id.0], |row| row.get(0))?; + Ok(HpkePublicKey::from_compressed_bytes(&receiver_pubkey).expect("Valid receiver pubkey")) + } }