From 9af7eb318bb0e6671bf554e7cea36a1792b7d92f Mon Sep 17 00:00:00 2001 From: spacebear Date: Fri, 19 Dec 2025 12:27:34 -0500 Subject: [PATCH 1/6] Refactor InternalPersistedError Conceptually, Storage errors are distinct from API errors. Storage errors are never persisted and reflect an application/implementation error. This commit splits API errors into their own enum to better reflect this distinction. --- payjoin/src/core/persist.rs | 127 ++++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 42 deletions(-) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 2fa987e77..66d641d9c 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -313,9 +313,9 @@ where pub fn api_error(self) -> Option { match self.0 { - InternalPersistedError::Fatal(e) - | InternalPersistedError::Transient(e) - | InternalPersistedError::FatalWithState(e, _) => Some(e), + InternalPersistedError::Api( + ApiError::Fatal(e) | ApiError::Transient(e) | ApiError::FatalWithState(e, _), + ) => Some(e), _ => None, } } @@ -329,16 +329,16 @@ where pub fn api_error_ref(&self) -> Option<&ApiErr> { match &self.0 { - InternalPersistedError::Fatal(e) - | InternalPersistedError::Transient(e) - | InternalPersistedError::FatalWithState(e, _) => Some(e), + InternalPersistedError::Api( + ApiError::Fatal(e) | ApiError::Transient(e) | ApiError::FatalWithState(e, _), + ) => Some(e), _ => None, } } pub fn error_state(self) -> Option { match self.0 { - InternalPersistedError::FatalWithState(_, state) => Some(state), + InternalPersistedError::Api(ApiError::FatalWithState(_, state)) => Some(state), _ => None, } } @@ -358,37 +358,54 @@ impl - fmt::Display for PersistedError +impl + fmt::Display for PersistedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.0 { - InternalPersistedError::Transient(err) => write!(f, "Transient error: {err}"), - InternalPersistedError::Fatal(err) | InternalPersistedError::FatalWithState(err, _) => - write!(f, "Fatal error: {err}"), + InternalPersistedError::Api(ApiError::Transient(err)) => + write!(f, "Transient error: {err}"), + InternalPersistedError::Api( + ApiError::Fatal(err) | ApiError::FatalWithState(err, _), + ) => write!(f, "Fatal error: {err}"), InternalPersistedError::Storage(err) => write!(f, "Storage error: {err}"), } } } #[derive(Debug)] -pub(crate) enum InternalPersistedError +pub(crate) enum ApiError { + /// Error indicating that the session should be retried from the same state + Transient(Err), + /// Error indicating that the session is terminally closed + Fatal(Err), + /// Fatal error that results in a state transition to ErrorState + FatalWithState(Err, ErrorState), +} + +#[derive(Debug)] +pub(crate) enum InternalPersistedError where - InternalApiError: std::error::Error, + ApiErr: std::error::Error, StorageErr: std::error::Error, ErrorState: fmt::Debug, { - /// Error indicating that the session should be retried from the same state - Transient(InternalApiError), - /// Error indicating that the session is terminally closed - Fatal(InternalApiError), - /// Fatal error that results in a state transition to ErrorState - FatalWithState(InternalApiError, ErrorState), - /// Error indicating that application failed to save the session event. This should be treated as a transient error - /// but is represented as a separate error because this error is propagated from the application's storage layer + /// Error indicating that the session failed to progress to the next success state. + Api(ApiError), + /// Error indicating that application failed to save the session event. Storage(StorageErr), } +impl From> + for InternalPersistedError +where + Err: std::error::Error, + StorageErr: std::error::Error, + ErrorState: fmt::Debug, +{ + fn from(api: ApiError) -> Self { InternalPersistedError::Api(api) } +} + /// Represents a state transition that either progresses to a new state or maintains the current state #[derive(Debug, PartialEq)] pub enum OptionalTransitionOutcome { @@ -446,7 +463,7 @@ trait InternalSessionPersister: SessionPersister { MaybeFatalOrSuccessTransition::Fatal(reject_fatal) => Err(self.handle_fatal_reject(reject_fatal).into()), MaybeFatalOrSuccessTransition::Transient(RejectTransient(err)) => - Err(InternalPersistedError::Transient(err).into()), + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), } } @@ -475,7 +492,7 @@ trait InternalSessionPersister: SessionPersister { Ok(success_value) } Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), Err(Rejection::Fatal(reject_fatal)) => Err(self.handle_fatal_reject(reject_fatal).into()), Err(Rejection::ReplyableError(reject_replyable_error)) => @@ -514,7 +531,7 @@ trait InternalSessionPersister: SessionPersister { Err(Rejection::Fatal(reject_fatal)) => Err(self.handle_fatal_reject(reject_fatal).into()), Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), Err(Rejection::ReplyableError(reject_replyable_error)) => Err(self.handle_replyable_error_reject(reject_replyable_error).into()), } @@ -549,7 +566,7 @@ trait InternalSessionPersister: SessionPersister { Err(Rejection::Fatal(reject_fatal)) => Err(self.handle_fatal_reject(reject_fatal).into()), Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), Err(Rejection::ReplyableError(reject_replyable_error)) => Err(self.handle_replyable_error_reject(reject_replyable_error).into()), } @@ -568,7 +585,8 @@ trait InternalSessionPersister: SessionPersister { self.save_event(event).map_err(InternalPersistedError::Storage)?; Ok(next_state) } - Err(RejectTransient(err)) => Err(InternalPersistedError::Transient(err).into()), + Err(RejectTransient(err)) => + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), } } @@ -592,7 +610,7 @@ trait InternalSessionPersister: SessionPersister { Err(self.handle_fatal_reject(reject_fatal).into()), Rejection::Transient(RejectTransient(err)) => { // No event to store for transient errors - Err(InternalPersistedError::Transient(err).into()) + Err(InternalPersistedError::Api(ApiError::Transient(err)).into()) } Rejection::ReplyableError(reject_replyable_error) => Err(self.handle_replyable_error_reject(reject_replyable_error).into()), @@ -618,7 +636,7 @@ trait InternalSessionPersister: SessionPersister { return InternalPersistedError::Storage(e); } - InternalPersistedError::Fatal(error) + InternalPersistedError::Api(ApiError::Fatal(error)) } fn handle_replyable_error_reject( @@ -634,7 +652,7 @@ trait InternalSessionPersister: SessionPersister { return InternalPersistedError::Storage(e); } // For replyable errors, don't close the session - keep it open for error response - InternalPersistedError::FatalWithState(error, error_state) + InternalPersistedError::Api(ApiError::FatalWithState(error, error_state)) } } @@ -854,7 +872,10 @@ mod tests { expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, test: Box::new(move |persister| { @@ -917,7 +938,10 @@ mod tests { expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, test: Box::new(move |persister| { @@ -929,7 +953,9 @@ mod tests { expected_result: ExpectedResult { events: vec![InMemoryTestEvent("error event".to_string())], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, test: Box::new(move |persister| { @@ -976,7 +1002,10 @@ mod tests { expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, test: Box::new(move |persister| { @@ -988,7 +1017,9 @@ mod tests { expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1050,7 +1081,10 @@ mod tests { expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1063,7 +1097,9 @@ mod tests { expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1125,7 +1161,9 @@ mod tests { expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1184,7 +1222,9 @@ mod tests { expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1197,7 +1237,10 @@ mod tests { expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, test: Box::new(move |persister| { @@ -1225,13 +1268,13 @@ mod tests { // Test Internal API error cases let fatal_error = PersistedError::( - InternalPersistedError::Fatal(api_err.clone()), + InternalPersistedError::Api(ApiError::Fatal(api_err.clone())), ); assert!(fatal_error.storage_error_ref().is_none()); assert!(fatal_error.api_error_ref().is_some()); let transient_error = PersistedError::( - InternalPersistedError::Transient(api_err.clone()), + InternalPersistedError::Api(ApiError::Transient(api_err.clone())), ); assert!(transient_error.storage_error_ref().is_none()); assert!(transient_error.api_error_ref().is_some()); From e0aee18f1bc8ee1019f8f0b52ea9079d12f4425e Mon Sep 17 00:00:00 2001 From: spacebear Date: Fri, 19 Dec 2025 14:13:04 -0500 Subject: [PATCH 2/6] Refactor persistence state transition actions Split `save()` into distinct "deconstruction" and "execution" steps, and have the `deconstruct()` method live on the Transition structs directly. `deconstruct()` returns a `PersistAction` which tells the persister what action to take (do nothing, save an event, or save an event and close the session). --- payjoin/src/core/persist.rs | 378 ++++++++++++++---------------------- 1 file changed, 151 insertions(+), 227 deletions(-) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 66d641d9c..3dd7af4e5 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -1,4 +1,32 @@ use std::fmt; + +/// Representation of the actions that the persister should take, if any. +pub(crate) enum PersistActions { + /// Do nothing. + NoOp, + /// Save an event. + Save(Event), + /// Save an event and close the session. + SaveAndClose(Event), +} + +impl PersistActions { + pub fn execute

(self, persister: &P) -> Result<(), P::InternalStorageError> + where + P: SessionPersister, + { + match self { + Self::NoOp => {} + Self::Save(event) => persister.save_event(event)?, + Self::SaveAndClose(event) => { + persister.save_event(event)?; + persister.close()?; + } + } + Ok(()) + } +} + /// Handles cases where the transition either succeeds with a final result that ends the session, or hits a static condition and stays in the same state. /// State transition may also be a fatal error or transient error. pub struct MaybeSuccessTransitionWithNoResults( @@ -27,6 +55,29 @@ impl )))) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> ( + PersistActions, + Result, ApiError>, + ) { + match self.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, success_value))) => ( + PersistActions::SaveAndClose(event), + Ok(OptionalTransitionOutcome::Progress(success_value)), + ), + Ok(AcceptOptionalTransition::NoResults(current_state)) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -38,7 +89,9 @@ impl P: SessionPersister, Err: std::error::Error, { - persister.save_maybe_no_results_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } /// A transition that can result in a state transition, fatal error, or successfully have no results. @@ -67,6 +120,27 @@ impl )))) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> ( + PersistActions, + Result, ApiError>, + ) { + match self.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, next_state))) => + (PersistActions::Save(event), Ok(OptionalTransitionOutcome::Progress(next_state))), + Ok(AcceptOptionalTransition::NoResults(current_state)) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -78,7 +152,9 @@ impl P: SessionPersister, Err: std::error::Error, { - persister.save_maybe_no_results_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -107,6 +183,20 @@ where MaybeFatalTransition(Err(Rejection::replyable_error(event, error_state, error))) } + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, next_state)) => (PersistActions::Save(event), Ok(next_state)), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, error_state, error))) => + (PersistActions::Save(event), Err(ApiError::FatalWithState(error, error_state))), + } + } + pub fn save

( self, persister: &P, @@ -115,7 +205,9 @@ where P: SessionPersister, Err: std::error::Error, { - persister.save_maybe_fatal_error_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -134,6 +226,13 @@ impl MaybeTransientTransition { MaybeTransientTransition(Err(RejectTransient(error))) } + pub(crate) fn deconstruct(self) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, next_state)) => (PersistActions::Save(event), Ok(next_state)), + Err(RejectTransient(error)) => (PersistActions::NoOp, Err(ApiError::Transient(error))), + } + } + pub fn save

( self, persister: &P, @@ -142,7 +241,9 @@ impl MaybeTransientTransition { P: SessionPersister, Err: std::error::Error, { - persister.save_maybe_transient_error_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -168,6 +269,21 @@ where MaybeSuccessTransition(Err(Rejection::fatal(event, error))) } + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, success_value)) => + (PersistActions::SaveAndClose(event), Ok(success_value)), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -175,7 +291,9 @@ where where P: SessionPersister, { - persister.save_maybe_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -187,11 +305,18 @@ impl NextStateTransition { NextStateTransition(AcceptNextState(event, next_state)) } + pub(crate) fn deconstruct(self) -> (PersistActions, NextState) { + let AcceptNextState(event, next_state) = self.0; + (PersistActions::Save(event), next_state) + } + pub fn save

(self, persister: &P) -> Result where P: SessionPersister, { - persister.save_progression_transition(self) + let (actions, next_state) = self.deconstruct(); + actions.execute(persister)?; + Ok(next_state) } } @@ -223,6 +348,23 @@ where MaybeFatalOrSuccessTransition::NoResults(current_state) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result, ApiError>) + { + match self { + MaybeFatalOrSuccessTransition::Success(event) => + (PersistActions::SaveAndClose(event), Ok(OptionalTransitionOutcome::Progress(()))), + MaybeFatalOrSuccessTransition::NoResults(current_state) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + MaybeFatalOrSuccessTransition::Transient(RejectTransient(error)) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + MaybeFatalOrSuccessTransition::Fatal(RejectFatal(event, error)) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -234,7 +376,9 @@ where P: SessionPersister, Err: std::error::Error, { - persister.save_maybe_fatal_or_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -438,226 +582,6 @@ pub trait SessionPersister { fn close(&self) -> Result<(), Self::InternalStorageError>; } -/// Internal logic for processing specific state transitions. Each method is strongly typed to the state transition type. -/// Methods are not meant to be called directly, but are invoked through a state transition object's `save` method. -trait InternalSessionPersister: SessionPersister { - fn save_maybe_fatal_or_success_transition( - &self, - state_transition: MaybeFatalOrSuccessTransition, - ) -> Result< - OptionalTransitionOutcome<(), CurrentState>, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition { - MaybeFatalOrSuccessTransition::Success(event) => { - // Success value here would be the something to save - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(())) - } - MaybeFatalOrSuccessTransition::NoResults(current_state) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - MaybeFatalOrSuccessTransition::Fatal(reject_fatal) => - Err(self.handle_fatal_reject(reject_fatal).into()), - MaybeFatalOrSuccessTransition::Transient(RejectTransient(err)) => - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), - } - } - - /// Save state transition where state transition does not return an error - /// Only returns an error if the storage fails - fn save_progression_transition( - &self, - state_transition: NextStateTransition, - ) -> Result { - self.save_event(state_transition.0 .0)?; - Ok(state_transition.0 .1) - } - - /// Save a transition that can be a state transition or a transient error - fn save_maybe_success_transition( - &self, - state_transition: MaybeSuccessTransition, - ) -> Result> - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptNextState(event, success_value)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(success_value) - } - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - - /// Persists the outcome of a state transition that may result in one of the following: - /// - A successful state transition, in which case the success value is returned and the session is closed. - /// - No state change (stasis), where the current state is retained and nothing is persisted. - /// - A transient error, which does not affect persistent storage and is returned to the caller. - /// - A fatal error, which is persisted and returned to the caller. - fn save_maybe_no_results_success_transition( - &self, - state_transition: MaybeSuccessTransitionWithNoResults< - Self::SessionEvent, - SuccessValue, - CurrentState, - Err, - >, - ) -> Result< - OptionalTransitionOutcome, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptOptionalTransition::Success(AcceptNextState(event, success_value))) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(success_value)) - } - Ok(AcceptOptionalTransition::NoResults(current_state)) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - /// Save a transition that can result in: - /// - A successful state transition - /// - No state change (no results) - /// - A transient error - /// - A fatal error - fn save_maybe_no_results_transition( - &self, - state_transition: MaybeFatalTransitionWithNoResults< - Self::SessionEvent, - NextState, - CurrentState, - Err, - >, - ) -> Result< - OptionalTransitionOutcome, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptOptionalTransition::Success(AcceptNextState(event, next_state))) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(next_state)) - } - Ok(AcceptOptionalTransition::NoResults(current_state)) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - - /// Save a transition that can be a transient error or a state transition - fn save_maybe_transient_error_transition( - &self, - state_transition: MaybeTransientTransition, - ) -> Result> - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptNextState(event, next_state)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(next_state) - } - Err(RejectTransient(err)) => - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()), - } - } - - /// Save a transition that can be a fatal error, transient error or a state transition - fn save_maybe_fatal_error_transition( - &self, - state_transition: MaybeFatalTransition, - ) -> Result> - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - match state_transition.0 { - Ok(AcceptNextState(event, next_state)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(next_state) - } - Err(e) => { - match e { - Rejection::Fatal(reject_fatal) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Rejection::Transient(RejectTransient(err)) => { - // No event to store for transient errors - Err(InternalPersistedError::Api(ApiError::Transient(err)).into()) - } - Rejection::ReplyableError(reject_replyable_error) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - } - } - - fn handle_fatal_reject( - &self, - reject_fatal: RejectFatal, - ) -> InternalPersistedError - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - let RejectFatal(event, error) = reject_fatal; - if let Err(e) = self.save_event(event) { - return InternalPersistedError::Storage(e); - } - // Session is in a terminal state, close it - if let Err(e) = self.close() { - return InternalPersistedError::Storage(e); - } - - InternalPersistedError::Api(ApiError::Fatal(error)) - } - - fn handle_replyable_error_reject( - &self, - reject_replyable_error: RejectReplyableError, - ) -> InternalPersistedError - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - let RejectReplyableError(event, error_state, error) = reject_replyable_error; - if let Err(e) = self.save_event(event) { - return InternalPersistedError::Storage(e); - } - // For replyable errors, don't close the session - keep it open for error response - InternalPersistedError::Api(ApiError::FatalWithState(error, error_state)) - } -} - -impl InternalSessionPersister for T {} - /// A persister that does nothing /// This persister cannot be used to replay a session #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] From fa1b466cdec3e960bbb2b681608ac083aafc3c78 Mon Sep 17 00:00:00 2001 From: spacebear Date: Wed, 17 Dec 2025 11:56:14 -0500 Subject: [PATCH 3/6] Introduce `AsyncSessionPersister` and save_async This gives the implementer the choice of to implement and call an asynchronous persister as an alternative to the existing synchronous one, by exposing the `AsyncSessionPersister` trait and `save_async` method in the API alongside their synchronous counterparts. --- payjoin/src/core/persist.rs | 165 ++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 3dd7af4e5..d7a7f551d 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -25,6 +25,22 @@ impl PersistActions { } Ok(()) } + + pub async fn execute_async

(self, persister: &P) -> Result<(), P::InternalStorageError> + where + P: AsyncSessionPersister, + Event: Send, + { + match self { + Self::NoOp => {} + Self::Save(event) => persister.save_event(event).await?, + Self::SaveAndClose(event) => { + persister.save_event(event).await?; + persister.close().await?; + } + } + Ok(()) + } } /// Handles cases where the transition either succeeds with a final result that ends the session, or hits a static condition and stays in the same state. @@ -93,7 +109,27 @@ impl actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: std::error::Error + Send, + SuccessValue: Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } + /// A transition that can result in a state transition, fatal error, or successfully have no results. pub struct MaybeFatalTransitionWithNoResults( Result, Rejection>, @@ -156,6 +192,25 @@ impl actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: std::error::Error + Send, + NextState: Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } /// A transition that can be either fatal, transient, or a state transition. @@ -209,6 +264,22 @@ where actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: std::error::Error + Send, + ErrorState: Send, + NextState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } /// A transition that can result in a state transition or a transient error. @@ -245,6 +316,21 @@ impl MaybeTransientTransition { actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: std::error::Error + Send, + NextState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } /// A transition that can result in the completion of a state machine or a transient error @@ -295,6 +381,21 @@ where actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: Send, + SuccessValue: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } /// A transition that always results in a state transition. @@ -318,6 +419,17 @@ impl NextStateTransition { actions.execute(persister)?; Ok(next_state) } + + pub async fn save_async

(self, persister: &P) -> Result + where + P: AsyncSessionPersister, + NextState: Send, + Event: Send, + { + let (actions, next_state) = self.deconstruct(); + actions.execute_async(persister).await?; + Ok(next_state) + } } /// A transition that can result in a succession completion, fatal error, or transient error. @@ -380,6 +492,24 @@ where actions.execute(persister).map_err(InternalPersistedError::Storage)?; Ok(outcome.map_err(InternalPersistedError::Api)?) } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome<(), CurrentState>, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: std::error::Error + Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } } /// Wrapper that marks the progression of a state machine @@ -582,6 +712,41 @@ pub trait SessionPersister { fn close(&self) -> Result<(), Self::InternalStorageError>; } +/// Async version of [`SessionPersister`] for use in async contexts. +// +// Methods use `impl Future<...> + Send` instead of `async fn` because `async fn` in traits +// doesn't guarantee the returned future is `Send`. This triggers the `async_fn_in_trait` lint. +// https://doc.rust-lang.org/stable/nightly-rustc/rustc_lint/async_fn_in_trait/static.ASYNC_FN_IN_TRAIT.html +pub trait AsyncSessionPersister: Send + Sync { + /// Errors that may arise from implementers storage layer + type InternalStorageError: std::error::Error + Send + Sync + 'static; + /// Session events types that we are persisting + type SessionEvent: Send; + + /// Appends to list of session updates, Receives generic events + fn save_event( + &self, + event: Self::SessionEvent, + ) -> impl std::future::Future> + Send; + + /// Loads all the events from the session in the same order they were saved + fn load( + &self, + ) -> impl std::future::Future< + Output = Result< + Box + Send>, + Self::InternalStorageError, + >, + > + Send; + + /// Marks the session as closed, no more events will be appended. + /// This is invoked when the session is terminated due to a fatal error + /// or when the session is closed due to a success state + fn close( + &self, + ) -> impl std::future::Future> + Send; +} + /// A persister that does nothing /// This persister cannot be used to replay a session #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] From ab07480346de28ae9b366e54c668826249350eff Mon Sep 17 00:00:00 2001 From: spacebear Date: Thu, 18 Dec 2025 23:11:43 -0500 Subject: [PATCH 4/6] Refactor persist.rs unit tests Simplify the `TestCase` struct by making the transition types in the struct directly. This allows more granular control over how the test cases are performed, which enables adding async tests with minimal code duplication in the next commit. --- payjoin/src/core/persist.rs | 381 +++++++++++++++++------------------- 1 file changed, 176 insertions(+), 205 deletions(-) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index d7a7f551d..398490840 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -858,12 +858,8 @@ mod tests { } } - struct TestCase { - // Allow type complexity for the test closure - #[allow(clippy::type_complexity)] - test: Box< - dyn Fn(&InMemoryTestPersister) -> Result, - >, + struct TestCase { + make_transition: Box Transition>, expected_result: ExpectedResult, } @@ -878,12 +874,11 @@ mod tests { success: Option, } - fn do_test( + fn verify_sync( persister: &InMemoryTestPersister, - test_case: &TestCase, + result: Result, + expected_result: &ExpectedResult, ) { - let expected_result = &test_case.expected_result; - let res = (test_case.test)(persister); let events = persister.load().expect("Persister should not fail").collect::>(); assert_eq!(events.len(), expected_result.events.len()); for (event, expected_event) in events.iter().zip(expected_result.events.iter()) { @@ -895,7 +890,7 @@ mod tests { expected_result.is_closed ); - match (&res, &expected_result.error) { + match (&result, &expected_result.error) { (Ok(actual), None) => { assert_eq!(Some(actual), expected_result.success.as_ref()); } @@ -907,57 +902,61 @@ mod tests { _ => panic!("Unexpected result state"), } } + macro_rules! run_test_cases { + ($test_cases:expr) => { + for test in &$test_cases { + let persister = InMemoryTestPersister::default(); + let result = (test.make_transition)().save(&persister); + verify_sync(&persister, result, &test.expected_result); + } + }; + } #[test] fn test_initial_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec> = vec![ - // Success - TestCase { - expected_result: ExpectedResult { - events: vec![event.clone()], - is_closed: false, - error: None, - success: Some(next_state.clone()), - }, - test: Box::new(move |persister| { - NextStateTransition::success(event.clone(), next_state.clone()).save(persister) - }), + + let test_cases = vec![TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || NextStateTransition::success(event.clone(), next_state.clone()) + }), + expected_result: ExpectedResult { + events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(next_state.clone()), }, - ]; + }]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] fn test_maybe_transient_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - InMemoryTestState, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || MaybeTransientTransition::success(event.clone(), next_state.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(next_state.clone()), }, - test: Box::new(move |persister| { - MaybeTransientTransition::success(event.clone(), next_state.clone()) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeTransientTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, @@ -967,63 +966,56 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeTransientTransition::transient(InMemoryTestError {}).save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] fn test_next_state_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec> = vec![ - // Success - TestCase { - expected_result: ExpectedResult { - events: vec![event.clone()], - is_closed: false, - error: None, - success: Some(next_state.clone()), - }, - test: Box::new(move |persister| { - NextStateTransition::success(event.clone(), next_state.clone()).save(persister) - }), + + let test_cases = vec![TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || NextStateTransition::success(event.clone(), next_state.clone()) + }), + expected_result: ExpectedResult { + events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(next_state.clone()), }, - ]; + }]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] fn test_maybe_success_transition() { let event = InMemoryTestEvent("foo".to_string()); - let test_cases: Vec< - TestCase<(), PersistedError>, - > = vec![ - // Success + let error_event = InMemoryTestEvent("error event".to_string()); + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + move || MaybeSuccessTransition::success(event.clone(), ()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(()), }, - test: Box::new(move |persister| { - MaybeSuccessTransition::success(event.clone(), ()).save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeSuccessTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, @@ -1033,34 +1025,24 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransition::transient(InMemoryTestError {}).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || MaybeSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) + }), expected_result: ExpectedResult { - events: vec![InMemoryTestEvent("error event".to_string())], + events: vec![error_event.clone()], is_closed: true, error: Some( InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransition::fatal( - InMemoryTestEvent("error event".to_string()), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] @@ -1069,26 +1051,26 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - InMemoryTestState, - PersistedError, - >, - > = vec![ + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || MaybeFatalTransition::success(event.clone(), next_state.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(next_state.clone()), }, - test: Box::new(move |persister| { - MaybeFatalTransition::success(event.clone(), next_state.clone()).save(persister) - }), }, - // Transient error TestCase { - expected_result: ExpectedResult { + make_transition: Box::new(|| MaybeFatalTransition::transient(InMemoryTestError {})), + expected_result: ExpectedResult::< + _, + PersistedError, + > { events: vec![], is_closed: false, error: Some( @@ -1097,12 +1079,12 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransition::transient(InMemoryTestError {}).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || MaybeFatalTransition::fatal(error_event.clone(), InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, @@ -1111,17 +1093,10 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransition::fatal(error_event.clone(), InMemoryTestError {}) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] @@ -1130,43 +1105,45 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let success_value = "Success value".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let success_value = success_value.clone(); + move || { + MaybeSuccessTransitionWithNoResults::success( + success_value.clone(), + event.clone(), + ) + } + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(OptionalTransitionOutcome::Progress(success_value.clone())), }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::success( - success_value.clone(), - event.clone(), - ) - .save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeSuccessTransitionWithNoResults::no_results(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::no_results(current_state.clone()) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeSuccessTransitionWithNoResults::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, @@ -1176,13 +1153,17 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::transient(InMemoryTestError {}) - .save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeSuccessTransitionWithNoResults::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, @@ -1191,20 +1172,10 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::fatal( - error_event.clone(), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] @@ -1213,40 +1184,51 @@ mod tests { let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || { + MaybeFatalTransitionWithNoResults::success( + event.clone(), + next_state.clone(), + ) + } + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Progress(next_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::success(event.clone(), next_state.clone()) - .save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeFatalTransitionWithNoResults::no_results(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::no_results(current_state.clone()) - .save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeFatalTransitionWithNoResults::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, @@ -1255,20 +1237,10 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::fatal( - error_event.clone(), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] @@ -1276,38 +1248,45 @@ mod tests { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome<(), InMemoryTestState>, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + move || MaybeFatalOrSuccessTransition::Success(event.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(OptionalTransitionOutcome::Progress(())), }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::Success(event.clone()).save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeFatalOrSuccessTransition::NoResults(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome<(), InMemoryTestState>, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::NoResults(current_state.clone()).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeFatalOrSuccessTransition::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, @@ -1316,13 +1295,11 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeFatalOrSuccessTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, @@ -1332,16 +1309,10 @@ mod tests { ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::transient(InMemoryTestError {}).save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] From 4b6ce5b87e8a4fa57ddeb042487964a040d72740 Mon Sep 17 00:00:00 2001 From: spacebear Date: Wed, 17 Dec 2025 12:25:35 -0500 Subject: [PATCH 5/6] Add unit tests for AsyncSessionPersister --- payjoin/src/core/persist.rs | 115 ++++++++++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 17 deletions(-) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 398490840..525fc36b4 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -832,6 +832,53 @@ pub mod test_utils { Ok(()) } } + + #[cfg(test)] + #[derive(Clone)] + /// Async in-memory session persister for testing async session replays and introspecting session events + pub struct InMemoryAsyncTestPersister { + pub(crate) inner: Arc>>, + } + + #[cfg(test)] + impl Default for InMemoryAsyncTestPersister { + fn default() -> Self { + Self { inner: Arc::new(tokio::sync::RwLock::new(InnerStorage::default())) } + } + } + + #[cfg(test)] + impl crate::persist::AsyncSessionPersister for InMemoryAsyncTestPersister + where + V: Clone + Send + Sync + 'static, + { + type InternalStorageError = std::convert::Infallible; + type SessionEvent = V; + + async fn save_event( + &self, + event: Self::SessionEvent, + ) -> Result<(), Self::InternalStorageError> { + let mut inner = self.inner.write().await; + Arc::make_mut(&mut inner.events).push(event); + Ok(()) + } + + async fn load( + &self, + ) -> Result + Send>, Self::InternalStorageError> + { + let inner = self.inner.read().await; + let events = Arc::clone(&inner.events); + Ok(Box::new(Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter())) + } + + async fn close(&self) -> Result<(), Self::InternalStorageError> { + let mut inner = self.inner.write().await; + inner.is_closed = true; + Ok(()) + } + } } #[cfg(test)] @@ -839,7 +886,7 @@ mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::persist::test_utils::InMemoryTestPersister; + use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister}; type InMemoryTestState = String; @@ -902,18 +949,52 @@ mod tests { _ => panic!("Unexpected result state"), } } + + async fn verify_async< + SuccessState: std::fmt::Debug + PartialEq + Send, + ErrorState: std::error::Error + Send, + >( + persister: &InMemoryAsyncTestPersister, + result: Result, + expected_result: &ExpectedResult, + ) { + let events = persister.load().await.expect("Persister should not fail").collect::>(); + assert_eq!(events.len(), expected_result.events.len()); + for (event, expected_event) in events.iter().zip(expected_result.events.iter()) { + assert_eq!(event.0, expected_event.0); + } + + assert_eq!(persister.inner.read().await.is_closed, expected_result.is_closed); + + match (&result, &expected_result.error) { + (Ok(actual), None) => { + assert_eq!(Some(actual), expected_result.success.as_ref()); + } + (Err(actual), Some(exp)) => { + // TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type + // And not all internal API errors implement PartialEq + assert_eq!(actual.to_string(), exp.to_string()); + } + _ => panic!("Unexpected result state"), + } + } + macro_rules! run_test_cases { ($test_cases:expr) => { for test in &$test_cases { let persister = InMemoryTestPersister::default(); let result = (test.make_transition)().save(&persister); verify_sync(&persister, result, &test.expected_result); + + let persister = InMemoryAsyncTestPersister::default(); + let result = (test.make_transition)().save_async(&persister).await; + verify_async(&persister, result, &test.expected_result).await; } }; } - #[test] - fn test_initial_transition() { + #[tokio::test] + async fn test_initial_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); @@ -934,8 +1015,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_transient_transition() { + #[tokio::test] + async fn test_maybe_transient_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); @@ -972,8 +1053,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_next_state_transition() { + #[tokio::test] + async fn test_next_state_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); @@ -994,8 +1075,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_success_transition() { + #[tokio::test] + async fn test_maybe_success_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); @@ -1045,8 +1126,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_transition() { + #[tokio::test] + async fn test_maybe_fatal_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let next_state = "Next state".to_string(); @@ -1099,8 +1180,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_success_transition_with_no_results() { + #[tokio::test] + async fn test_maybe_success_transition_with_no_results() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); @@ -1178,8 +1259,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_transition_with_no_results() { + #[tokio::test] + async fn test_maybe_fatal_transition_with_no_results() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); @@ -1243,8 +1324,8 @@ mod tests { run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_or_success_transition() { + #[tokio::test] + async fn test_maybe_fatal_or_success_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); From c78ba7e581f0a85affeeb04e304a70fe36916ad7 Mon Sep 17 00:00:00 2001 From: spacebear Date: Mon, 12 Jan 2026 17:48:15 -0500 Subject: [PATCH 6/6] Move Err trait bounds to impl block in persist.rs Specify the error trait bound on the impl block for all methods instead of redundantly specifying it for every method. Also fixes some inconsistencies in how it was previously specified. --- payjoin/src/core/persist.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 525fc36b4..bb5083e43 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -51,6 +51,8 @@ pub struct MaybeSuccessTransitionWithNoResults MaybeSuccessTransitionWithNoResults +where + Err: std::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeSuccessTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -103,7 +105,6 @@ impl > where P: SessionPersister, - Err: std::error::Error, { let (actions, outcome) = self.deconstruct(); actions.execute(persister).map_err(InternalPersistedError::Storage)?; @@ -119,7 +120,7 @@ impl > where P: AsyncSessionPersister, - Err: std::error::Error + Send, + Err: Send, SuccessValue: Send, CurrentState: Send, Event: Send, @@ -137,6 +138,8 @@ pub struct MaybeFatalTransitionWithNoResults MaybeFatalTransitionWithNoResults +where + Err: std::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeFatalTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -186,7 +189,6 @@ impl > where P: SessionPersister, - Err: std::error::Error, { let (actions, outcome) = self.deconstruct(); actions.execute(persister).map_err(InternalPersistedError::Storage)?; @@ -202,7 +204,7 @@ impl > where P: AsyncSessionPersister, - Err: std::error::Error + Send, + Err: Send, NextState: Send, CurrentState: Send, Event: Send, @@ -220,6 +222,7 @@ pub struct MaybeFatalTransition( impl MaybeFatalTransition where + Err: std::error::Error, ErrorState: fmt::Debug, { pub(crate) fn fatal(event: Event, error: Err) -> Self { @@ -258,7 +261,6 @@ where ) -> Result> where P: SessionPersister, - Err: std::error::Error, { let (actions, outcome) = self.deconstruct(); actions.execute(persister).map_err(InternalPersistedError::Storage)?; @@ -271,7 +273,7 @@ where ) -> Result> where P: AsyncSessionPersister, - Err: std::error::Error + Send, + Err: Send, ErrorState: Send, NextState: Send, Event: Send, @@ -288,7 +290,10 @@ pub struct MaybeTransientTransition( Result, RejectTransient>, ); -impl MaybeTransientTransition { +impl MaybeTransientTransition +where + Err: std::error::Error, +{ pub(crate) fn success(event: Event, next_state: NextState) -> Self { MaybeTransientTransition(Ok(AcceptNextState(event, next_state))) } @@ -310,7 +315,6 @@ impl MaybeTransientTransition { ) -> Result> where P: SessionPersister, - Err: std::error::Error, { let (actions, outcome) = self.deconstruct(); actions.execute(persister).map_err(InternalPersistedError::Storage)?; @@ -323,7 +327,7 @@ impl MaybeTransientTransition { ) -> Result> where P: AsyncSessionPersister, - Err: std::error::Error + Send, + Err: Send, NextState: Send, Event: Send, { @@ -486,7 +490,6 @@ where > where P: SessionPersister, - Err: std::error::Error, { let (actions, outcome) = self.deconstruct(); actions.execute(persister).map_err(InternalPersistedError::Storage)?; @@ -502,7 +505,7 @@ where > where P: AsyncSessionPersister, - Err: std::error::Error + Send, + Err: Send, CurrentState: Send, Event: Send, {