diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 37e104756..58a9d371d 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -6,9 +6,9 @@ use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::persist::OptionalTransitionOutcome; use payjoin::receive::v2::{ - process_err_res, replay_event_log as replay_receiver_event_log, Initialized, MaybeInputsOwned, - MaybeInputsSeen, OutputsUnknown, PayjoinProposal, ProvisionalProposal, ReceiveSession, - Receiver, ReceiverBuilder, SessionHistory, UncheckedOriginalPayload, WantsFeeRange, + replay_event_log as replay_receiver_event_log, HasReplyableError, Initialized, + MaybeInputsOwned, MaybeInputsSeen, OutputsUnknown, PayjoinProposal, ProvisionalProposal, + ReceiveSession, Receiver, ReceiverBuilder, UncheckedOriginalPayload, WantsFeeRange, WantsInputs, WantsOutputs, }; use payjoin::send::v2::{ @@ -70,7 +70,7 @@ impl StatusText for ReceiveSession { | ReceiveSession::WantsFeeRange(_) | ReceiveSession::ProvisionalProposal(_) => "Processing original proposal", ReceiveSession::PayjoinProposal(_) => "Payjoin proposal sent", - ReceiveSession::TerminalFailure => "Session failure", + ReceiveSession::HasReplyableError(_) => "Session failure", } } } @@ -350,13 +350,17 @@ impl AppTrait for App { self.db.get_inactive_send_session_ids()?.into_iter().try_for_each( |(session_id, completed_at)| { let persister = SenderPersister::from_id(self.db.clone(), session_id.clone()); - if let Ok((sender_state, session_history)) = replay_sender_event_log(&persister) { + if let Ok((sender_state, _)) = replay_sender_event_log(&persister) { let row = SessionHistoryRow { session_id, role: Role::Sender, - status: sender_state, + status: sender_state.clone(), completed_at: Some(completed_at), - error_message: session_history.terminal_error(), + error_message: match sender_state { + SendSession::TerminalFailure => + Some("Sender terminal failure".to_string()), + _ => None, + }, }; send_rows.push(row); } @@ -367,14 +371,17 @@ impl AppTrait for App { self.db.get_inactive_recv_session_ids()?.into_iter().try_for_each( |(session_id, completed_at)| { let persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); - if let Ok((receiver_state, session_history)) = replay_receiver_event_log(&persister) - { + if let Ok((receiver_state, _)) = replay_receiver_event_log(&persister) { let row = SessionHistoryRow { session_id, role: Role::Receiver, - status: receiver_state, + status: receiver_state.clone(), completed_at: Some(completed_at), - error_message: session_history.terminal_error().map(|e| e.0), + error_message: match &receiver_state { + ReceiveSession::HasReplyableError(replyable_error) => + Some(replyable_error.error_reply().to_json().to_string()), + _ => None, + }, }; recv_rows.push(row); } @@ -519,22 +526,11 @@ impl App { self.finalize_proposal(proposal, persister).await, ReceiveSession::PayjoinProposal(proposal) => self.send_payjoin_proposal(proposal, persister).await, - ReceiveSession::TerminalFailure => - return Err(anyhow!("Terminal receiver session")), + ReceiveSession::HasReplyableError(error) => + self.handle_error(error, persister).await, } }; - - match res { - Ok(_) => Ok(()), - Err(e) => { - let (_, session_history) = replay_receiver_event_log(persister)?; - let pj_uri = session_history.pj_uri().extras.endpoint().clone(); - let ohttp_relay = self.unwrap_relay_or_else_fetch(Some(pj_uri)).await?; - self.handle_recoverable_error(&ohttp_relay, &session_history).await?; - - Err(e) - } - } + res } #[allow(clippy::incompatible_msrv)] @@ -700,20 +696,14 @@ impl App { Ok(ohttp_relay) } - /// Handle request error by sending an error response over the directory - async fn handle_recoverable_error( + /// Handle error by attempting to send an error response over the directory + async fn handle_error( &self, - ohttp_relay: &payjoin::Url, - session_history: &SessionHistory, + session: Receiver, + persister: &ReceiverPersister, ) -> Result<()> { - let e = match session_history.terminal_error() { - Some((_, Some(e))) => e, - _ => return Ok(()), - }; - let (err_req, err_ctx) = session_history - .extract_err_req(ohttp_relay.as_str())? - .expect("If JsonReply is Some, then err_req and err_ctx should be Some"); - let to_return = anyhow!("Replied with error: {}", e.to_json()); + let (err_req, err_ctx) = + session.create_error_request(self.unwrap_relay_or_else_fetch(None).await?.as_str())?; let err_response = match self.post_request(err_req).await { Ok(response) => response, @@ -725,11 +715,11 @@ impl App { Err(e) => return Err(anyhow!("Failed to get error response bytes: {}", e)), }; - if let Err(e) = process_err_res(&err_bytes, err_ctx) { + if let Err(e) = session.process_error_response(&err_bytes, err_ctx).save(persister) { return Err(anyhow!("Failed to process error response: {}", e)); } - Err(to_return) + Ok(()) } async fn post_request(&self, req: payjoin::Request) -> Result { diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index b282800da..6f87beb4b 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -85,7 +85,7 @@ pub enum ReceiveSession { WantsFeeRange { inner: Arc }, ProvisionalProposal { inner: Arc }, PayjoinProposal { inner: Arc }, - TerminalFailure, + HasReplyableError { inner: Arc }, } impl From for ReceiveSession { @@ -112,7 +112,8 @@ impl From for ReceiveSession { Self::ProvisionalProposal { inner: Arc::new(inner.into()) }, ReceiveSession::PayjoinProposal(inner) => Self::PayjoinProposal { inner: Arc::new(inner.into()) }, - ReceiveSession::TerminalFailure => Self::TerminalFailure, + ReceiveSession::HasReplyableError(inner) => + Self::HasReplyableError { inner: Arc::new(inner.into()) }, } } } @@ -146,19 +147,6 @@ impl From for SessionHistory { fn from(value: payjoin::receive::v2::SessionHistory) -> Self { Self(value) } } -#[derive(uniffi::Object)] -pub struct TerminalErr { - error: String, - reply: Option, -} - -#[uniffi::export] -impl TerminalErr { - pub fn error(&self) -> String { self.error.clone() } - - pub fn reply(&self) -> Option> { self.reply.clone().map(Arc::new) } -} - #[uniffi::export] impl SessionHistory { /// Receiver session Payjoin URI @@ -169,33 +157,10 @@ impl SessionHistory { self.0.psbt_ready_for_signing().map(|psbt| Arc::new(psbt.into())) } - /// Terminal error from the session if present - pub fn terminal_error(&self) -> Option> { - self.0.terminal_error().map(|(error, reply)| { - Arc::new(TerminalErr { error, reply: reply.map(|reply| reply.into()) }) - }) - } - /// Fallback transaction from the session if present pub fn fallback_tx(&self) -> Option> { self.0.fallback_tx().map(|tx| Arc::new(tx.into())) } - - /// Construct the error request to be posted on the directory if an error occurred. - /// To process the response, use [process_err_res] - pub fn extract_err_req( - &self, - ohttp_relay: String, - ) -> Result, SessionError> { - match self.0.extract_err_req(ohttp_relay) { - Ok(Some((request, ctx))) => Ok(Some(RequestResponse { - request: request.into(), - client_response: Arc::new(ctx.into()), - })), - Ok(None) => Ok(None), - Err(e) => Err(SessionError::from(e)), - } - } } #[derive(uniffi::Object)] @@ -502,12 +467,6 @@ impl UncheckedOriginalPayload { } } -/// Process an OHTTP Encapsulated HTTP POST Error response -/// to ensure it has been posted properly -#[uniffi::export] -pub fn process_err_res(body: &[u8], context: &ClientResponse) -> Result<(), SessionError> { - payjoin::receive::v2::process_err_res(body, context.into()).map_err(Into::into) -} #[derive(Clone, uniffi::Object)] pub struct MaybeInputsOwned(payjoin::receive::v2::Receiver); @@ -961,7 +920,7 @@ pub struct PayjoinProposalTransition( payjoin::persist::MaybeSuccessTransition< payjoin::receive::v2::SessionEvent, (), - payjoin::receive::Error, + payjoin::receive::ProtocolError, >, >, >, @@ -1036,6 +995,85 @@ impl PayjoinProposal { } } +#[derive(Clone, uniffi::Object)] +pub struct HasReplyableError( + pub payjoin::receive::v2::Receiver, +); + +impl From + for payjoin::receive::v2::Receiver +{ + fn from(value: HasReplyableError) -> Self { value.0 } +} + +impl From> + for HasReplyableError +{ + fn from( + value: payjoin::receive::v2::Receiver, + ) -> Self { + Self(value) + } +} + +#[derive(uniffi::Object)] +pub struct HasReplyableErrorTransition( + Arc< + RwLock< + Option< + payjoin::persist::MaybeSuccessTransition< + payjoin::receive::v2::SessionEvent, + (), + payjoin::receive::Error, + >, + >, + >, + >, +); + +#[uniffi::export] +impl HasReplyableErrorTransition { + pub fn save( + &self, + persister: Arc, + ) -> Result<(), ReceiverPersistedError> { + let adapter = CallbackPersisterAdapter::new(persister); + let mut inner = + self.0.write().map_err(|_| ImplementationError::from("Lock poisoned".to_string()))?; + + let value = inner + .take() + .ok_or_else(|| ImplementationError::from("Already saved or moved".to_string()))?; + + value.save(&adapter).map_err(ReceiverPersistedError::from)?; + Ok(()) + } +} + +#[uniffi::export] +impl HasReplyableError { + pub fn create_error_request( + &self, + ohttp_relay: String, + ) -> Result { + self.0.clone().create_error_request(ohttp_relay).map_err(Into::into).map(|(req, ctx)| { + RequestResponse { request: req.into(), client_response: Arc::new(ctx.into()) } + }) + } + + pub fn process_error_response( + &self, + body: &[u8], + ohttp_context: &ClientResponse, + ) -> PayjoinProposalTransition { + PayjoinProposalTransition(Arc::new(RwLock::new(Some( + self.0.clone().process_error_response(body, ohttp_context.into()), + )))) + } + + pub fn error_reply(&self) -> String { self.0.error_reply().to_json().to_string() } +} + /// Session persister that should save and load events as JSON strings. #[uniffi::export(with_foreign)] pub trait JsonReceiverSessionPersister: Send + Sync { diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index 8709f45e7..6b01e5fad 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -231,7 +231,7 @@ impl Rejection { /// Represents a fatal rejection of a state transition. /// When this error occurs, the session must be closed and cannot be resumed. -pub struct RejectFatal(Event, Err); +pub struct RejectFatal(pub(crate) Event, pub(crate) Err); /// Represents a transient rejection of a state transition. /// When this error occurs, the session should resume from its current state. pub struct RejectTransient(pub(crate) Err); @@ -336,6 +336,17 @@ pub enum OptionalTransitionOutcome { Stasis(CurrentState), } +/// Sealed trait to prevent external implementation of SessionEventTrait. +pub(crate) mod sealed { + pub trait Sealed {} +} + +/// Trait for session events that determines if an event represents a session closure. +pub trait SessionEventTrait: sealed::Sealed { + /// Returns true if this event represents a session closure that should close the persister. + fn is_closed_event(&self) -> bool; +} + /// A session that can persist events to an append-only log. /// A session represents a sequence of events generated by the BIP78 state machine. /// The events can be replayed from the log to reconstruct the state machine's state. @@ -343,7 +354,7 @@ pub trait SessionPersister { /// Errors that may arise from implementers storage layer type InternalStorageError: std::error::Error + Send + Sync + 'static; /// Session events types that we are persisting - type SessionEvent; + type SessionEvent: SessionEventTrait; /// Appends to list of session updates, Receives generic events fn save_event(&self, event: Self::SessionEvent) -> Result<(), Self::InternalStorageError>; @@ -512,12 +523,14 @@ trait InternalSessionPersister: SessionPersister { Err: std::error::Error, { let RejectFatal(event, error) = reject_fatal; + let should_close = event.is_closed_event(); if let Err(e) = self.save_event(event) { return InternalPersistedError::Storage(e); } - // Session is in a terminal state, close it - if let Err(e) = self.close() { - return InternalPersistedError::Storage(e); + if should_close { + if let Err(e) = self.close() { + return InternalPersistedError::Storage(e); + } } InternalPersistedError::Fatal(error) @@ -531,6 +544,12 @@ impl InternalSessionPersister for T {} #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct NoopPersisterEvent; +impl sealed::Sealed for NoopPersisterEvent {} + +impl SessionEventTrait for NoopPersisterEvent { + fn is_closed_event(&self) -> bool { false } +} + #[derive(Debug, Clone)] pub struct NoopSessionPersister(std::marker::PhantomData); @@ -538,7 +557,7 @@ impl Default for NoopSessionPersister { fn default() -> Self { Self(std::marker::PhantomData) } } -impl SessionPersister for NoopSessionPersister { +impl SessionPersister for NoopSessionPersister { type InternalStorageError = std::convert::Infallible; type SessionEvent = E; @@ -559,7 +578,7 @@ impl SessionPersister for NoopSessionPersister { pub mod test_utils { use std::sync::{Arc, RwLock}; - use crate::persist::SessionPersister; + use crate::persist::{SessionEventTrait, SessionPersister}; #[derive(Clone)] /// In-memory session persister for testing session replays and introspecting session events @@ -583,7 +602,7 @@ pub mod test_utils { impl SessionPersister for InMemoryTestPersister where - V: Clone + 'static, + V: Clone + 'static + SessionEventTrait, { type InternalStorageError = std::convert::Infallible; type SessionEvent = V; @@ -625,6 +644,12 @@ mod tests { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InMemoryTestEvent(String); + impl crate::persist::sealed::Sealed for InMemoryTestEvent {} + + impl crate::persist::SessionEventTrait for InMemoryTestEvent { + fn is_closed_event(&self) -> bool { self.0 == "error close" } + } + #[derive(Debug, Clone, PartialEq)] /// Dummy error type for testing struct InMemoryTestError {} @@ -783,6 +808,7 @@ mod tests { #[test] fn test_maybe_success_transition() { let event = InMemoryTestEvent("foo".to_string()); + let error_event = InMemoryTestEvent("error close".to_string()); let test_cases: Vec< TestCase<(), PersistedError>, > = vec![ @@ -813,17 +839,14 @@ mod tests { // Fatal error TestCase { expected_result: ExpectedResult { - events: vec![InMemoryTestEvent("error event".to_string())], + events: vec![error_event.clone()], is_closed: true, error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), success: None, }, test: Box::new(move |persister| { - MaybeSuccessTransition::fatal( - InMemoryTestEvent("error event".to_string()), - InMemoryTestError {}, - ) - .save(persister) + MaybeSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) + .save(persister) }), }, ]; @@ -873,7 +896,7 @@ mod tests { TestCase { expected_result: ExpectedResult { events: vec![error_event.clone()], - is_closed: true, + is_closed: false, error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), success: None, }, @@ -893,7 +916,7 @@ mod tests { #[test] fn test_maybe_success_transition_with_no_results() { let event = InMemoryTestEvent("foo".to_string()); - let error_event = InMemoryTestEvent("error event".to_string()); + let error_event = InMemoryTestEvent("error close".to_string()); let current_state = "Current state".to_string(); let success_value = "Success value".to_string(); let test_cases: Vec< @@ -1010,7 +1033,7 @@ mod tests { TestCase { expected_result: ExpectedResult { events: vec![error_event.clone()], - is_closed: true, + is_closed: false, error: Some(InternalPersistedError::Fatal(InMemoryTestError {}).into()), success: None, }, diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 59dec2dc6..64722d46d 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -34,7 +34,7 @@ pub(crate) use error::InternalSessionError; pub use error::SessionError; use serde::de::Deserializer; use serde::{Deserialize, Serialize}; -pub use session::{replay_event_log, SessionEvent, SessionHistory}; +pub use session::{replay_event_log, SessionEvent, SessionHistory, SessionStatus}; use url::Url; use super::error::{Error, InputContributionError}; @@ -123,7 +123,7 @@ fn short_id_from_pubkey(pubkey: &HpkePublicKey) -> ShortId { } /// Represents the various states of a Payjoin receiver session during the protocol flow. -/// Each variant parameterizes a `Receiver` with a specific state type, and [`ReceiveSession::TerminalFailure`] which indicates the session has ended or is invalid. +/// Each variant parameterizes a `Receiver` with a specific state type. /// /// This provides type erasure for the receive session state, allowing for the session to be replayed /// and the state to be updated with the next event over a uniform interface. @@ -139,7 +139,7 @@ pub enum ReceiveSession { WantsFeeRange(Receiver), ProvisionalProposal(Receiver), PayjoinProposal(Receiver), - TerminalFailure, + HasReplyableError(Receiver), } impl ReceiveSession { @@ -185,7 +185,23 @@ impl ReceiveSession { SessionEvent::PayjoinProposal(payjoin_proposal), ) => Ok(state.apply_payjoin_proposal(payjoin_proposal)), - (_, SessionEvent::SessionInvalid(_, _)) => Ok(ReceiveSession::TerminalFailure), + (session, SessionEvent::HasReplyableError(error)) => + Ok(ReceiveSession::HasReplyableError(Receiver { + state: HasReplyableError { error_reply: error.clone() }, + session_context: match session { + ReceiveSession::Initialized(r) => r.session_context, + ReceiveSession::UncheckedOriginalPayload(r) => r.session_context, + ReceiveSession::MaybeInputsOwned(r) => r.session_context, + ReceiveSession::MaybeInputsSeen(r) => r.session_context, + ReceiveSession::OutputsUnknown(r) => r.session_context, + ReceiveSession::WantsOutputs(r) => r.session_context, + ReceiveSession::WantsInputs(r) => r.session_context, + ReceiveSession::WantsFeeRange(r) => r.session_context, + ReceiveSession::ProvisionalProposal(r) => r.session_context, + ReceiveSession::PayjoinProposal(r) => r.session_context, + ReceiveSession::HasReplyableError(r) => r.session_context, + }, + })), (current_state, SessionEvent::Closed(_)) => Ok(current_state), @@ -211,6 +227,7 @@ mod sealed { impl State for super::WantsFeeRange {} impl State for super::ProvisionalProposal {} impl State for super::PayjoinProposal {} + impl State for super::HasReplyableError {} } /// Sealed trait for V2 receive session states. @@ -250,34 +267,6 @@ impl core::ops::DerefMut for Receiver { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } } -/// Construct an OHTTP Encapsulated HTTP POST request to return -/// a Receiver Error Response -fn extract_err_req( - err: &JsonReply, - ohttp_relay: impl IntoUrl, - session_context: &SessionContext, -) -> Result<(Request, ohttp::ClientResponse), SessionError> { - if session_context.expiration.elapsed() { - return Err(InternalSessionError::Expired(session_context.expiration).into()); - } - let mailbox = mailbox_endpoint(&session_context.directory, &session_context.reply_mailbox_id()); - let (body, ohttp_ctx) = ohttp_encapsulate( - &session_context.ohttp_keys.0, - "POST", - mailbox.as_str(), - Some(err.to_json().to_string().as_bytes()), - ) - .map_err(InternalSessionError::OhttpEncapsulation)?; - let req = Request::new_v2(&session_context.full_relay_url(ohttp_relay)?, &body); - Ok((req, ohttp_ctx)) -} - -/// Process an OHTTP Encapsulated HTTP POST Error response -/// to ensure it has been posted properly -pub fn process_err_res(body: &[u8], context: ohttp::ClientResponse) -> Result<(), SessionError> { - process_post_res(body, context).map_err(|e| InternalSessionError::DirectoryResponse(e).into()) -} - #[derive(Debug, Clone)] pub struct ReceiverBuilder(SessionContext); @@ -377,12 +366,13 @@ impl Receiver { let proposal = match self.inner_process_res(body, context) { Ok(proposal) => proposal, Err(e) => match e { - // Implementation errors should be unreachable + // Fatal V2 errors result in immediate session closure because it is infeasible + // to reply to the sender ProtocolError::V2(ref session_error) => match session_error { SessionError(InternalSessionError::DirectoryResponse(directory_error)) => if directory_error.is_fatal() { return MaybeFatalTransitionWithNoResults::fatal( - SessionEvent::SessionInvalid(e.to_string(), None), + SessionEvent::Closed(SessionOutcome::Failure), e, ); } else { @@ -390,13 +380,14 @@ impl Receiver { }, _ => return MaybeFatalTransitionWithNoResults::fatal( - SessionEvent::SessionInvalid(session_error.to_string(), None), + SessionEvent::Closed(SessionOutcome::Failure), e, ), }, + // Replyable errors should not close the session until a reply was attempted _ => return MaybeFatalTransitionWithNoResults::fatal( - SessionEvent::SessionInvalid(e.to_string(), None), + SessionEvent::HasReplyableError((&e).into()), e, ), }, @@ -566,10 +557,7 @@ impl Receiver { ), Err(Error::Implementation(e)) => MaybeFatalTransition::transient(Error::Implementation(e)), - Err(e) => MaybeFatalTransition::fatal( - SessionEvent::SessionInvalid(e.to_string(), Some(JsonReply::from(&e))), - e, - ), + Err(e) => MaybeFatalTransition::fatal(SessionEvent::HasReplyableError((&e).into()), e), } } @@ -635,7 +623,7 @@ impl Receiver { } _ => { return MaybeFatalTransition::fatal( - SessionEvent::SessionInvalid(e.to_string(), Some(JsonReply::from(&e))), + SessionEvent::HasReplyableError((&e).into()), e, ); } @@ -688,7 +676,7 @@ impl Receiver { } _ => { return MaybeFatalTransition::fatal( - SessionEvent::SessionInvalid(e.to_string(), Some(JsonReply::from(&e))), + SessionEvent::HasReplyableError((&e).into()), e, ); } @@ -746,7 +734,7 @@ impl Receiver { } _ => { return MaybeFatalTransition::fatal( - SessionEvent::SessionInvalid(e.to_string(), Some(JsonReply::from(&e))), + SessionEvent::HasReplyableError((&e).into()), e, ); } @@ -935,15 +923,7 @@ impl Receiver { { Ok(inner) => inner, Err(e) => { - // FIXME: follow up by returning a terminal error rather than replyable error - let payload_error = super::PayloadError::from(e); - return MaybeFatalTransition::fatal( - SessionEvent::SessionInvalid( - payload_error.to_string(), - Some(JsonReply::from(&payload_error)), - ), - ProtocolError::OriginalPayload(payload_error), - ); + return MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())); } }; MaybeFatalTransition::success( @@ -1071,23 +1051,79 @@ impl Receiver { self, res: &[u8], ohttp_context: ohttp::ClientResponse, - ) -> MaybeSuccessTransition { + ) -> MaybeSuccessTransition { match process_post_res(res, ohttp_context) { Ok(_) => MaybeSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Success), ()), Err(e) => if e.is_fatal() { MaybeSuccessTransition::fatal( - SessionEvent::SessionInvalid(e.to_string(), None), - InternalSessionError::DirectoryResponse(e).into(), + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), ) } else { - MaybeSuccessTransition::transient( + MaybeSuccessTransition::transient(ProtocolError::V2( InternalSessionError::DirectoryResponse(e).into(), + )) + }, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct HasReplyableError { + error_reply: JsonReply, +} + +impl Receiver { + /// Construct an OHTTP Encapsulated HTTP POST request to return + /// a Receiver Error Response + pub fn create_error_request( + &self, + ohttp_relay: impl IntoUrl, + ) -> Result<(Request, ohttp::ClientResponse), SessionError> { + let session_context = &self.session_context; + if session_context.expiration.elapsed() { + return Err(InternalSessionError::Expired(session_context.expiration).into()); + } + let mailbox = + mailbox_endpoint(&session_context.directory, &session_context.reply_mailbox_id()); + let (body, ohttp_ctx) = ohttp_encapsulate( + &session_context.ohttp_keys.0, + "POST", + mailbox.as_str(), + Some(self.error_reply.to_json().to_string().as_bytes()), + ) + .map_err(InternalSessionError::OhttpEncapsulation)?; + let req = Request::new_v2(&session_context.full_relay_url(ohttp_relay)?, &body); + Ok((req, ohttp_ctx)) + } + + /// Process an OHTTP Encapsulated HTTP POST Error response + /// to ensure it has been posted properly + pub fn process_error_response( + &self, + res: &[u8], + ohttp_context: ohttp::ClientResponse, + ) -> MaybeSuccessTransition { + match process_post_res(res, ohttp_context) { + Ok(_) => + MaybeSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Failure), ()), + Err(e) => + if e.is_fatal() { + MaybeSuccessTransition::fatal( + SessionEvent::Closed(SessionOutcome::Failure), + ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), ) + } else { + MaybeSuccessTransition::transient(ProtocolError::V2( + InternalSessionError::DirectoryResponse(e).into(), + )) }, } } + + pub fn error_reply(&self) -> &JsonReply { &self.error_reply } } /// Derive a mailbox endpoint on a directory given a [`ShortId`]. @@ -1135,7 +1171,7 @@ pub mod test { use super::*; use crate::output_substitution::OutputSubstitution; - use crate::persist::{NoopSessionPersister, RejectTransient, Rejection}; + use crate::persist::{NoopSessionPersister, RejectFatal, RejectTransient, Rejection}; use crate::receive::optional_parameters::Params; use crate::receive::v2; use crate::ImplementationError; @@ -1174,7 +1210,7 @@ pub mod test { } } - pub(crate) fn mock_err() -> (String, JsonReply) { + pub(crate) fn mock_err() -> JsonReply { let noop_persister = NoopSessionPersister::default(); let receiver = Receiver { state: unchecked_proposal_v2_from_test_vector(), @@ -1189,8 +1225,7 @@ pub mod test { let error = server_error().expect_err("Server error should be populated with mock error"); let res = error.api_error().expect("check_broadcast error should propagate to api error"); - let actual_json = JsonReply::from(&res); - (res.to_string(), actual_json) + JsonReply::from(&res) } #[test] @@ -1248,6 +1283,41 @@ pub mod test { Ok(()) } + #[test] + fn test_unchecked_proposal_fatal_error() -> Result<(), BoxError> { + let unchecked_proposal = unchecked_proposal_v2_from_test_vector(); + let receiver = + v2::Receiver { state: unchecked_proposal, session_context: SHARED_CONTEXT.clone() }; + + let receive_session = ReceiveSession::UncheckedOriginalPayload(receiver.clone()); + let unchecked_proposal = + receiver.check_broadcast_suitability(Some(FeeRate::MIN), |_| Ok(false)); + + let event = match &unchecked_proposal { + MaybeFatalTransition(Err(Rejection::Fatal(RejectFatal( + event, + Error::Protocol(error), + )))) => { + assert_eq!( + error.to_string(), + InternalPayloadError::OriginalPsbtNotBroadcastable.to_string() + ); + event.clone() + } + _ => panic!("Expected fatal error"), + }; + + let has_error = match receive_session.process_event(event) { + Ok(ReceiveSession::HasReplyableError(r)) => r, + _ => panic!("Expected HasError"), + }; + + let _err_req = has_error.create_error_request(EXAMPLE_URL.as_str())?; + // TODO: assert process_error_response terminally closes session + + Ok(()) + } + #[test] fn test_maybe_inputs_seen_transient_error() -> Result<(), BoxError> { let persister = NoopSessionPersister::default(); @@ -1343,55 +1413,35 @@ pub mod test { } #[test] - fn test_extract_err_req() -> Result<(), BoxError> { - let receiver = Receiver { - state: unchecked_proposal_v2_from_test_vector(), - session_context: SHARED_CONTEXT.clone(), - }; + fn test_create_error_request() -> Result<(), BoxError> { let mock_err = mock_err(); let expected_json = serde_json::json!({ "errorCode": "unavailable", "message": "Receiver error" }); - assert_eq!(mock_err.1.to_json(), expected_json); + assert_eq!(mock_err.to_json(), expected_json); - let (_req, _ctx) = - extract_err_req(&mock_err.1, EXAMPLE_URL.as_str(), &receiver.session_context)?; + let receiver = Receiver { + state: HasReplyableError { error_reply: mock_err.clone() }, + session_context: SHARED_CONTEXT.clone(), + }; + + let (_req, _ctx) = receiver.create_error_request(EXAMPLE_URL.as_str())?; - let internal_error: Error = InternalPayloadError::MissingPayment.into(); - let (_req, _ctx) = extract_err_req( - &(&internal_error).into(), - EXAMPLE_URL.as_str(), - &receiver.session_context, - )?; Ok(()) } #[test] - fn test_extract_err_req_expiration() -> Result<(), BoxError> { + fn test_create_error_request_expiration() -> Result<(), BoxError> { let now = crate::time::Time::now(); - let noop_persister = NoopSessionPersister::default(); let context = SessionContext { expiration: now, ..SHARED_CONTEXT.clone() }; let receiver = Receiver { - state: UncheckedOriginalPayload { - original: crate::receive::tests::original_from_test_vector(), - }, + state: HasReplyableError { error_reply: mock_err() }, session_context: context.clone(), }; - let server_error = || { - receiver - .clone() - .check_broadcast_suitability(None, |_| Err("mock error".into())) - .save(&noop_persister) - }; - - let error = server_error().expect_err("Server error should be populated with mock error"); - let res = error.api_error().expect("check_broadcast error should propagate to api error"); - let actual_json = JsonReply::from(&res); - - let expiration = extract_err_req(&actual_json, EXAMPLE_URL.as_str(), &context); + let expiration = receiver.create_error_request(EXAMPLE_URL.as_str()); match expiration { Err(error) => assert_eq!( diff --git a/payjoin/src/core/receive/v2/session.rs b/payjoin/src/core/receive/v2/session.rs index 8e8de81df..1df493373 100644 --- a/payjoin/src/core/receive/v2/session.rs +++ b/payjoin/src/core/receive/v2/session.rs @@ -3,10 +3,9 @@ use serde::{Deserialize, Serialize}; use super::{ReceiveSession, SessionContext}; use crate::error::{InternalReplayError, ReplayError}; use crate::output_substitution::OutputSubstitution; -use crate::persist::SessionPersister; -use crate::receive::v2::{extract_err_req, SessionError}; +use crate::persist::{SessionEventTrait, SessionPersister}; use crate::receive::{common, InputPair, JsonReply, OriginalPayload, PsbtContext}; -use crate::{ImplementationError, IntoUrl, PjUri, Request}; +use crate::{ImplementationError, PjUri}; /// Replay a receiver event log to get the receiver in its current state [ReceiveSession] /// and a session history [SessionHistory] @@ -104,43 +103,6 @@ impl SessionHistory { }) } - /// Terminal error from the session if present - pub fn terminal_error(&self) -> Option<(String, Option)> { - self.events.iter().find_map(|event| match event { - SessionEvent::SessionInvalid(err_str, reply) => Some((err_str.clone(), reply.clone())), - _ => None, - }) - } - - /// Construct the error request to be posted on the directory if an error occurred. - /// To process the response, use [crate::receive::v2::process_err_res] - pub fn extract_err_req( - &self, - ohttp_relay: impl IntoUrl, - ) -> Result, SessionError> { - // FIXME ideally this should be more like a method of - // Receiver and subsequent states instead of the - // history as a whole since it doesn't make sense to call it before, - // reaching that state. - if !self.received_sender_proposal() { - return Ok(None); - } - - let session_context = self.session_context(); - let json_reply = match self.terminal_error() { - Some((_, Some(json_reply))) => json_reply, - _ => return Ok(None), - }; - let (req, ctx) = extract_err_req(&json_reply, ohttp_relay, &session_context)?; - Ok(Some((req, ctx))) - } - - fn received_sender_proposal(&self) -> bool { - self.events - .iter() - .any(|event| matches!(event, SessionEvent::UncheckedOriginalPayload { .. })) - } - fn session_context(&self) -> SessionContext { let mut initial_session_context = self .events @@ -189,10 +151,7 @@ pub enum SessionStatus { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum SessionEvent { Created(SessionContext), - UncheckedOriginalPayload { - original: OriginalPayload, - reply_key: Option, - }, + UncheckedOriginalPayload { original: OriginalPayload, reply_key: Option }, MaybeInputsOwned(), MaybeInputsSeen(), OutputsUnknown(), @@ -201,11 +160,7 @@ pub enum SessionEvent { WantsFeeRange(Vec), ProvisionalProposal(PsbtContext), PayjoinProposal(bitcoin::Psbt), - /// Session is invalid. This is a irrecoverable error. Fallback tx should be broadcasted. - /// TODO this should be any error type that is impl std::error and works well with serde, or as a fallback can be formatted as a string - /// Reason being in some cases we still want to preserve the error b/c we can action on it. For now this is a terminal state and there is nothing to replay and is saved to be displayed. - /// b/c its a terminal state and there is nothing to replay. So serialization will be lossy and that is fine. - SessionInvalid(String, Option), + HasReplyableError(JsonReply), Closed(SessionOutcome), } @@ -222,6 +177,12 @@ pub enum SessionOutcome { Cancel, } +impl crate::persist::sealed::Sealed for SessionEvent {} + +impl SessionEventTrait for SessionEvent { + fn is_closed_event(&self) -> bool { matches!(self, SessionEvent::Closed(_)) } +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime}; @@ -234,9 +195,10 @@ mod tests { use crate::receive::tests::original_from_test_vector; use crate::receive::v2::test::{mock_err, SHARED_CONTEXT}; use crate::receive::v2::{ - Initialized, MaybeInputsOwned, PayjoinProposal, ProvisionalProposal, Receiver, - UncheckedOriginalPayload, + HasReplyableError, Initialized, MaybeInputsOwned, PayjoinProposal, ProvisionalProposal, + Receiver, UncheckedOriginalPayload, }; + use crate::receive::{InternalPayloadError, PayloadError}; fn unchecked_receiver_from_test_vector() -> Receiver { Receiver { @@ -301,6 +263,7 @@ mod tests { SessionEvent::WantsFeeRange(wants_fee_range.state.inner.receiver_inputs.clone()), SessionEvent::ProvisionalProposal(provisional_proposal.state.psbt_context.clone()), SessionEvent::PayjoinProposal(payjoin_proposal.psbt().clone()), + SessionEvent::HasReplyableError(mock_err()), ]; for event in test_cases { @@ -613,87 +576,90 @@ mod tests { } #[test] - fn test_session_history_uri() -> Result<(), BoxError> { + fn test_session_fatal_error() -> Result<(), BoxError> { + let persister = NoopSessionPersister::::default(); let session_context = SHARED_CONTEXT.clone(); - let events = vec![SessionEvent::Created(session_context.clone())]; + let mut events = vec![]; - let uri = SessionHistory { events }.pj_uri(); + let original = original_from_test_vector(); + // Original PSBT is not broadcastable + let _unbroadcastable = unchecked_receiver_from_test_vector() + .check_broadcast_suitability(None, |_| Ok(false)) + .save(&persister) + .expect_err("Unbroadcastable should error"); + // NOTE: it would be good to assert against the internal error type but InternalPersistedError is private + let expected_error = PayloadError(InternalPayloadError::OriginalPsbtNotBroadcastable); + let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1); - assert_ne!(uri.extras.pj_param.endpoint(), EXAMPLE_URL.clone()); - assert_eq!(uri.extras.output_substitution, OutputSubstitution::Disabled); + events.push(SessionEvent::Created(session_context.clone())); + events.push(SessionEvent::UncheckedOriginalPayload { + original: original.clone(), + reply_key: reply_key.clone(), + }); + events.push(SessionEvent::HasReplyableError((&expected_error).into())); + events.push(SessionEvent::Closed(SessionOutcome::Failure)); - Ok(()) + let test = SessionHistoryTest { + events, + expected_session_history: SessionHistoryExpectedOutcome { + psbt_with_fee_contributions: None, + fallback_tx: None, + expected_status: SessionStatus::Failed, + }, + expected_receiver_state: ReceiveSession::HasReplyableError(Receiver { + state: HasReplyableError { error_reply: (&expected_error).into() }, + session_context: SessionContext { reply_key, ..session_context }, + }), + }; + run_session_history_test(test) } #[test] - fn test_skipped_session_extract_err_request() -> Result<(), BoxError> { - let ohttp_relay = EXAMPLE_URL.as_str(); - let mock_err = mock_err(); + fn test_session_transient_error() -> Result<(), BoxError> { + let persister = NoopSessionPersister::::default(); + let session_context = SHARED_CONTEXT.clone(); + let mut events = vec![]; - let session_history = SessionHistory { events: vec![SessionEvent::MaybeInputsOwned()] }; - let err_req = session_history.extract_err_req(ohttp_relay)?; - assert!(err_req.is_none()); + let original = original_from_test_vector(); + // Mock some implementation error + let _maybe_broadcastable = unchecked_receiver_from_test_vector() + .check_broadcast_suitability(None, |_| Err("mock error".into())) + .save(&persister) + .expect_err("Mock error should error"); + // NOTE: it would be good to assert against the internal error type but InternalPersistedError is private - let session_history = SessionHistory { - events: vec![ - SessionEvent::MaybeInputsOwned(), - SessionEvent::SessionInvalid(mock_err.0.clone(), Some(mock_err.1.clone())), - ], - }; + let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1); - let err_req = session_history.extract_err_req(ohttp_relay)?; - assert!(err_req.is_none()); + events.push(SessionEvent::Created(session_context.clone())); + events.push(SessionEvent::UncheckedOriginalPayload { + original: original.clone(), + reply_key: reply_key.clone(), + }); - let session_history = SessionHistory { - events: vec![ - SessionEvent::Created(SHARED_CONTEXT.clone()), - SessionEvent::MaybeInputsOwned(), - SessionEvent::SessionInvalid(mock_err.0.clone(), Some(mock_err.1.clone())), - ], + let test = SessionHistoryTest { + events, + expected_session_history: SessionHistoryExpectedOutcome { + psbt_with_fee_contributions: None, + fallback_tx: None, + expected_status: SessionStatus::Active, + }, + expected_receiver_state: ReceiveSession::UncheckedOriginalPayload(Receiver { + state: UncheckedOriginalPayload { original }, + session_context: SessionContext { reply_key, ..session_context }, + }), }; - - let err_req = session_history.extract_err_req(ohttp_relay)?; - assert!(err_req.is_none()); - Ok(()) + run_session_history_test(test) } #[test] - fn test_session_extract_err_req_reply_key() -> Result<(), BoxError> { - let proposal = original_from_test_vector(); - let ohttp_relay = EXAMPLE_URL.as_str(); - let mock_err = mock_err(); - - let session_history_one = SessionHistory { - events: vec![ - SessionEvent::Created(SHARED_CONTEXT.clone()), - SessionEvent::UncheckedOriginalPayload { - original: proposal.clone(), - reply_key: Some(crate::HpkeKeyPair::gen_keypair().1), - }, - SessionEvent::SessionInvalid(mock_err.0.clone(), Some(mock_err.1.clone())), - ], - }; - - let err_req_one = session_history_one.extract_err_req(ohttp_relay)?; - assert!(err_req_one.is_some()); + fn test_session_history_uri() -> Result<(), BoxError> { + let session_context = SHARED_CONTEXT.clone(); + let events = vec![SessionEvent::Created(session_context.clone())]; - let session_history_two = SessionHistory { - events: vec![ - SessionEvent::Created(SHARED_CONTEXT.clone()), - SessionEvent::UncheckedOriginalPayload { - original: proposal.clone(), - reply_key: Some(crate::HpkeKeyPair::gen_keypair().1), - }, - SessionEvent::SessionInvalid(mock_err.0, Some(mock_err.1)), - ], - }; + let uri = SessionHistory { events }.pj_uri(); - let err_req_two = session_history_two.extract_err_req(ohttp_relay)?; - assert!(err_req_two.is_some()); - assert_ne!( - session_history_one.session_context().reply_key, - session_history_two.session_context().reply_key - ); + assert_ne!(uri.extras.pj_param.endpoint(), EXAMPLE_URL.clone()); + assert_eq!(uri.extras.output_substitution, OutputSubstitution::Disabled); Ok(()) } diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index 0834e8471..845e31204 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -1,6 +1,6 @@ use super::WithReplyKey; use crate::error::{InternalReplayError, ReplayError}; -use crate::persist::SessionPersister; +use crate::persist::{SessionEventTrait, SessionPersister}; use crate::send::v2::{PollingForProposal, SendSession}; use crate::uri::v2::PjParam; use crate::ImplementationError; @@ -77,13 +77,6 @@ impl SessionHistory { }) .expect("Session event log must contain at least one event with pj_param") } - - pub fn terminal_error(&self) -> Option { - self.events.iter().find_map(|event| match event { - SessionEvent::SessionInvalid(error) => Some(error.clone()), - _ => None, - }) - } } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -98,6 +91,16 @@ pub enum SessionEvent { SessionInvalid(String), } +impl crate::persist::sealed::Sealed for SessionEvent {} + +impl SessionEventTrait for SessionEvent { + fn is_closed_event(&self) -> bool { + // Sender doesn't have a Closed variant yet, so always return true for now. + // This maintains current behavior where all fatal events close the session. + true + } +} + #[cfg(test)] mod tests { use bitcoin::{FeeRate, ScriptBuf}; @@ -180,7 +183,6 @@ mod tests { events: Vec, expected_session_history: SessionHistoryExpectedOutcome, expected_sender_state: SendSession, - expected_error: Option, } fn run_session_history_test(test: SessionHistoryTest) { @@ -194,7 +196,6 @@ mod tests { assert_eq!(sender, test.expected_sender_state); assert_eq!(session_history.fallback_tx(), test.expected_session_history.fallback_tx); assert_eq!(*session_history.pj_param(), test.expected_session_history.pj_param); - assert_eq!(session_history.terminal_error(), test.expected_error); } #[test] @@ -268,7 +269,7 @@ mod tests { crate::OhttpKeys( ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"), ), - HpkeKeyPair::gen_keypair().1, + reply_key.1, ); let with_reply_key = WithReplyKey { pj_param: pj_param.clone(), @@ -280,7 +281,6 @@ mod tests { events: vec![SessionEvent::CreatedReplyKey(with_reply_key)], expected_session_history: SessionHistoryExpectedOutcome { fallback_tx, pj_param }, expected_sender_state: SendSession::WithReplyKey(sender), - expected_error: None, }; run_session_history_test(test); } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index b3936536a..9bd6c868d 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -198,8 +198,8 @@ mod integration { use http::StatusCode; use payjoin::persist::{NoopSessionPersister, OptionalTransitionOutcome}; use payjoin::receive::v2::{ - replay_event_log as replay_receiver_event_log, PayjoinProposal, Receiver, - ReceiverBuilder, UncheckedOriginalPayload, + replay_event_log as replay_receiver_event_log, PayjoinProposal, ReceiveSession, + Receiver, ReceiverBuilder, SessionStatus, UncheckedOriginalPayload, }; use payjoin::send::v2::SenderBuilder; use payjoin::{OhttpKeys, PjUri, UriExt}; @@ -317,12 +317,12 @@ mod integration { let result = tokio::select!( err = services.take_ohttp_relay_handle() => panic!("Ohttp relay exited early: {:?}", err), err = services.take_directory_handle() => panic!("Directory server exited early: {:?}", err), - res = process_err_res(&services) => res + res = do_err_test(&services) => res ); assert!(result.is_ok(), "v2 send receive failed: {:#?}", result.unwrap_err()); - async fn process_err_res(services: &TestServices) -> Result<(), BoxError> { + async fn do_err_test(services: &TestServices) -> Result<(), BoxError> { let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; let agent = services.http_agent(); services.wait_for_services_ready().await?; @@ -419,10 +419,14 @@ mod integration { "Protocol error: Can't broadcast. PSBT rejected by mempool." ); - let (_, session_history) = replay_receiver_event_log(&persister)?; - let (err_req, err_ctx) = session_history - .extract_err_req(services.ohttp_relay_url().as_str())? - .expect("error request should exist"); + let (session, session_history) = replay_receiver_event_log(&persister)?; + assert_eq!(session_history.status(), SessionStatus::Active); + let has_error = match session { + ReceiveSession::HasReplyableError(r) => r, + _ => panic!("Expected HasError"), + }; + let (err_req, err_ctx) = + has_error.create_error_request(services.ohttp_relay_url().as_str())?; let err_response = agent .post(err_req.url) .header("Content-Type", err_req.content_type) @@ -431,8 +435,16 @@ mod integration { .await?; let err_bytes = err_response.bytes().await?; - // Ensure that the error was handled properly - assert!(payjoin::receive::v2::process_err_res(&err_bytes, err_ctx).is_ok()); + has_error.process_error_response(&err_bytes, err_ctx).save(&persister)?; + + // Ensure the session is closed properly + let (receiver_state, session_history) = replay_receiver_event_log(&persister)?; + assert_eq!(session_history.status(), SessionStatus::Failed); + match receiver_state { + ReceiveSession::HasReplyableError(e) => assert_eq!(e, has_error), + _ => panic!("Expected HasReplyableError"), + }; + // TODO: Sender should retrieve the error response to complete the error flow Ok(()) }