diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 2fa987e77..bb5083e43 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -1,4 +1,48 @@ use std::fmt; + +/// Representation of the actions that the persister should take, if any. +pub(crate) enum PersistActions { + /// Do nothing. + NoOp, + /// Save an event. + Save(Event), + /// Save an event and close the session. + SaveAndClose(Event), +} + +impl PersistActions { + pub fn execute

(self, persister: &P) -> Result<(), P::InternalStorageError> + where + P: SessionPersister, + { + match self { + Self::NoOp => {} + Self::Save(event) => persister.save_event(event)?, + Self::SaveAndClose(event) => { + persister.save_event(event)?; + persister.close()?; + } + } + Ok(()) + } + + pub async fn execute_async

(self, persister: &P) -> Result<(), P::InternalStorageError> + where + P: AsyncSessionPersister, + Event: Send, + { + match self { + Self::NoOp => {} + Self::Save(event) => persister.save_event(event).await?, + Self::SaveAndClose(event) => { + persister.save_event(event).await?; + persister.close().await?; + } + } + Ok(()) + } +} + /// Handles cases where the transition either succeeds with a final result that ends the session, or hits a static condition and stays in the same state. /// State transition may also be a fatal error or transient error. pub struct MaybeSuccessTransitionWithNoResults( @@ -7,6 +51,8 @@ pub struct MaybeSuccessTransitionWithNoResults MaybeSuccessTransitionWithNoResults +where + Err: std::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeSuccessTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -27,6 +73,29 @@ impl )))) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> ( + PersistActions, + Result, ApiError>, + ) { + match self.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, success_value))) => ( + PersistActions::SaveAndClose(event), + Ok(OptionalTransitionOutcome::Progress(success_value)), + ), + Ok(AcceptOptionalTransition::NoResults(current_state)) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -36,11 +105,32 @@ impl > where P: SessionPersister, - Err: std::error::Error, { - persister.save_maybe_no_results_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: Send, + SuccessValue: Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } + /// A transition that can result in a state transition, fatal error, or successfully have no results. pub struct MaybeFatalTransitionWithNoResults( Result, Rejection>, @@ -48,6 +138,8 @@ pub struct MaybeFatalTransitionWithNoResults MaybeFatalTransitionWithNoResults +where + Err: std::error::Error, { pub(crate) fn fatal(event: Event, error: Err) -> Self { MaybeFatalTransitionWithNoResults(Err(Rejection::fatal(event, error))) @@ -67,6 +159,27 @@ impl )))) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> ( + PersistActions, + Result, ApiError>, + ) { + match self.0 { + Ok(AcceptOptionalTransition::Success(AcceptNextState(event, next_state))) => + (PersistActions::Save(event), Ok(OptionalTransitionOutcome::Progress(next_state))), + Ok(AcceptOptionalTransition::NoResults(current_state)) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -76,9 +189,29 @@ impl > where P: SessionPersister, - Err: std::error::Error, { - persister.save_maybe_no_results_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: Send, + NextState: Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -89,6 +222,7 @@ pub struct MaybeFatalTransition( impl MaybeFatalTransition where + Err: std::error::Error, ErrorState: fmt::Debug, { pub(crate) fn fatal(event: Event, error: Err) -> Self { @@ -107,15 +241,46 @@ where MaybeFatalTransition(Err(Rejection::replyable_error(event, error_state, error))) } + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, next_state)) => (PersistActions::Save(event), Ok(next_state)), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, error_state, error))) => + (PersistActions::Save(event), Err(ApiError::FatalWithState(error, error_state))), + } + } + pub fn save

( self, persister: &P, ) -> Result> where P: SessionPersister, - Err: std::error::Error, { - persister.save_maybe_fatal_error_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: Send, + ErrorState: Send, + NextState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -125,7 +290,10 @@ pub struct MaybeTransientTransition( Result, RejectTransient>, ); -impl MaybeTransientTransition { +impl MaybeTransientTransition +where + Err: std::error::Error, +{ pub(crate) fn success(event: Event, next_state: NextState) -> Self { MaybeTransientTransition(Ok(AcceptNextState(event, next_state))) } @@ -134,15 +302,38 @@ impl MaybeTransientTransition { MaybeTransientTransition(Err(RejectTransient(error))) } + pub(crate) fn deconstruct(self) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, next_state)) => (PersistActions::Save(event), Ok(next_state)), + Err(RejectTransient(error)) => (PersistActions::NoOp, Err(ApiError::Transient(error))), + } + } + pub fn save

( self, persister: &P, ) -> Result> where P: SessionPersister, - Err: std::error::Error, { - persister.save_maybe_transient_error_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: Send, + NextState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -168,6 +359,21 @@ where MaybeSuccessTransition(Err(Rejection::fatal(event, error))) } + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result>) { + match self.0 { + Ok(AcceptNextState(event, success_value)) => + (PersistActions::SaveAndClose(event), Ok(success_value)), + Err(Rejection::Transient(RejectTransient(error))) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + Err(Rejection::Fatal(RejectFatal(event, error))) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -175,7 +381,24 @@ where where P: SessionPersister, { - persister.save_maybe_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result> + where + P: AsyncSessionPersister, + Err: Send, + SuccessValue: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -187,11 +410,29 @@ impl NextStateTransition { NextStateTransition(AcceptNextState(event, next_state)) } + pub(crate) fn deconstruct(self) -> (PersistActions, NextState) { + let AcceptNextState(event, next_state) = self.0; + (PersistActions::Save(event), next_state) + } + pub fn save

(self, persister: &P) -> Result where P: SessionPersister, { - persister.save_progression_transition(self) + let (actions, next_state) = self.deconstruct(); + actions.execute(persister)?; + Ok(next_state) + } + + pub async fn save_async

(self, persister: &P) -> Result + where + P: AsyncSessionPersister, + NextState: Send, + Event: Send, + { + let (actions, next_state) = self.deconstruct(); + actions.execute_async(persister).await?; + Ok(next_state) } } @@ -223,6 +464,23 @@ where MaybeFatalOrSuccessTransition::NoResults(current_state) } + #[allow(clippy::type_complexity)] + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result, ApiError>) + { + match self { + MaybeFatalOrSuccessTransition::Success(event) => + (PersistActions::SaveAndClose(event), Ok(OptionalTransitionOutcome::Progress(()))), + MaybeFatalOrSuccessTransition::NoResults(current_state) => + (PersistActions::NoOp, Ok(OptionalTransitionOutcome::Stasis(current_state))), + MaybeFatalOrSuccessTransition::Transient(RejectTransient(error)) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + MaybeFatalOrSuccessTransition::Fatal(RejectFatal(event, error)) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + } + } + pub fn save

( self, persister: &P, @@ -232,9 +490,28 @@ where > where P: SessionPersister, - Err: std::error::Error, { - persister.save_maybe_fatal_or_success_transition(self) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) + } + + pub async fn save_async

( + self, + persister: &P, + ) -> Result< + OptionalTransitionOutcome<(), CurrentState>, + PersistedError, + > + where + P: AsyncSessionPersister, + Err: Send, + CurrentState: Send, + Event: Send, + { + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -313,9 +590,9 @@ where pub fn api_error(self) -> Option { match self.0 { - InternalPersistedError::Fatal(e) - | InternalPersistedError::Transient(e) - | InternalPersistedError::FatalWithState(e, _) => Some(e), + InternalPersistedError::Api( + ApiError::Fatal(e) | ApiError::Transient(e) | ApiError::FatalWithState(e, _), + ) => Some(e), _ => None, } } @@ -329,16 +606,16 @@ where pub fn api_error_ref(&self) -> Option<&ApiErr> { match &self.0 { - InternalPersistedError::Fatal(e) - | InternalPersistedError::Transient(e) - | InternalPersistedError::FatalWithState(e, _) => Some(e), + InternalPersistedError::Api( + ApiError::Fatal(e) | ApiError::Transient(e) | ApiError::FatalWithState(e, _), + ) => Some(e), _ => None, } } pub fn error_state(self) -> Option { match self.0 { - InternalPersistedError::FatalWithState(_, state) => Some(state), + InternalPersistedError::Api(ApiError::FatalWithState(_, state)) => Some(state), _ => None, } } @@ -358,37 +635,54 @@ impl - fmt::Display for PersistedError +impl + fmt::Display for PersistedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.0 { - InternalPersistedError::Transient(err) => write!(f, "Transient error: {err}"), - InternalPersistedError::Fatal(err) | InternalPersistedError::FatalWithState(err, _) => - write!(f, "Fatal error: {err}"), + InternalPersistedError::Api(ApiError::Transient(err)) => + write!(f, "Transient error: {err}"), + InternalPersistedError::Api( + ApiError::Fatal(err) | ApiError::FatalWithState(err, _), + ) => write!(f, "Fatal error: {err}"), InternalPersistedError::Storage(err) => write!(f, "Storage error: {err}"), } } } #[derive(Debug)] -pub(crate) enum InternalPersistedError +pub(crate) enum ApiError { + /// Error indicating that the session should be retried from the same state + Transient(Err), + /// Error indicating that the session is terminally closed + Fatal(Err), + /// Fatal error that results in a state transition to ErrorState + FatalWithState(Err, ErrorState), +} + +#[derive(Debug)] +pub(crate) enum InternalPersistedError where - InternalApiError: std::error::Error, + ApiErr: std::error::Error, StorageErr: std::error::Error, ErrorState: fmt::Debug, { - /// Error indicating that the session should be retried from the same state - Transient(InternalApiError), - /// Error indicating that the session is terminally closed - Fatal(InternalApiError), - /// Fatal error that results in a state transition to ErrorState - FatalWithState(InternalApiError, ErrorState), - /// Error indicating that application failed to save the session event. This should be treated as a transient error - /// but is represented as a separate error because this error is propagated from the application's storage layer + /// Error indicating that the session failed to progress to the next success state. + Api(ApiError), + /// Error indicating that application failed to save the session event. Storage(StorageErr), } +impl From> + for InternalPersistedError +where + Err: std::error::Error, + StorageErr: std::error::Error, + ErrorState: fmt::Debug, +{ + fn from(api: ApiError) -> Self { InternalPersistedError::Api(api) } +} + /// Represents a state transition that either progresses to a new state or maintains the current state #[derive(Debug, PartialEq)] pub enum OptionalTransitionOutcome { @@ -421,225 +715,41 @@ pub trait SessionPersister { fn close(&self) -> Result<(), Self::InternalStorageError>; } -/// Internal logic for processing specific state transitions. Each method is strongly typed to the state transition type. -/// Methods are not meant to be called directly, but are invoked through a state transition object's `save` method. -trait InternalSessionPersister: SessionPersister { - fn save_maybe_fatal_or_success_transition( - &self, - state_transition: MaybeFatalOrSuccessTransition, - ) -> Result< - OptionalTransitionOutcome<(), CurrentState>, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition { - MaybeFatalOrSuccessTransition::Success(event) => { - // Success value here would be the something to save - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(())) - } - MaybeFatalOrSuccessTransition::NoResults(current_state) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - MaybeFatalOrSuccessTransition::Fatal(reject_fatal) => - Err(self.handle_fatal_reject(reject_fatal).into()), - MaybeFatalOrSuccessTransition::Transient(RejectTransient(err)) => - Err(InternalPersistedError::Transient(err).into()), - } - } - - /// Save state transition where state transition does not return an error - /// Only returns an error if the storage fails - fn save_progression_transition( - &self, - state_transition: NextStateTransition, - ) -> Result { - self.save_event(state_transition.0 .0)?; - Ok(state_transition.0 .1) - } +/// Async version of [`SessionPersister`] for use in async contexts. +// +// Methods use `impl Future<...> + Send` instead of `async fn` because `async fn` in traits +// doesn't guarantee the returned future is `Send`. This triggers the `async_fn_in_trait` lint. +// https://doc.rust-lang.org/stable/nightly-rustc/rustc_lint/async_fn_in_trait/static.ASYNC_FN_IN_TRAIT.html +pub trait AsyncSessionPersister: Send + Sync { + /// Errors that may arise from implementers storage layer + type InternalStorageError: std::error::Error + Send + Sync + 'static; + /// Session events types that we are persisting + type SessionEvent: Send; - /// Save a transition that can be a state transition or a transient error - fn save_maybe_success_transition( + /// Appends to list of session updates, Receives generic events + fn save_event( &self, - state_transition: MaybeSuccessTransition, - ) -> Result> - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptNextState(event, success_value)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(success_value) - } - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } + event: Self::SessionEvent, + ) -> impl std::future::Future> + Send; - /// Persists the outcome of a state transition that may result in one of the following: - /// - A successful state transition, in which case the success value is returned and the session is closed. - /// - No state change (stasis), where the current state is retained and nothing is persisted. - /// - A transient error, which does not affect persistent storage and is returned to the caller. - /// - A fatal error, which is persisted and returned to the caller. - fn save_maybe_no_results_success_transition( - &self, - state_transition: MaybeSuccessTransitionWithNoResults< - Self::SessionEvent, - SuccessValue, - CurrentState, - Err, - >, - ) -> Result< - OptionalTransitionOutcome, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptOptionalTransition::Success(AcceptNextState(event, success_value))) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - self.close().map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(success_value)) - } - Ok(AcceptOptionalTransition::NoResults(current_state)) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - /// Save a transition that can result in: - /// - A successful state transition - /// - No state change (no results) - /// - A transient error - /// - A fatal error - fn save_maybe_no_results_transition( + /// Loads all the events from the session in the same order they were saved + fn load( &self, - state_transition: MaybeFatalTransitionWithNoResults< - Self::SessionEvent, - NextState, - CurrentState, - Err, + ) -> impl std::future::Future< + Output = Result< + Box + Send>, + Self::InternalStorageError, >, - ) -> Result< - OptionalTransitionOutcome, - PersistedError, - > - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptOptionalTransition::Success(AcceptNextState(event, next_state))) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(OptionalTransitionOutcome::Progress(next_state)) - } - Ok(AcceptOptionalTransition::NoResults(current_state)) => - Ok(OptionalTransitionOutcome::Stasis(current_state)), - Err(Rejection::Fatal(reject_fatal)) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Err(Rejection::Transient(RejectTransient(err))) => - Err(InternalPersistedError::Transient(err).into()), - Err(Rejection::ReplyableError(reject_replyable_error)) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - - /// Save a transition that can be a transient error or a state transition - fn save_maybe_transient_error_transition( - &self, - state_transition: MaybeTransientTransition, - ) -> Result> - where - Err: std::error::Error, - { - match state_transition.0 { - Ok(AcceptNextState(event, next_state)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(next_state) - } - Err(RejectTransient(err)) => Err(InternalPersistedError::Transient(err).into()), - } - } - - /// Save a transition that can be a fatal error, transient error or a state transition - fn save_maybe_fatal_error_transition( - &self, - state_transition: MaybeFatalTransition, - ) -> Result> - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - match state_transition.0 { - Ok(AcceptNextState(event, next_state)) => { - self.save_event(event).map_err(InternalPersistedError::Storage)?; - Ok(next_state) - } - Err(e) => { - match e { - Rejection::Fatal(reject_fatal) => - Err(self.handle_fatal_reject(reject_fatal).into()), - Rejection::Transient(RejectTransient(err)) => { - // No event to store for transient errors - Err(InternalPersistedError::Transient(err).into()) - } - Rejection::ReplyableError(reject_replyable_error) => - Err(self.handle_replyable_error_reject(reject_replyable_error).into()), - } - } - } - } - - fn handle_fatal_reject( - &self, - reject_fatal: RejectFatal, - ) -> InternalPersistedError - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - let RejectFatal(event, error) = reject_fatal; - if let Err(e) = self.save_event(event) { - return InternalPersistedError::Storage(e); - } - // Session is in a terminal state, close it - if let Err(e) = self.close() { - return InternalPersistedError::Storage(e); - } - - InternalPersistedError::Fatal(error) - } + > + Send; - fn handle_replyable_error_reject( + /// Marks the session as closed, no more events will be appended. + /// This is invoked when the session is terminated due to a fatal error + /// or when the session is closed due to a success state + fn close( &self, - reject_replyable_error: RejectReplyableError, - ) -> InternalPersistedError - where - Err: std::error::Error, - ErrorState: fmt::Debug, - { - let RejectReplyableError(event, error_state, error) = reject_replyable_error; - if let Err(e) = self.save_event(event) { - return InternalPersistedError::Storage(e); - } - // For replyable errors, don't close the session - keep it open for error response - InternalPersistedError::FatalWithState(error, error_state) - } + ) -> impl std::future::Future> + Send; } -impl InternalSessionPersister for T {} - /// A persister that does nothing /// This persister cannot be used to replay a session #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -725,6 +835,53 @@ pub mod test_utils { Ok(()) } } + + #[cfg(test)] + #[derive(Clone)] + /// Async in-memory session persister for testing async session replays and introspecting session events + pub struct InMemoryAsyncTestPersister { + pub(crate) inner: Arc>>, + } + + #[cfg(test)] + impl Default for InMemoryAsyncTestPersister { + fn default() -> Self { + Self { inner: Arc::new(tokio::sync::RwLock::new(InnerStorage::default())) } + } + } + + #[cfg(test)] + impl crate::persist::AsyncSessionPersister for InMemoryAsyncTestPersister + where + V: Clone + Send + Sync + 'static, + { + type InternalStorageError = std::convert::Infallible; + type SessionEvent = V; + + async fn save_event( + &self, + event: Self::SessionEvent, + ) -> Result<(), Self::InternalStorageError> { + let mut inner = self.inner.write().await; + Arc::make_mut(&mut inner.events).push(event); + Ok(()) + } + + async fn load( + &self, + ) -> Result + Send>, Self::InternalStorageError> + { + let inner = self.inner.read().await; + let events = Arc::clone(&inner.events); + Ok(Box::new(Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter())) + } + + async fn close(&self) -> Result<(), Self::InternalStorageError> { + let mut inner = self.inner.write().await; + inner.is_closed = true; + Ok(()) + } + } } #[cfg(test)] @@ -732,7 +889,7 @@ mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::persist::test_utils::InMemoryTestPersister; + use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister}; type InMemoryTestState = String; @@ -751,12 +908,8 @@ mod tests { } } - struct TestCase { - // Allow type complexity for the test closure - #[allow(clippy::type_complexity)] - test: Box< - dyn Fn(&InMemoryTestPersister) -> Result, - >, + struct TestCase { + make_transition: Box Transition>, expected_result: ExpectedResult, } @@ -771,12 +924,11 @@ mod tests { success: Option, } - fn do_test( + fn verify_sync( persister: &InMemoryTestPersister, - test_case: &TestCase, + result: Result, + expected_result: &ExpectedResult, ) { - let expected_result = &test_case.expected_result; - let res = (test_case.test)(persister); let events = persister.load().expect("Persister should not fail").collect::>(); assert_eq!(events.len(), expected_result.events.len()); for (event, expected_event) in events.iter().zip(expected_result.events.iter()) { @@ -788,7 +940,7 @@ mod tests { expected_result.is_closed ); - match (&res, &expected_result.error) { + match (&result, &expected_result.error) { (Ok(actual), None) => { assert_eq!(Some(actual), expected_result.success.as_ref()); } @@ -801,415 +953,450 @@ mod tests { } } - #[test] - fn test_initial_transition() { + async fn verify_async< + SuccessState: std::fmt::Debug + PartialEq + Send, + ErrorState: std::error::Error + Send, + >( + persister: &InMemoryAsyncTestPersister, + result: Result, + expected_result: &ExpectedResult, + ) { + let events = persister.load().await.expect("Persister should not fail").collect::>(); + assert_eq!(events.len(), expected_result.events.len()); + for (event, expected_event) in events.iter().zip(expected_result.events.iter()) { + assert_eq!(event.0, expected_event.0); + } + + assert_eq!(persister.inner.read().await.is_closed, expected_result.is_closed); + + match (&result, &expected_result.error) { + (Ok(actual), None) => { + assert_eq!(Some(actual), expected_result.success.as_ref()); + } + (Err(actual), Some(exp)) => { + // TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type + // And not all internal API errors implement PartialEq + assert_eq!(actual.to_string(), exp.to_string()); + } + _ => panic!("Unexpected result state"), + } + } + + macro_rules! run_test_cases { + ($test_cases:expr) => { + for test in &$test_cases { + let persister = InMemoryTestPersister::default(); + let result = (test.make_transition)().save(&persister); + verify_sync(&persister, result, &test.expected_result); + + let persister = InMemoryAsyncTestPersister::default(); + let result = (test.make_transition)().save_async(&persister).await; + verify_async(&persister, result, &test.expected_result).await; + } + }; + } + + #[tokio::test] + async fn test_initial_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec> = vec![ - // Success - TestCase { - expected_result: ExpectedResult { - events: vec![event.clone()], - is_closed: false, - error: None, - success: Some(next_state.clone()), - }, - test: Box::new(move |persister| { - NextStateTransition::success(event.clone(), next_state.clone()).save(persister) - }), + + let test_cases = vec![TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || NextStateTransition::success(event.clone(), next_state.clone()) + }), + expected_result: ExpectedResult { + events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(next_state.clone()), }, - ]; + }]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_transient_transition() { + #[tokio::test] + async fn test_maybe_transient_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - InMemoryTestState, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || MaybeTransientTransition::success(event.clone(), next_state.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(next_state.clone()), }, - test: Box::new(move |persister| { - MaybeTransientTransition::success(event.clone(), next_state.clone()) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeTransientTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeTransientTransition::transient(InMemoryTestError {}).save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_next_state_transition() { + #[tokio::test] + async fn test_next_state_transition() { let event = InMemoryTestEvent("foo".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec> = vec![ - // Success - TestCase { - expected_result: ExpectedResult { - events: vec![event.clone()], - is_closed: false, - error: None, - success: Some(next_state.clone()), - }, - test: Box::new(move |persister| { - NextStateTransition::success(event.clone(), next_state.clone()).save(persister) - }), + + let test_cases = vec![TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || NextStateTransition::success(event.clone(), next_state.clone()) + }), + expected_result: ExpectedResult { + events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(next_state.clone()), }, - ]; + }]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_success_transition() { + #[tokio::test] + async fn test_maybe_success_transition() { let event = InMemoryTestEvent("foo".to_string()); - let test_cases: Vec< - TestCase<(), PersistedError>, - > = vec![ - // Success + let error_event = InMemoryTestEvent("error event".to_string()); + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + move || MaybeSuccessTransition::success(event.clone(), ()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(()), }, - test: Box::new(move |persister| { - MaybeSuccessTransition::success(event.clone(), ()).save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeSuccessTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransition::transient(InMemoryTestError {}).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || MaybeSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) + }), expected_result: ExpectedResult { - events: vec![InMemoryTestEvent("error event".to_string())], + events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransition::fatal( - InMemoryTestEvent("error event".to_string()), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_transition() { + #[tokio::test] + async fn test_maybe_fatal_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - InMemoryTestState, - PersistedError, - >, - > = vec![ + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || MaybeFatalTransition::success(event.clone(), next_state.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(next_state.clone()), }, - test: Box::new(move |persister| { - MaybeFatalTransition::success(event.clone(), next_state.clone()).save(persister) - }), }, - // Transient error TestCase { - expected_result: ExpectedResult { + make_transition: Box::new(|| MaybeFatalTransition::transient(InMemoryTestError {})), + expected_result: ExpectedResult::< + _, + PersistedError, + > { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransition::transient(InMemoryTestError {}).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || MaybeFatalTransition::fatal(error_event.clone(), InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransition::fatal(error_event.clone(), InMemoryTestError {}) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_success_transition_with_no_results() { + #[tokio::test] + async fn test_maybe_success_transition_with_no_results() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let success_value = "Success value".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let success_value = success_value.clone(); + move || { + MaybeSuccessTransitionWithNoResults::success( + success_value.clone(), + event.clone(), + ) + } + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(OptionalTransitionOutcome::Progress(success_value.clone())), }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::success( - success_value.clone(), - event.clone(), - ) - .save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeSuccessTransitionWithNoResults::no_results(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::no_results(current_state.clone()) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeSuccessTransitionWithNoResults::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::transient(InMemoryTestError {}) - .save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeSuccessTransitionWithNoResults::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeSuccessTransitionWithNoResults::fatal( - error_event.clone(), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_transition_with_no_results() { + #[tokio::test] + async fn test_maybe_fatal_transition_with_no_results() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); let next_state = "Next state".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || { + MaybeFatalTransitionWithNoResults::success( + event.clone(), + next_state.clone(), + ) + } + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Progress(next_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::success(event.clone(), next_state.clone()) - .save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeFatalTransitionWithNoResults::no_results(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::no_results(current_state.clone()) - .save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeFatalTransitionWithNoResults::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalTransitionWithNoResults::fatal( - error_event.clone(), - InMemoryTestError {}, - ) - .save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } - #[test] - fn test_maybe_fatal_or_success_transition() { + #[tokio::test] + async fn test_maybe_fatal_or_success_transition() { let event = InMemoryTestEvent("foo".to_string()); let error_event = InMemoryTestEvent("error event".to_string()); let current_state = "Current state".to_string(); - let test_cases: Vec< - TestCase< - OptionalTransitionOutcome<(), InMemoryTestState>, - PersistedError, - >, - > = vec![ - // Success + + let test_cases = vec![ TestCase { + make_transition: Box::new({ + let event = event.clone(); + move || MaybeFatalOrSuccessTransition::Success(event.clone()) + }), expected_result: ExpectedResult { events: vec![event.clone()], is_closed: true, error: None, success: Some(OptionalTransitionOutcome::Progress(())), }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::Success(event.clone()).save(persister) - }), }, - // No results TestCase { - expected_result: ExpectedResult { + make_transition: Box::new({ + let current_state = current_state.clone(); + move || MaybeFatalOrSuccessTransition::NoResults(current_state.clone()) + }), + expected_result: ExpectedResult::< + OptionalTransitionOutcome<(), InMemoryTestState>, + PersistedError, + > { events: vec![], is_closed: false, error: None, success: Some(OptionalTransitionOutcome::Stasis(current_state.clone())), }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::NoResults(current_state.clone()).save(persister) - }), }, - // Fatal error TestCase { + make_transition: Box::new({ + let error_event = error_event.clone(); + move || { + MaybeFatalOrSuccessTransition::fatal( + error_event.clone(), + InMemoryTestError {}, + ) + } + }), expected_result: ExpectedResult { events: vec![error_event.clone()], is_closed: true, - error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) - .save(persister) - }), }, - // Transient error TestCase { + make_transition: Box::new(|| { + MaybeFatalOrSuccessTransition::transient(InMemoryTestError {}) + }), expected_result: ExpectedResult { events: vec![], is_closed: false, - error: Some(InternalPersistedError::Transient(InMemoryTestError {}).into()), + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), success: None, }, - test: Box::new(move |persister| { - MaybeFatalOrSuccessTransition::transient(InMemoryTestError {}).save(persister) - }), }, ]; - for test in test_cases { - let persister = InMemoryTestPersister::default(); - do_test(&persister, &test); - } + run_test_cases!(test_cases); } #[test] @@ -1225,13 +1412,13 @@ mod tests { // Test Internal API error cases let fatal_error = PersistedError::( - InternalPersistedError::Fatal(api_err.clone()), + InternalPersistedError::Api(ApiError::Fatal(api_err.clone())), ); assert!(fatal_error.storage_error_ref().is_none()); assert!(fatal_error.api_error_ref().is_some()); let transient_error = PersistedError::( - InternalPersistedError::Transient(api_err.clone()), + InternalPersistedError::Api(ApiError::Transient(api_err.clone())), ); assert!(transient_error.storage_error_ref().is_none()); assert!(transient_error.api_error_ref().is_some());