From 1c0233cf831227ee62719fe3f514347e1716f060 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 16 Mar 2026 08:13:07 +0100 Subject: [PATCH 01/14] Refactor each session state into its own module --- crates/hotfix/src/session.rs | 22 +- crates/hotfix/src/session/state.rs | 221 +++--------------- crates/hotfix/src/session/state/active.rs | 19 ++ .../src/session/state/awaiting_logon.rs | 8 + .../src/session/state/awaiting_logout.rs | 8 + .../src/session/state/awaiting_resend.rs | 121 ++++++++++ .../hotfix/src/session/state/disconnected.rs | 35 +++ 7 files changed, 237 insertions(+), 197 deletions(-) create mode 100644 crates/hotfix/src/session/state/active.rs create mode 100644 crates/hotfix/src/session/state/awaiting_logon.rs create mode 100644 crates/hotfix/src/session/state/awaiting_logout.rs create mode 100644 crates/hotfix/src/session/state/awaiting_resend.rs create mode 100644 crates/hotfix/src/session/state/disconnected.rs diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 3926bc74..f7ed64df 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -45,7 +45,9 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{AwaitingResendTransitionOutcome, TestRequestId}; +use crate::session::state::{ + AwaitingLogonState, AwaitingLogoutState, AwaitingResendTransitionOutcome, TestRequestId, +}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; @@ -200,7 +202,7 @@ where } } - if let SessionState::AwaitingLogon { .. } = &mut self.state { + if let SessionState::AwaitingLogon(_) = &mut self.state { // TODO: should this (and all inbound message processing) logic be pushed into the state? if message_type != Logon::MSG_TYPE { self.state.disconnect_writer().await; @@ -332,11 +334,11 @@ where } async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { - self.state = SessionState::AwaitingLogon { + self.state = SessionState::AwaitingLogon(AwaitingLogonState { writer, logon_sent: false, logon_timeout: Instant::now() + Duration::from_secs(self.config.logon_timeout), - }; + }); self.reset_peer_timer(None); self.send_logon().await?; @@ -345,23 +347,23 @@ where async fn on_disconnect(&mut self, reason: String) { match self.state { - SessionState::Active { .. } - | SessionState::AwaitingLogon { .. } + SessionState::Active(_) + | SessionState::AwaitingLogon(_) | SessionState::AwaitingResend(_) => { self.state.disconnect_writer().await; self.state = SessionState::new_disconnected(true, &reason); } - SessionState::Disconnected { .. } => { + SessionState::Disconnected(_) => { warn!("disconnect message was received, but the session is already disconnected") } - SessionState::AwaitingLogout { reconnect, .. } => { + SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { self.state = SessionState::new_disconnected(reconnect, &reason); } } } async fn on_logon(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let SessionState::AwaitingLogon { writer, .. } = &self.state { + if let SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) = &self.state { match self.verify_message(message, true, true) { Ok(_) => { // happy logon flow, the session is now active @@ -395,7 +397,7 @@ where // if the session is already disconnected, we have nothing else to do SessionState::Disconnected(..) => {} // if we initiated the logout, preserve the reconnect flag - SessionState::AwaitingLogout { reconnect, .. } => { + SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { self.state.disconnect_writer().await; self.state = SessionState::new_disconnected(reconnect, "logout completed"); } diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index fa84472d..87abd5fa 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -1,36 +1,37 @@ +mod active; +mod awaiting_logon; +mod awaiting_logout; +mod awaiting_resend; +mod disconnected; + +pub(crate) use active::{ActiveState, calculate_peer_interval}; +pub(crate) use awaiting_logon::AwaitingLogonState; +pub(crate) use awaiting_logout::AwaitingLogoutState; +pub(crate) use awaiting_resend::{AwaitingResendState, AwaitingResendTransitionOutcome}; +pub(crate) use disconnected::DisconnectedState; + use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; use crate::session::event::AwaitingActiveSessionResponse; use crate::session::info::Status as SessionInfoStatus; use crate::transport::writer::WriterRef; -use hotfix_message::message::Message; -use std::collections::VecDeque; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; use tracing::{debug, error}; const TEST_REQUEST_THRESHOLD: f64 = 1.2; -const MAX_RESEND_ATTEMPTS: usize = 3; pub(crate) type TestRequestId = String; pub enum SessionState { /// We have established a connection, sent a logon message and await a response. - AwaitingLogon { - writer: WriterRef, - logon_sent: bool, - logon_timeout: Instant, - }, + AwaitingLogon(AwaitingLogonState), /// We are awaiting the target to resend the gap we have. AwaitingResend(AwaitingResendState), /// We are in the process of gracefully logging out - AwaitingLogout { - writer: WriterRef, // we need the writer so we can disconnect it on successful logout - logout_timeout: Instant, - reconnect: bool, // we carry this forward for the subsequent disconnected state - }, + AwaitingLogout(AwaitingLogoutState), /// The session is active, we have connected and mutually logged on. Active(ActiveState), /// The TCP connection has been dropped. @@ -73,9 +74,9 @@ impl SessionState { writer.send_raw_message(message).await } } - Self::AwaitingLogon { + Self::AwaitingLogon(AwaitingLogonState { writer, logon_sent, .. - } => match message_type { + }) => match message_type { Logon::MSG_TYPE => { if *logon_sent { error!("trying to send logon twice"); @@ -89,7 +90,7 @@ impl SessionState { } _ => error!("invalid outgoing message for AwaitingLogon state"), }, - Self::AwaitingLogout { writer, .. } => { + Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) => { // Logout messages are allowed because we first transition into AwaitingLogout // and only then send the logout message if message_type == Logout::MSG_TYPE { @@ -103,8 +104,8 @@ impl SessionState { pub async fn disconnect_writer(&self) { match self { Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } + | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) + | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => writer.disconnect().await, _ => debug!("disconnecting an already disconnected session"), } @@ -113,8 +114,8 @@ impl SessionState { fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon { writer, .. } - | Self::AwaitingLogout { writer, .. } + | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) + | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => Some(writer), _ => None, } @@ -125,17 +126,17 @@ impl SessionState { logout_timeout: Duration, reconnect: bool, ) -> bool { - if matches!(self, SessionState::AwaitingLogout { .. }) { + if matches!(self, SessionState::AwaitingLogout(_)) { debug!("already in awaiting logout state"); return false; } if let Some(writer) = self.get_writer() { - *self = SessionState::AwaitingLogout { + *self = SessionState::AwaitingLogout(AwaitingLogoutState { writer: writer.clone(), logout_timeout: Instant::now() + logout_timeout, reconnect, - }; + }); true } else { error!("trying to transition to awaiting logout without an established connection"); @@ -149,14 +150,14 @@ impl SessionState { end: u64, ) -> AwaitingResendTransitionOutcome { match self { - SessionState::AwaitingLogon { writer, .. } + SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) | SessionState::Active(ActiveState { writer, .. }) => { let awaiting_resend = AwaitingResendState::new(writer.to_owned(), begin, end); *self = SessionState::AwaitingResend(awaiting_resend); AwaitingResendTransitionOutcome::Success } SessionState::AwaitingResend(state) => state.update(begin, end), - SessionState::AwaitingLogout { .. } => AwaitingResendTransitionOutcome::InvalidState( + SessionState::AwaitingLogout(_) => AwaitingResendTransitionOutcome::InvalidState( "trying to request a resend while we are already logging out".to_string(), ), SessionState::Disconnected(_) => AwaitingResendTransitionOutcome::InvalidState( @@ -227,8 +228,10 @@ impl SessionState { pub fn peer_deadline(&self) -> Option<&Instant> { match self { Self::Active(ActiveState { peer_deadline, .. }) => Some(peer_deadline), - Self::AwaitingLogon { logon_timeout, .. } => Some(logon_timeout), - Self::AwaitingLogout { logout_timeout, .. } => Some(logout_timeout), + Self::AwaitingLogon(AwaitingLogonState { logon_timeout, .. }) => Some(logon_timeout), + Self::AwaitingLogout(AwaitingLogoutState { + logout_timeout, .. + }) => Some(logout_timeout), _ => None, } } @@ -274,16 +277,16 @@ impl SessionState { } pub fn is_awaiting_logon(&self) -> bool { - matches!(self, SessionState::AwaitingLogon { .. }) + matches!(self, SessionState::AwaitingLogon(_)) } pub fn is_awaiting_logout(&self) -> bool { - matches!(self, SessionState::AwaitingLogout { .. }) + matches!(self, SessionState::AwaitingLogout(_)) } pub fn as_status(&self) -> SessionInfoStatus { match self { - SessionState::AwaitingLogon { .. } => SessionInfoStatus::AwaitingLogon, + SessionState::AwaitingLogon(_) => SessionInfoStatus::AwaitingLogon, SessionState::AwaitingResend(AwaitingResendState { begin_seq_number, end_seq_number, @@ -294,165 +297,9 @@ impl SessionState { end: *end_seq_number, attempts: *resend_attempts, }, - SessionState::AwaitingLogout { .. } => SessionInfoStatus::AwaitingLogout, + SessionState::AwaitingLogout(_) => SessionInfoStatus::AwaitingLogout, SessionState::Active(_) => SessionInfoStatus::Active, SessionState::Disconnected(_) => SessionInfoStatus::Disconnected, } } } - -#[inline] -fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { - (heartbeat_interval as f64 * TEST_REQUEST_THRESHOLD).round() as u64 -} - -pub struct ActiveState { - /// The writer's reference to send messages to the counterparty - writer: WriterRef, - /// When we should send the next heartbeat message to the counterparty - heartbeat_deadline: Instant, - /// When the next message from the counterparty is expected at the latest - peer_deadline: Instant, - /// The ID of the test request we sent on peer timer expiry - sent_test_request_id: Option, -} - -/// Session state we're in while processing messages we requested to be resent. -pub struct AwaitingResendState { - /// The reference to the writer loop. - pub(crate) writer: WriterRef, - /// The beginning of the gap we're waiting for the target to resend. - pub(crate) begin_seq_number: u64, - /// The end of the gap we're waiting for the target to resend. - pub(crate) end_seq_number: u64, - /// Inbound messages we receive while processing the resend. - pub(crate) inbound_queue: VecDeque, - /// The number of times we've attempted to ask the counterparty to resend the gap. - pub(crate) resend_attempts: usize, -} - -impl AwaitingResendState { - fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { - Self { - writer, - begin_seq_number, - end_seq_number, - inbound_queue: Default::default(), - resend_attempts: 1, - } - } - - fn update( - &mut self, - begin_seq_number: u64, - end_seq_number: u64, - ) -> AwaitingResendTransitionOutcome { - let resend_attempts = if self.begin_seq_number == begin_seq_number { - if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { - return AwaitingResendTransitionOutcome::AttemptsExceeded; - } - self.resend_attempts + 1 - } else if begin_seq_number < self.begin_seq_number { - return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; - } else { - 1 - }; - - self.resend_attempts = resend_attempts; - self.begin_seq_number = begin_seq_number; - self.end_seq_number = end_seq_number; - - AwaitingResendTransitionOutcome::Success - } -} - -pub struct DisconnectedState { - reconnect: bool, - session_awaiter: Option>, - reason: String, -} - -impl DisconnectedState { - fn new(reconnect: bool, reason: &str) -> Self { - Self { - reconnect, - session_awaiter: None, - reason: reason.to_string(), - } - } - - fn set_session_awaiter(&mut self, responder: oneshot::Sender) { - self.session_awaiter = Some(responder); - } - - fn has_session_awaiter(&self) -> bool { - self.session_awaiter.is_some() - } - - fn take_session_awaiter(&mut self) -> Option> { - self.session_awaiter.take() - } -} - -pub enum AwaitingResendTransitionOutcome { - Success, - InvalidState(String), - BeginSeqNumberTooLow, - AttemptsExceeded, -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::mpsc; - - #[test] - fn test_awaiting_resend_transition_begin_seq_number_too_low() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - let result = state.try_transition_to_awaiting_resend(0, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow - )); - } - - #[test] - fn test_awaiting_resend_transition_attempts_exceeded() { - let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - - // we can transition twice more without hitting the limit - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - - // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::AttemptsExceeded - )); - } - - #[test] - fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { - let mut state = SessionState::AwaitingLogout { - writer: create_writer_ref(), - logout_timeout: Instant::now(), - reconnect: false, - }; - - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::InvalidState(_) - )); - } - - fn create_writer_ref() -> WriterRef { - let (sender, _) = mpsc::channel(10); - WriterRef::new(sender) - } -} diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs new file mode 100644 index 00000000..b0828bbe --- /dev/null +++ b/crates/hotfix/src/session/state/active.rs @@ -0,0 +1,19 @@ +use crate::session::state::TestRequestId; +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct ActiveState { + /// The writer's reference to send messages to the counterparty + pub(crate) writer: WriterRef, + /// When we should send the next heartbeat message to the counterparty + pub(crate) heartbeat_deadline: Instant, + /// When the next message from the counterparty is expected at the latest + pub(crate) peer_deadline: Instant, + /// The ID of the test request we sent on peer timer expiry + pub(crate) sent_test_request_id: Option, +} + +#[inline] +pub(crate) fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { + (heartbeat_interval as f64 * super::TEST_REQUEST_THRESHOLD).round() as u64 +} diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs new file mode 100644 index 00000000..534bb4ff --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -0,0 +1,8 @@ +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct AwaitingLogonState { + pub(crate) writer: WriterRef, + pub(crate) logon_sent: bool, + pub(crate) logon_timeout: Instant, +} diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs new file mode 100644 index 00000000..cef2f457 --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -0,0 +1,8 @@ +use crate::transport::writer::WriterRef; +use tokio::time::Instant; + +pub(crate) struct AwaitingLogoutState { + pub(crate) writer: WriterRef, + pub(crate) logout_timeout: Instant, + pub(crate) reconnect: bool, +} diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs new file mode 100644 index 00000000..ede8a2ed --- /dev/null +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -0,0 +1,121 @@ +use crate::transport::writer::WriterRef; +use hotfix_message::message::Message; +use std::collections::VecDeque; + +const MAX_RESEND_ATTEMPTS: usize = 3; + +/// Session state we're in while processing messages we requested to be resent. +pub(crate) struct AwaitingResendState { + /// The reference to the writer loop. + pub(crate) writer: WriterRef, + /// The beginning of the gap we're waiting for the target to resend. + pub(crate) begin_seq_number: u64, + /// The end of the gap we're waiting for the target to resend. + pub(crate) end_seq_number: u64, + /// Inbound messages we receive while processing the resend. + pub(crate) inbound_queue: VecDeque, + /// The number of times we've attempted to ask the counterparty to resend the gap. + pub(crate) resend_attempts: usize, +} + +impl AwaitingResendState { + pub(crate) fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { + Self { + writer, + begin_seq_number, + end_seq_number, + inbound_queue: Default::default(), + resend_attempts: 1, + } + } + + pub(crate) fn update( + &mut self, + begin_seq_number: u64, + end_seq_number: u64, + ) -> AwaitingResendTransitionOutcome { + let resend_attempts = if self.begin_seq_number == begin_seq_number { + if self.resend_attempts + 1 > MAX_RESEND_ATTEMPTS { + return AwaitingResendTransitionOutcome::AttemptsExceeded; + } + self.resend_attempts + 1 + } else if begin_seq_number < self.begin_seq_number { + return AwaitingResendTransitionOutcome::BeginSeqNumberTooLow; + } else { + 1 + }; + + self.resend_attempts = resend_attempts; + self.begin_seq_number = begin_seq_number; + self.end_seq_number = end_seq_number; + + AwaitingResendTransitionOutcome::Success + } +} + +pub(crate) enum AwaitingResendTransitionOutcome { + Success, + InvalidState(String), + BeginSeqNumberTooLow, + AttemptsExceeded, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::session::state::SessionState; + use tokio::sync::mpsc; + use tokio::time::Instant; + + #[test] + fn test_awaiting_resend_transition_begin_seq_number_too_low() { + let writer = create_writer_ref(); + let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); + let result = state.try_transition_to_awaiting_resend(0, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow + )); + } + + #[test] + fn test_awaiting_resend_transition_attempts_exceeded() { + let writer = create_writer_ref(); + let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); + + // we can transition twice more without hitting the limit + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + + // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::AttemptsExceeded + )); + } + + #[test] + fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { + use crate::session::state::AwaitingLogoutState; + + let mut state = SessionState::AwaitingLogout(AwaitingLogoutState { + writer: create_writer_ref(), + logout_timeout: Instant::now(), + reconnect: false, + }); + + let result = state.try_transition_to_awaiting_resend(1, 5); + assert!(matches!( + result, + AwaitingResendTransitionOutcome::InvalidState(_) + )); + } + + fn create_writer_ref() -> WriterRef { + let (sender, _) = mpsc::channel(10); + WriterRef::new(sender) + } +} diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs new file mode 100644 index 00000000..54a55e4b --- /dev/null +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -0,0 +1,35 @@ +use crate::session::event::AwaitingActiveSessionResponse; +use tokio::sync::oneshot; + +pub(crate) struct DisconnectedState { + pub(crate) reconnect: bool, + session_awaiter: Option>, + pub(crate) reason: String, +} + +impl DisconnectedState { + pub(crate) fn new(reconnect: bool, reason: &str) -> Self { + Self { + reconnect, + session_awaiter: None, + reason: reason.to_string(), + } + } + + pub(crate) fn set_session_awaiter( + &mut self, + responder: oneshot::Sender, + ) { + self.session_awaiter = Some(responder); + } + + pub(crate) fn has_session_awaiter(&self) -> bool { + self.session_awaiter.is_some() + } + + pub(crate) fn take_session_awaiter( + &mut self, + ) -> Option> { + self.session_awaiter.take() + } +} From ec931c101884fbf6d7dffb658043709342c4d2f0 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 09:03:18 +0100 Subject: [PATCH 02/14] Move further logic into state variants --- crates/hotfix/src/session/state.rs | 52 +++++-------------- crates/hotfix/src/session/state/active.rs | 29 +++++++++++ .../hotfix/src/session/state/disconnected.rs | 27 ++++++++++ 3 files changed, 70 insertions(+), 38 deletions(-) diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 87abd5fa..4d9ca6a0 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -59,7 +59,7 @@ impl SessionState { pub fn should_reconnect(&self) -> bool { match self { - SessionState::Disconnected(DisconnectedState { reconnect, .. }) => *reconnect, + SessionState::Disconnected(state) => state.should_reconnect(), _ => true, } } @@ -173,7 +173,7 @@ impl SessionState { ) { match self { SessionState::Disconnected(state) => { - if state.has_session_awaiter() { + if let Err(responder) = state.register_session_awaiter(responder) { let reason = &state.reason; error!( "session awaiter already registered on state disconnected due to: {reason}" @@ -181,9 +181,6 @@ impl SessionState { if let Err(err) = responder.send(AwaitingActiveSessionResponse::Shutdown) { error!("failed to send session awaiter response: {err:?}"); } - } else { - state.set_session_awaiter(responder); - debug!("registered session awaiter"); } } _ => { @@ -196,42 +193,31 @@ impl SessionState { } pub fn notify_session_awaiter(&mut self) { - if let SessionState::Disconnected(state) = self - && let Some(awaiter) = state.take_session_awaiter() - { - if let Err(err) = awaiter.send(AwaitingActiveSessionResponse::Active) { - error!("failed to send session awaiter response: {err:?}"); - } else { - debug!("notified session awaiter"); - } + if let SessionState::Disconnected(state) = self { + state.notify_session_awaiter(); } } pub fn heartbeat_deadline(&self) -> Option<&Instant> { match self { - Self::Active(ActiveState { - heartbeat_deadline, .. - }) => Some(heartbeat_deadline), + Self::Active(state) => Some(state.heartbeat_deadline()), _ => None, } } pub fn reset_heartbeat_timer(&mut self, heartbeat_interval: u64) { - if let Self::Active(ActiveState { - heartbeat_deadline, .. - }) = self - { - *heartbeat_deadline = Instant::now() + Duration::from_secs(heartbeat_interval); + if let Self::Active(state) = self { + state.reset_heartbeat_timer(heartbeat_interval); } } pub fn peer_deadline(&self) -> Option<&Instant> { match self { - Self::Active(ActiveState { peer_deadline, .. }) => Some(peer_deadline), + Self::Active(state) => Some(state.peer_deadline()), Self::AwaitingLogon(AwaitingLogonState { logon_timeout, .. }) => Some(logon_timeout), - Self::AwaitingLogout(AwaitingLogoutState { - logout_timeout, .. - }) => Some(logout_timeout), + Self::AwaitingLogout(AwaitingLogoutState { logout_timeout, .. }) => { + Some(logout_timeout) + } _ => None, } } @@ -241,24 +227,14 @@ impl SessionState { heartbeat_interval: u64, test_request_id: Option, ) { - if let Self::Active(ActiveState { - peer_deadline, - sent_test_request_id, - .. - }) = self - { - let interval = calculate_peer_interval(heartbeat_interval); - *peer_deadline = Instant::now() + Duration::from_secs(interval); - *sent_test_request_id = test_request_id; + if let Self::Active(state) = self { + state.reset_peer_timer(heartbeat_interval, test_request_id); } } pub fn expected_test_response_id(&self) -> Option<&TestRequestId> { match self { - Self::Active(ActiveState { - sent_test_request_id: expected_test_response_id, - .. - }) => expected_test_response_id.as_ref(), + Self::Active(state) => state.expected_test_response_id(), _ => None, } } diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index b0828bbe..e3151bf2 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -1,5 +1,6 @@ use crate::session::state::TestRequestId; use crate::transport::writer::WriterRef; +use std::time::Duration; use tokio::time::Instant; pub(crate) struct ActiveState { @@ -13,6 +14,34 @@ pub(crate) struct ActiveState { pub(crate) sent_test_request_id: Option, } +impl ActiveState { + pub(crate) fn heartbeat_deadline(&self) -> &Instant { + &self.heartbeat_deadline + } + + pub(crate) fn reset_heartbeat_timer(&mut self, heartbeat_interval: u64) { + self.heartbeat_deadline = Instant::now() + Duration::from_secs(heartbeat_interval); + } + + pub(crate) fn peer_deadline(&self) -> &Instant { + &self.peer_deadline + } + + pub(crate) fn reset_peer_timer( + &mut self, + heartbeat_interval: u64, + test_request_id: Option, + ) { + let interval = calculate_peer_interval(heartbeat_interval); + self.peer_deadline = Instant::now() + Duration::from_secs(interval); + self.sent_test_request_id = test_request_id; + } + + pub(crate) fn expected_test_response_id(&self) -> Option<&TestRequestId> { + self.sent_test_request_id.as_ref() + } +} + #[inline] pub(crate) fn calculate_peer_interval(heartbeat_interval: u64) -> u64 { (heartbeat_interval as f64 * super::TEST_REQUEST_THRESHOLD).round() as u64 diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs index 54a55e4b..a8a1c1a9 100644 --- a/crates/hotfix/src/session/state/disconnected.rs +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -1,5 +1,6 @@ use crate::session::event::AwaitingActiveSessionResponse; use tokio::sync::oneshot; +use tracing::{debug, error}; pub(crate) struct DisconnectedState { pub(crate) reconnect: bool, @@ -32,4 +33,30 @@ impl DisconnectedState { ) -> Option> { self.session_awaiter.take() } + + pub(crate) fn should_reconnect(&self) -> bool { + self.reconnect + } + + pub(crate) fn register_session_awaiter( + &mut self, + responder: oneshot::Sender, + ) -> Result<(), oneshot::Sender> { + if self.has_session_awaiter() { + Err(responder) + } else { + self.set_session_awaiter(responder); + Ok(()) + } + } + + pub(crate) fn notify_session_awaiter(&mut self) { + if let Some(awaiter) = self.take_session_awaiter() { + if let Err(err) = awaiter.send(AwaitingActiveSessionResponse::Active) { + error!("failed to send session awaiter response: {err:?}"); + } else { + debug!("notified session awaiter"); + } + } + } } From 92b16e5dc730f0fc9f94bcd79e9cfc1a3209cf82 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 09:22:36 +0100 Subject: [PATCH 03/14] Proof-of-concept to move heartbeat handler into session state --- crates/hotfix/src/session.rs | 14 ++++-- crates/hotfix/src/session/state.rs | 55 +++++++++++++++++++++++ crates/hotfix/src/session/state/active.rs | 20 ++++++++- 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index f7ed64df..ac3b0412 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -46,7 +46,8 @@ pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; use crate::session::state::{ - AwaitingLogonState, AwaitingLogoutState, AwaitingResendTransitionOutcome, TestRequestId, + AwaitingLogonState, AwaitingLogoutState, AwaitingResendTransitionOutcome, SessionCtx, + TestRequestId, }; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; @@ -1088,8 +1089,15 @@ where } async fn handle_heartbeat_timeout(&mut self) { - if let Err(err) = self.send_message(Heartbeat::default()).await { - error!(err = ?err, "failed to send heartbeat message"); + let Session { + ref mut state, + ref mut store, + ref config, + .. + } = *self; + if let SessionState::Active(active) = state { + let mut ctx = SessionCtx { config, store }; + active.on_heartbeat_timeout(&mut ctx).await; } } diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 4d9ca6a0..f0eeed63 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -10,17 +10,72 @@ pub(crate) use awaiting_logout::AwaitingLogoutState; pub(crate) use awaiting_resend::{AwaitingResendState, AwaitingResendTransitionOutcome}; pub(crate) use disconnected::DisconnectedState; +use crate::config::SessionConfig; use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; +use crate::message::{OutboundMessage, generate_message}; +use crate::session::error::InternalSendError; use crate::session::event::AwaitingActiveSessionResponse; use crate::session::info::Status as SessionInfoStatus; +use crate::store::StoreError; use crate::transport::writer::WriterRef; +use hotfix_store::MessageStore; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; use tracing::{debug, error}; +pub(crate) struct SessionCtx<'a, Store> { + pub config: &'a SessionConfig, + pub store: &'a mut Store, +} + +#[allow(dead_code)] // fields used in later sub-phases +pub(crate) struct PreparedMessage { + pub seq_num: u64, + pub msg_type: String, + pub raw: RawFixMessage, +} + +impl SessionCtx<'_, Store> { + pub async fn prepare_message( + &mut self, + message: impl OutboundMessage, + ) -> Result { + let seq_num = self.store.next_sender_seq_number(); + let msg_type = message.message_type().to_string(); + let msg = generate_message( + &self.config.begin_string, + &self.config.sender_comp_id, + &self.config.target_comp_id, + seq_num, + message, + ) + .map_err(|e| { + InternalSendError::Persist(StoreError::PersistMessage { + sequence_number: seq_num, + source: e.into(), + }) + })?; + + self.store + .increment_sender_seq_number() + .await + .map_err(InternalSendError::SequenceNumber)?; + self.store + .add(seq_num, &msg) + .await + .map_err(InternalSendError::Persist)?; + + Ok(PreparedMessage { + seq_num, + msg_type, + raw: RawFixMessage::new(msg), + }) + } +} + const TEST_REQUEST_THRESHOLD: f64 = 1.2; pub(crate) type TestRequestId = String; diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index e3151bf2..55eeb1cb 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -1,7 +1,10 @@ -use crate::session::state::TestRequestId; +use crate::message::heartbeat::Heartbeat; +use crate::session::state::{SessionCtx, TestRequestId}; use crate::transport::writer::WriterRef; +use hotfix_store::MessageStore; use std::time::Duration; use tokio::time::Instant; +use tracing::error; pub(crate) struct ActiveState { /// The writer's reference to send messages to the counterparty @@ -40,6 +43,21 @@ impl ActiveState { pub(crate) fn expected_test_response_id(&self) -> Option<&TestRequestId> { self.sent_test_request_id.as_ref() } + + pub(crate) async fn on_heartbeat_timeout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + ) { + let prepared = match ctx.prepare_message(Heartbeat::default()).await { + Ok(prepared) => prepared, + Err(err) => { + error!(err = ?err, "failed to send heartbeat message"); + return; + } + }; + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } } #[inline] From e46905c66fae1aa1af6eb3927603a18c2c5aa9d6 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 09:32:35 +0100 Subject: [PATCH 04/14] Move is_awaiting_logon and is_awaiting_logout logic to session state --- crates/hotfix/src/session.rs | 39 +++++++++++-------- crates/hotfix/src/session/state.rs | 8 ---- crates/hotfix/src/session/state/active.rs | 36 ++++++++++++++++- .../src/session/state/awaiting_logon.rs | 8 ++++ .../src/session/state/awaiting_logout.rs | 9 +++++ 5 files changed, 73 insertions(+), 27 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index ac3b0412..0ecdd726 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -1102,23 +1102,28 @@ where } async fn handle_peer_timeout(&mut self) { - if self.state.is_expecting_test_response() { - warn!("peer didn't respond, terminating.."); - self.logout_and_terminate("peer timeout").await; - } else if self.state.is_awaiting_logon() { - warn!("peer didn't respond to our Logon, disconnecting.."); - self.state.disconnect_writer().await; - } else if self.state.is_awaiting_logout() { - warn!("peer didn't respond to our Logout, disconnecting.."); - self.state.disconnect_writer().await; - } else { - let req_id = format!("TEST_{}", self.store.next_target_seq_number()); - info!("sending TestRequest due to peer timer expiring"); - let request = TestRequest::new(req_id.clone()); - if let Err(err) = self.send_message(request).await { - error!(err = ?err, "failed to send TestRequest"); + let Session { + ref mut state, + ref mut store, + ref config, + .. + } = *self; + let transition = match state { + SessionState::Active(active) => { + let mut ctx = SessionCtx { config, store }; + active.on_peer_timeout(&mut ctx).await + } + SessionState::AwaitingLogon(awaiting_logon) => { + awaiting_logon.on_peer_timeout().await; + None } - self.reset_peer_timer(Some(req_id)); + SessionState::AwaitingLogout(awaiting_logout) => { + Some(awaiting_logout.on_peer_timeout().await) + } + _ => None, + }; + if let Some(new_state) = transition { + self.state = new_state; } } @@ -1616,7 +1621,7 @@ mod tests { // State should be AwaitingLogout (graceful logout initiated) assert!( - session.state.is_awaiting_logout(), + matches!(session.state, SessionState::AwaitingLogout(_)), "State should be AwaitingLogout when schedule is inactive and was connected" ); } diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index f0eeed63..d14aa2a4 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -307,14 +307,6 @@ impl SessionState { self.expected_test_response_id().is_some() } - pub fn is_awaiting_logon(&self) -> bool { - matches!(self, SessionState::AwaitingLogon(_)) - } - - pub fn is_awaiting_logout(&self) -> bool { - matches!(self, SessionState::AwaitingLogout(_)) - } - pub fn as_status(&self) -> SessionInfoStatus { match self { SessionState::AwaitingLogon(_) => SessionInfoStatus::AwaitingLogon, diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index 55eeb1cb..982c1f9f 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -1,10 +1,12 @@ use crate::message::heartbeat::Heartbeat; -use crate::session::state::{SessionCtx, TestRequestId}; +use crate::message::logout::Logout; +use crate::message::test_request::TestRequest; +use crate::session::state::{SessionCtx, SessionState, TestRequestId}; use crate::transport::writer::WriterRef; use hotfix_store::MessageStore; use std::time::Duration; use tokio::time::Instant; -use tracing::error; +use tracing::{error, info, warn}; pub(crate) struct ActiveState { /// The writer's reference to send messages to the counterparty @@ -44,6 +46,36 @@ impl ActiveState { self.sent_test_request_id.as_ref() } + pub(crate) async fn on_peer_timeout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + ) -> Option { + if self.sent_test_request_id.is_some() { + warn!("peer didn't respond, terminating.."); + let logout = Logout::with_reason("peer timeout".to_string()); + if let Ok(prepared) = ctx.prepare_message(logout).await { + self.writer.send_raw_message(prepared.raw).await; + } + self.writer.disconnect().await; + return Some(SessionState::new_disconnected(true, "peer timeout")); + } + + let req_id = format!("TEST_{}", ctx.store.next_target_seq_number()); + info!("sending TestRequest due to peer timer expiring"); + let request = TestRequest::new(req_id.clone()); + match ctx.prepare_message(request).await { + Ok(prepared) => { + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + Err(err) => { + error!(err = ?err, "failed to send TestRequest"); + } + } + self.reset_peer_timer(ctx.config.heartbeat_interval, Some(req_id)); + None + } + pub(crate) async fn on_heartbeat_timeout( &mut self, ctx: &mut SessionCtx<'_, Store>, diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index 534bb4ff..67285d0e 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -1,8 +1,16 @@ use crate::transport::writer::WriterRef; use tokio::time::Instant; +use tracing::warn; pub(crate) struct AwaitingLogonState { pub(crate) writer: WriterRef, pub(crate) logon_sent: bool, pub(crate) logon_timeout: Instant, } + +impl AwaitingLogonState { + pub(crate) async fn on_peer_timeout(&self) { + warn!("peer didn't respond to our Logon, disconnecting.."); + self.writer.disconnect().await; + } +} diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index cef2f457..ed820cbe 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -1,8 +1,17 @@ use crate::transport::writer::WriterRef; use tokio::time::Instant; +use tracing::warn; pub(crate) struct AwaitingLogoutState { pub(crate) writer: WriterRef, pub(crate) logout_timeout: Instant, pub(crate) reconnect: bool, } + +impl AwaitingLogoutState { + pub(crate) async fn on_peer_timeout(&self) -> super::SessionState { + warn!("peer didn't respond to our Logout, disconnecting.."); + self.writer.disconnect().await; + super::SessionState::new_disconnected(self.reconnect, "logout timeout") + } +} From 7fcbb24be75d75f68898f86ccf74a164b95f19e1 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 09:41:17 +0100 Subject: [PATCH 05/14] Move on_disconnect handling to state variants --- crates/hotfix/src/session.rs | 29 +++++++++---------- crates/hotfix/src/session/state/active.rs | 5 ++++ .../src/session/state/awaiting_logon.rs | 5 ++++ .../src/session/state/awaiting_logout.rs | 4 +++ .../src/session/state/awaiting_resend.rs | 5 ++++ .../hotfix/src/session/state/disconnected.rs | 16 ++++++++++ 6 files changed, 48 insertions(+), 16 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 0ecdd726..80662388 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -335,11 +335,9 @@ where } async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { - self.state = SessionState::AwaitingLogon(AwaitingLogonState { - writer, - logon_sent: false, - logon_timeout: Instant::now() + Duration::from_secs(self.config.logon_timeout), - }); + if let SessionState::Disconnected(s) = &self.state { + self.state = s.on_connect(writer, Duration::from_secs(self.config.logon_timeout)); + } self.reset_peer_timer(None); self.send_logon().await?; @@ -347,19 +345,18 @@ where } async fn on_disconnect(&mut self, reason: String) { - match self.state { - SessionState::Active(_) - | SessionState::AwaitingLogon(_) - | SessionState::AwaitingResend(_) => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(true, &reason); - } + let transition = match &self.state { + SessionState::Active(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingLogon(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingResend(s) => Some(s.on_disconnect(&reason).await), + SessionState::AwaitingLogout(s) => Some(s.on_disconnect(&reason)), SessionState::Disconnected(_) => { - warn!("disconnect message was received, but the session is already disconnected") - } - SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { - self.state = SessionState::new_disconnected(reconnect, &reason); + warn!("disconnect message was received, but the session is already disconnected"); + None } + }; + if let Some(new_state) = transition { + self.state = new_state; } } diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index 982c1f9f..aa506d2a 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -46,6 +46,11 @@ impl ActiveState { self.sent_test_request_id.as_ref() } + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { + self.writer.disconnect().await; + SessionState::new_disconnected(true, reason) + } + pub(crate) async fn on_peer_timeout( &mut self, ctx: &mut SessionCtx<'_, Store>, diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index 67285d0e..9fe1aa3d 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -9,6 +9,11 @@ pub(crate) struct AwaitingLogonState { } impl AwaitingLogonState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> super::SessionState { + self.writer.disconnect().await; + super::SessionState::new_disconnected(true, reason) + } + pub(crate) async fn on_peer_timeout(&self) { warn!("peer didn't respond to our Logon, disconnecting.."); self.writer.disconnect().await; diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index ed820cbe..a0e3b90f 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -9,6 +9,10 @@ pub(crate) struct AwaitingLogoutState { } impl AwaitingLogoutState { + pub(crate) fn on_disconnect(&self, reason: &str) -> super::SessionState { + super::SessionState::new_disconnected(self.reconnect, reason) + } + pub(crate) async fn on_peer_timeout(&self) -> super::SessionState { warn!("peer didn't respond to our Logout, disconnecting.."); self.writer.disconnect().await; diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index ede8a2ed..3e5afd9c 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -19,6 +19,11 @@ pub(crate) struct AwaitingResendState { } impl AwaitingResendState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> super::SessionState { + self.writer.disconnect().await; + super::SessionState::new_disconnected(true, reason) + } + pub(crate) fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { Self { writer, diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs index a8a1c1a9..c96e2be0 100644 --- a/crates/hotfix/src/session/state/disconnected.rs +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -1,5 +1,9 @@ use crate::session::event::AwaitingActiveSessionResponse; +use crate::session::state::AwaitingLogonState; +use crate::transport::writer::WriterRef; +use std::time::Duration; use tokio::sync::oneshot; +use tokio::time::Instant; use tracing::{debug, error}; pub(crate) struct DisconnectedState { @@ -34,6 +38,18 @@ impl DisconnectedState { self.session_awaiter.take() } + pub(crate) fn on_connect( + &self, + writer: WriterRef, + logon_timeout: Duration, + ) -> super::SessionState { + super::SessionState::AwaitingLogon(AwaitingLogonState { + writer, + logon_sent: false, + logon_timeout: Instant::now() + logon_timeout, + }) + } + pub(crate) fn should_reconnect(&self) -> bool { self.reconnect } From 47d1a94c212c871fb8cb7b281cbb6cc7143f48ff Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 11:06:32 +0100 Subject: [PATCH 06/14] Move verification methods from session to session state --- crates/hotfix/src/session.rs | 353 +++++-------------------- crates/hotfix/src/session/state.rs | 405 ++++++++++++++++++++++++++++- 2 files changed, 466 insertions(+), 292 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 80662388..4b5d9842 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -14,7 +14,7 @@ use std::pin::Pin; use tokio::select; use tokio::sync::mpsc; use tokio::time::{Duration, Instant, Sleep, sleep, sleep_until}; -use tracing::{debug, enabled, error, info, warn}; +use tracing::{debug, error, info, warn}; use crate::Application; use crate::application::{InboundDecision, OutboundDecision}; @@ -31,8 +31,7 @@ use crate::message::resend_request::ResendRequest; use crate::message::sequence_reset::SequenceReset; use crate::message::test_request::TestRequest; use crate::message::verification::verify_message; -use crate::message::verification_error::{CompIdType, MessageVerificationError}; -use crate::message::{is_admin, prepare_message_for_resend}; +use crate::message::verification_error::MessageVerificationError; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; @@ -45,10 +44,7 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{ - AwaitingLogonState, AwaitingLogoutState, AwaitingResendTransitionOutcome, SessionCtx, - TestRequestId, -}; +use crate::session::state::{AwaitingLogonState, AwaitingLogoutState, SessionCtx, TestRequestId}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; @@ -587,186 +583,42 @@ where &mut self, error: MessageVerificationError, ) -> Result<(), SessionOperationError> { - match error { - MessageVerificationError::SeqNumberTooLow { - expected, - actual, - possible_duplicate, - } => { - self.handle_sequence_number_too_low(expected, actual, possible_duplicate) - .await; - } - MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - self.handle_sequence_number_too_high(expected, actual) - .await?; - } - MessageVerificationError::IncorrectBeginString(begin_string) => { - self.handle_incorrect_begin_string(begin_string).await; - } - MessageVerificationError::IncorrectCompId { - comp_id, - comp_id_type, - msg_seq_num, - } => { - self.handle_incorrect_comp_id(comp_id, comp_id_type, msg_seq_num) - .await; - } - MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { - self.handle_sending_time_accuracy_problem(msg_seq_num, "unexpected sending time") - .await; - } - MessageVerificationError::SendingTimeMissing { msg_seq_num } => { - self.handle_sending_time_accuracy_problem(msg_seq_num, "sending time missing") - .await; - } - MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { - self.handle_original_sending_time_missing(msg_seq_num).await; - } - MessageVerificationError::OriginalSendingTimeAfterSendingTime { - msg_seq_num, .. - } => { - self.handle_sending_time_accuracy_problem( - msg_seq_num, - "original sending time is after sending time", - ) - .await; - } - } - - Ok(()) - } - - async fn handle_incorrect_begin_string(&mut self, received_begin_string: String) { - self.logout_and_terminate(&format!( - "beginString={received_begin_string} is not supported" - )) - .await; - } - - async fn handle_incorrect_comp_id( - &mut self, - received_comp_id: String, - comp_id_type: CompIdType, - msg_seq_num: u64, - ) { - error!( - "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&format!("invalid comp ID {received_comp_id}")); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject message with invalid comp ID: {err}"); + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, }; - - self.logout_and_terminate("incorrect comp ID received") - .await; - } - - async fn handle_sequence_number_too_low( - &mut self, - expected: u64, - actual: u64, - possible_duplicate: bool, - ) { - if possible_duplicate { - warn!( - "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" - ); - return; - } - error!( - "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." - ); - let reason = format!("sequence number too low (actual {actual}, expected {expected})"); - self.logout_and_terminate(&reason).await; - self.state = SessionState::new_disconnected(false, &reason); - } - - async fn handle_sequence_number_too_high( - &mut self, - expected: u64, - actual: u64, - ) -> Result<(), SessionOperationError> { - match self - .state - .try_transition_to_awaiting_resend(expected, actual) - { - AwaitingResendTransitionOutcome::Success => { - debug!( - "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." - ); - self.send_resend_request(expected, actual).await?; - } - AwaitingResendTransitionOutcome::InvalidState(reason) => { - error!("failed to request resend: {reason}"); - } - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected( - false, - "awaiting resend begin seq number unexpectedly lower than the previous resend request's", - ); - } - AwaitingResendTransitionOutcome::AttemptsExceeded => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected( - false, - "resend request attempts exceeded, manual intervention required", - ); - } + if let Some(new_state) = ctx.handle_verification_error(state, error).await? { + *state = new_state; } - Ok(()) } async fn handle_invalid_msg_type(&mut self, message: Message, msg_type: &str) { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidMsgtype) - .text(&format!("invalid message type {msg_type}")); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject message for invalid msgtype: {err}"); - }; - - #[allow(clippy::collapsible_if)] - if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) - && self.store.next_target_seq_number() == seq_num - { - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - }; - } - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } - - async fn handle_sending_time_accuracy_problem(&mut self, msg_seq_num: u64, text: &str) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) - .text(text); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject for time accuracy problem: {err}"); - }; - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - }; - } - - async fn handle_original_sending_time_missing(&mut self, msg_seq_num: u64) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("original sending time is required"); - if let Err(err) = self.send_message(reject).await { - error!("failed to send reject for time missing tag: {err}"); - }; - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, }; + ctx.handle_invalid_msg_type(state, &message, msg_type).await; } async fn resend_messages( @@ -775,79 +627,27 @@ where end: u64, _message: &Message, ) -> Result<(), SessionOperationError> { - info!(begin, end, "resending messages as requested"); - let messages = self.store.get_slice(begin as usize, end as usize).await?; - - let no = messages.len(); - debug!(number_of_messages = no, "number of messages"); - - let mut reset_start: Option = None; - let mut sequence_number = 0; - - for msg in messages { - let mut message = self - .message_builder - .build(msg.as_slice()) - .into_message() - .ok_or_else(|| { - SessionOperationError::StoredMessageParse(format!( - "failed to build message for raw message: {msg:?}" - )) - })?; - sequence_number = get_msg_seq_num(&message); - let message_type: String = message - .header() - .get::<&str>(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? - .to_string(); - - if is_admin(&message_type) { - if reset_start.is_none() { - reset_start = Some(sequence_number); - } - continue; - } - - if let Some(begin) = reset_start { - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await?; - reset_start = None; - } - - if let Err(e) = prepare_message_for_resend(&mut message) { - error!( - error = e, - "failed to prepare message for resend, sending original" - ); - } - self.send_raw(&message_type, message.encode(&self.message_config)?) - .await; - - if enabled!(tracing::Level::DEBUG) - && let Ok(m) = String::from_utf8(msg.clone()) - { - debug!(sequence_number, message = m, "resent message"); - } - } - - if let Some(begin) = reset_start { - // the final reset if needed - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await?; + let writer = self.state.get_writer().cloned(); + if let Some(writer) = writer { + let Session { + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + ctx.resend_messages(&writer, begin, end).await?; + self.reset_heartbeat_timer(); } - Ok(()) } - fn log_skipped_admin_messages(begin: u64, end: u64) { - info!( - begin, - end, "skipped admin message(s) during resend, requesting reset for these" - ); - } - fn reset_heartbeat_timer(&mut self) { self.state .reset_heartbeat_timer(self.config.heartbeat_interval); @@ -922,41 +722,6 @@ where self.reset_heartbeat_timer(); } - async fn send_sequence_reset( - &mut self, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let sequence_reset = SequenceReset { - gap_fill: true, - new_seq_no: end, - }; - let raw_message = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - begin, - sequence_reset, - )?; - - self.send_raw(SequenceReset::MSG_TYPE, raw_message).await; - debug!(begin, end, "sent reset sequence"); - - Ok(()) - } - - async fn send_resend_request( - &mut self, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let request = ResendRequest::new(begin, end); - self.send_message(request) - .await - .with_send_context("resend request")?; - Ok(()) - } - async fn send_logon(&mut self) -> Result<(), SessionOperationError> { let reset_config = if self.config.reset_on_logon || self.reset_on_next_logon { self.store.reset().await?; @@ -1090,10 +855,17 @@ where ref mut state, ref mut store, ref config, + ref message_builder, + ref message_config, .. } = *self; if let SessionState::Active(active) = state { - let mut ctx = SessionCtx { config, store }; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; active.on_heartbeat_timeout(&mut ctx).await; } } @@ -1103,11 +875,18 @@ where ref mut state, ref mut store, ref config, + ref message_builder, + ref message_config, .. } = *self; let transition = match state { SessionState::Active(active) => { - let mut ctx = SessionCtx { config, store }; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; active.on_peer_timeout(&mut ctx).await } SessionState::AwaitingLogon(awaiting_logon) => { @@ -1198,7 +977,7 @@ where /// Panics if the message does not contain a valid MsgSeqNum field. /// This should never happen for messages that have passed validation. #[allow(clippy::expect_used)] -fn get_msg_seq_num(message: &Message) -> u64 { +pub(crate) fn get_msg_seq_num(message: &Message) -> u64 { message .header() .get(MSG_SEQ_NUM) diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index d14aa2a4..0e1fbfee 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -14,26 +14,39 @@ use crate::config::SessionConfig; use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; -use crate::message::{OutboundMessage, generate_message}; -use crate::session::error::InternalSendError; +use crate::message::reject::Reject; +use crate::message::resend_request::ResendRequest; +use crate::message::sequence_reset::SequenceReset; +use crate::message::verification::verify_message as verify_message_impl; +use crate::message::verification_error::{CompIdType, MessageVerificationError}; +use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; +use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; use crate::session::event::AwaitingActiveSessionResponse; +use crate::session::get_msg_seq_num; use crate::session::info::Status as SessionInfoStatus; use crate::store::StoreError; use crate::transport::writer::WriterRef; +use hotfix_message::message::{Config as MessageConfig, Message}; +use hotfix_message::session_fields::SessionRejectReason; +use hotfix_message::{MessageBuilder, Part}; use hotfix_store::MessageStore; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; -use tracing::{debug, error}; +use tracing::{debug, enabled, error, info, warn}; + +use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE}; pub(crate) struct SessionCtx<'a, Store> { pub config: &'a SessionConfig, pub store: &'a mut Store, + pub message_builder: &'a MessageBuilder, + pub message_config: &'a MessageConfig, } -#[allow(dead_code)] // fields used in later sub-phases pub(crate) struct PreparedMessage { pub seq_num: u64, + #[allow(dead_code)] // used in later sub-phases pub msg_type: String, pub raw: RawFixMessage, } @@ -74,6 +87,388 @@ impl SessionCtx<'_, Store> { raw: RawFixMessage::new(msg), }) } + + /// Prepare, persist, and send a message via the given writer. + pub async fn send_message( + &mut self, + writer: &WriterRef, + message: impl OutboundMessage, + ) -> Result { + let prepared = self.prepare_message(message).await?; + writer.send_raw_message(prepared.raw).await; + Ok(prepared.seq_num) + } + + #[allow(dead_code)] // used when states handle their own messages in 2e + pub fn verify_message( + &self, + message: &Message, + check_too_high: bool, + check_too_low: bool, + ) -> Result<(), MessageVerificationError> { + let expected_seq_number = if check_too_high || check_too_low { + Some(self.store.next_target_seq_number()) + } else { + None + }; + verify_message_impl( + message, + self.config, + expected_seq_number, + check_too_high, + check_too_low, + ) + } + + /// Handle a verification error. Returns `Some(new_state)` if a state transition is needed. + pub async fn handle_verification_error( + &mut self, + state: &mut SessionState, + error: MessageVerificationError, + ) -> Result, SessionOperationError> { + match error { + MessageVerificationError::SeqNumberTooLow { + expected, + actual, + possible_duplicate, + } => Ok(self + .handle_sequence_number_too_low(state, expected, actual, possible_duplicate) + .await), + MessageVerificationError::SeqNumberTooHigh { expected, actual } => { + self.handle_sequence_number_too_high(state, expected, actual) + .await + } + MessageVerificationError::IncorrectBeginString(begin_string) => Ok(Some( + self.handle_incorrect_begin_string(state, begin_string) + .await, + )), + MessageVerificationError::IncorrectCompId { + comp_id, + comp_id_type, + msg_seq_num, + } => Ok(Some( + self.handle_incorrect_comp_id(state, comp_id, comp_id_type, msg_seq_num) + .await, + )), + MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { + self.handle_sending_time_accuracy_problem( + state, + msg_seq_num, + "unexpected sending time", + ) + .await; + Ok(None) + } + MessageVerificationError::SendingTimeMissing { msg_seq_num } => { + self.handle_sending_time_accuracy_problem( + state, + msg_seq_num, + "sending time missing", + ) + .await; + Ok(None) + } + MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { + self.handle_original_sending_time_missing(state, msg_seq_num) + .await; + Ok(None) + } + MessageVerificationError::OriginalSendingTimeAfterSendingTime { + msg_seq_num, .. + } => { + self.handle_sending_time_accuracy_problem( + state, + msg_seq_num, + "original sending time is after sending time", + ) + .await; + Ok(None) + } + } + } + + async fn handle_incorrect_begin_string( + &mut self, + state: &SessionState, + received_begin_string: String, + ) -> SessionState { + self.logout_and_terminate( + state, + &format!("beginString={received_begin_string} is not supported"), + ) + .await; + SessionState::new_disconnected(true, "incorrect begin string") + } + + async fn handle_incorrect_comp_id( + &mut self, + state: &SessionState, + received_comp_id: String, + comp_id_type: CompIdType, + msg_seq_num: u64, + ) -> SessionState { + error!( + "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&format!("invalid comp ID {received_comp_id}")); + if let Some(writer) = state.get_writer() + && let Err(err) = self.send_message(writer, reject).await + { + error!("failed to send reject message with invalid comp ID: {err}"); + } + self.logout_and_terminate(state, "incorrect comp ID received") + .await; + SessionState::new_disconnected(true, "incorrect comp ID") + } + + async fn handle_sequence_number_too_low( + &mut self, + state: &SessionState, + expected: u64, + actual: u64, + possible_duplicate: bool, + ) -> Option { + if possible_duplicate { + warn!( + "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" + ); + return None; + } + error!( + "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." + ); + let reason = format!("sequence number too low (actual {actual}, expected {expected})"); + self.logout_and_terminate(state, &reason).await; + Some(SessionState::new_disconnected(false, &reason)) + } + + async fn handle_sequence_number_too_high( + &mut self, + state: &mut SessionState, + expected: u64, + actual: u64, + ) -> Result, SessionOperationError> { + match state.try_transition_to_awaiting_resend(expected, actual) { + AwaitingResendTransitionOutcome::Success => { + debug!( + "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." + ); + if let Some(writer) = state.get_writer() { + let request = ResendRequest::new(expected, actual); + self.send_message(writer, request) + .await + .with_send_context("resend request")?; + } + Ok(None) // state already mutated by try_transition_to_awaiting_resend + } + AwaitingResendTransitionOutcome::InvalidState(reason) => { + error!("failed to request resend: {reason}"); + Ok(None) + } + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { + state.disconnect_writer().await; + Ok(Some(SessionState::new_disconnected( + false, + "awaiting resend begin seq number unexpectedly lower than the previous resend request's", + ))) + } + AwaitingResendTransitionOutcome::AttemptsExceeded => { + state.disconnect_writer().await; + Ok(Some(SessionState::new_disconnected( + false, + "resend request attempts exceeded, manual intervention required", + ))) + } + } + } + + async fn handle_sending_time_accuracy_problem( + &mut self, + state: &SessionState, + msg_seq_num: u64, + text: &str, + ) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) + .text(text); + if let Some(writer) = state.get_writer() + && let Err(err) = self.send_message(writer, reject).await + { + error!("failed to send reject for time accuracy problem: {err}"); + } + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + + async fn handle_original_sending_time_missing( + &mut self, + state: &SessionState, + msg_seq_num: u64, + ) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("original sending time is required"); + if let Some(writer) = state.get_writer() + && let Err(err) = self.send_message(writer, reject).await + { + error!("failed to send reject for time missing tag: {err}"); + } + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + + /// Send a logout message and immediately disconnect. + async fn logout_and_terminate(&mut self, state: &SessionState, reason: &str) { + if let Some(writer) = state.get_writer() { + let logout = Logout::with_reason(reason.to_string()); + match self.prepare_message(logout).await { + Ok(prepared) => writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout during session termination: {err}"), + } + writer.disconnect().await; + } + } + + pub async fn resend_messages( + &mut self, + writer: &WriterRef, + begin: u64, + end: u64, + ) -> Result<(), SessionOperationError> { + info!(begin, end, "resending messages as requested"); + let messages = self.store.get_slice(begin as usize, end as usize).await?; + + let no = messages.len(); + debug!(number_of_messages = no, "number of messages"); + + let mut reset_start: Option = None; + let mut sequence_number = 0; + + for msg in messages { + let mut message = self + .message_builder + .build(msg.as_slice()) + .into_message() + .ok_or_else(|| { + SessionOperationError::StoredMessageParse(format!( + "failed to build message for raw message: {msg:?}" + )) + })?; + sequence_number = get_msg_seq_num(&message); + let message_type: String = message + .header() + .get::<&str>(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? + .to_string(); + + if is_admin(&message_type) { + if reset_start.is_none() { + reset_start = Some(sequence_number); + } + continue; + } + + if let Some(begin) = reset_start { + let end = sequence_number; + Self::log_skipped_admin_messages(begin, end); + self.send_sequence_reset(writer, begin, end).await?; + reset_start = None; + } + + if let Err(e) = prepare_message_for_resend(&mut message) { + error!( + error = e, + "failed to prepare message for resend, sending original" + ); + } + writer + .send_raw_message(RawFixMessage::new(message.encode(self.message_config)?)) + .await; + + if enabled!(tracing::Level::DEBUG) + && let Ok(m) = String::from_utf8(msg.clone()) + { + debug!(sequence_number, message = m, "resent message"); + } + } + + if let Some(begin) = reset_start { + let end = sequence_number; + Self::log_skipped_admin_messages(begin, end); + self.send_sequence_reset(writer, begin, end).await?; + } + + Ok(()) + } + + pub async fn send_sequence_reset( + &mut self, + writer: &WriterRef, + begin: u64, + end: u64, + ) -> Result<(), SessionOperationError> { + let sequence_reset = SequenceReset { + gap_fill: true, + new_seq_no: end, + }; + let raw_message = generate_message( + &self.config.begin_string, + &self.config.sender_comp_id, + &self.config.target_comp_id, + begin, + sequence_reset, + )?; + + writer + .send_raw_message(RawFixMessage::new(raw_message)) + .await; + debug!(begin, end, "sent reset sequence"); + + Ok(()) + } + + fn log_skipped_admin_messages(begin: u64, end: u64) { + info!( + begin, + end, "skipped admin message(s) during resend, requesting reset for these" + ); + } + + pub async fn handle_invalid_msg_type( + &mut self, + state: &SessionState, + message: &Message, + msg_type: &str, + ) { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidMsgtype) + .text(&format!("invalid message type {msg_type}")); + if let Some(writer) = state.get_writer() + && let Err(err) = self.send_message(writer, reject).await + { + error!("failed to send reject message for invalid msgtype: {err}"); + } + + #[allow(clippy::collapsible_if)] + if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) + && self.store.next_target_seq_number() == seq_num + { + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } + } } const TEST_REQUEST_THRESHOLD: f64 = 1.2; @@ -166,7 +561,7 @@ impl SessionState { } } - fn get_writer(&self) -> Option<&WriterRef> { + pub(crate) fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) From 79518abd4ab9bf519442b14adf34448cb7f0cc9c Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 12:02:19 +0100 Subject: [PATCH 07/14] Move process_message and a bunch more functions off of session --- crates/hotfix/src/session.rs | 694 +++++------------- crates/hotfix/src/session/state.rs | 170 ++--- crates/hotfix/src/session/state/active.rs | 415 ++++++++++- .../src/session/state/awaiting_logon.rs | 68 +- .../src/session/state/awaiting_logout.rs | 56 +- .../src/session/state/awaiting_resend.rs | 412 ++++++++++- 6 files changed, 1205 insertions(+), 610 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 4b5d9842..177685d5 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -10,6 +10,7 @@ use chrono::Utc; use hotfix_message::dict::Dictionary; use hotfix_message::message::{Config as MessageConfig, Message}; use hotfix_message::{MessageBuilder, Part}; +use std::future::Future; use std::pin::Pin; use tokio::select; use tokio::sync::mpsc; @@ -17,21 +18,14 @@ use tokio::time::{Duration, Instant, Sleep, sleep, sleep_until}; use tracing::{debug, error, info, warn}; use crate::Application; -use crate::application::{InboundDecision, OutboundDecision}; use crate::config::SessionConfig; use crate::message::OutboundMessage; -use crate::message::business_reject::BusinessReject; use crate::message::generate_message; -use crate::message::heartbeat::Heartbeat; use crate::message::logon::{Logon, ResetSeqNumConfig}; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; use crate::message::reject::Reject; use crate::message::resend_request::ResendRequest; -use crate::message::sequence_reset::SequenceReset; -use crate::message::test_request::TestRequest; -use crate::message::verification::verify_message; -use crate::message::verification_error::MessageVerificationError; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; @@ -44,16 +38,13 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{AwaitingLogonState, AwaitingLogoutState, SessionCtx, TestRequestId}; +use crate::session::state::{SessionCtx, TestRequestId, TransitionResult}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; use event::SessionEvent; use hotfix_message::parsed_message::{InvalidReason, ParsedMessage}; -use hotfix_message::session_fields::{ - BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_SEQ_NUM, MSG_TYPE, NEW_SEQ_NO, - SessionRejectReason, TEST_REQ_ID, -}; +use hotfix_message::session_fields::{MSG_SEQ_NUM, SessionRejectReason}; const SCHEDULE_CHECK_INTERVAL: u64 = 1; @@ -119,64 +110,24 @@ where raw_message: RawFixMessage, ) -> Result<(), SessionOperationError> { debug!("received message: {}", raw_message); + + // Reset peer timer before dispatching (if not expecting test response) if !self.state.is_expecting_test_response() { - // if we are not awaiting a specific test response, any message can reset the timer - // otherwise only a heartbeat with the corresponding TestReqID can self.reset_peer_timer(None); } match self.message_builder.build(raw_message.as_bytes()) { ParsedMessage::Valid(message) => { - self.process_message(message).await?; - self.check_end_of_resend().await?; + self.dispatch_valid_message(message).await?; } ParsedMessage::Garbled(r) => { - // garbled messages should be skipped and we should assume it was a transmission error let message = raw_message.to_string(); let reason = format!("{r:?}"); error!(message, reason, "received garbled message"); } - ParsedMessage::Invalid { message, reason } => match reason { - InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidTagNumber) - .text(&format!("invalid field {tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid field")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } - InvalidReason::InvalidComponent(_component_name) => { - // TODO: what's the correct way to handle this? - warn!("received invalid component"); - } - InvalidReason::InvalidMsgType(msg_type) => { - self.handle_invalid_msg_type(message, &msg_type).await; - } - InvalidReason::InvalidOrderInGroup { tag, .. } => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason( - SessionRejectReason::RepeatingGroupFieldsOutOfOrder, - ) - .text(&format!("field appears in incorrect order:{tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid group order")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } - }, + ParsedMessage::Invalid { message, reason } => { + self.handle_invalid_parsed_message(message, reason).await?; + } ParsedMessage::UnexpectedError(err) => { error!("unexpected error: {:?}", err); } @@ -185,151 +136,160 @@ where Ok(()) } - async fn process_message(&mut self, message: Message) -> Result<(), SessionOperationError> { - let message_type: &str = message - .header() - .get(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; - - if let SessionState::AwaitingResend(state) = &mut self.state { - let seq_number = get_msg_seq_num(&message); - if seq_number > state.end_seq_number && message_type != ResendRequest::MSG_TYPE { - state.inbound_queue.push_back(message); - return Ok(()); - } - } - - if let SessionState::AwaitingLogon(_) = &mut self.state { - // TODO: should this (and all inbound message processing) logic be pushed into the state? - if message_type != Logon::MSG_TYPE { - self.state.disconnect_writer().await; - return Ok(()); - } - } - - match message_type { - Heartbeat::MSG_TYPE => { - self.on_heartbeat(&message).await?; - } - TestRequest::MSG_TYPE => { - self.on_test_request(&message).await?; - } - ResendRequest::MSG_TYPE => { - self.on_resend_request(&message).await?; - } - Reject::MSG_TYPE => { - self.on_reject(&message).await?; - } - SequenceReset::MSG_TYPE => { - self.on_sequence_reset(&message).await?; + async fn handle_invalid_parsed_message( + &mut self, + message: Message, + reason: InvalidReason, + ) -> Result<(), SessionOperationError> { + match reason { + InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidTagNumber) + .text(&format!("invalid field {tag}")); + self.send_message(reject) + .await + .with_send_context("reject for invalid field")?; + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } } - Logout::MSG_TYPE => { - self.on_logout(&message).await?; + InvalidReason::InvalidComponent(_component_name) => { + warn!("received invalid component"); } - Logon::MSG_TYPE => { - self.on_logon(&message).await?; + InvalidReason::InvalidMsgType(msg_type) => { + let Session { + ref state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + if let Some(writer) = state.get_writer() { + ctx.handle_invalid_msg_type(writer, &message, &msg_type) + .await; + } } - _ => self.process_app_message(&message).await?, - } - - Ok(()) - } - - async fn process_app_message( - &mut self, - message: &Message, - ) -> Result<(), SessionOperationError> { - match self.verify_message(message, true, true) { - Ok(_) => { - match self.application.on_inbound_message(message).await { - InboundDecision::Accept => {} - InboundDecision::Reject { reason, text } => { - let msg_type: &str = message - .header() - .get(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; - let mut reject = BusinessReject::new(msg_type, reason) - .ref_seq_num(get_msg_seq_num(message)); - if let Some(text) = text { - reject = reject.text(&text); - } + InvalidReason::InvalidOrderInGroup { tag, .. } => { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason( + SessionRejectReason::RepeatingGroupFieldsOutOfOrder, + ) + .text(&format!("field appears in incorrect order:{tag}")); self.send_message(reject) .await - .with_send_context("business message reject")?; + .with_send_context("reject for invalid group order")?; } - InboundDecision::TerminateSession => { - error!("failed to send inbound message to application"); - self.state.disconnect_writer().await; + Err(err) => { + error!("failed to get message seq num: {:?}", err); } } - self.store.increment_target_seq_number().await?; } - Err(err) => self.handle_verification_error(err).await?, } - Ok(()) } - async fn check_end_of_resend(&mut self) -> Result<(), SessionOperationError> { - let ended_state = if let SessionState::AwaitingResend(state) = &mut self.state { - if self.store.next_target_seq_number() > state.end_seq_number { - let new_state = - SessionState::new_active(state.writer.clone(), self.config.heartbeat_interval); - Some(std::mem::replace(&mut self.state, new_state)) - } else { - None + fn dispatch_valid_message( + &mut self, + message: Message, + ) -> Pin> + Send + '_>> { + Box::pin(self.dispatch_valid_message_inner(message)) + } + + async fn dispatch_valid_message_inner( + &mut self, + message: Message, + ) -> Result<(), SessionOperationError> { + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + ref mut application, + .. + } = *self; + + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + + let transition = match state { + SessionState::Active(s) => s.on_fix_message(&mut ctx, application, message).await?, + SessionState::AwaitingLogon(s) => { + s.on_fix_message(&mut ctx, application, message).await? } - } else { - None + SessionState::AwaitingResend(s) => { + s.on_fix_message(&mut ctx, application, message).await? + } + SessionState::AwaitingLogout(s) => { + s.on_fix_message(&mut ctx, application, message).await? + } + SessionState::Disconnected(_) => TransitionResult::Stay, }; - if let Some(SessionState::AwaitingResend(mut state)) = ended_state { - // we have reached the end of the resend, - // process queued messages and resume normal operation - debug!("resend is done, processing backlog"); - while let Some(msg) = state.inbound_queue.pop_front() { - let seq_number: u64 = msg.get(MSG_SEQ_NUM).unwrap_or_else(|e| { - error!("failed to get seq number: {:?}", e); - 0 - }); - let msg_type: &str = msg.header().get(MSG_TYPE).unwrap_or(""); - debug!(seq_number, msg_type, "processing queued message"); - - if msg_type == ResendRequest::MSG_TYPE { - // ResendRequest was already processed when it arrived (it bypasses - // the queue in process_message). Just increment the target seq number - // for sequence accounting purposes. - self.store.increment_target_seq_number().await?; - } else { - self.process_message(msg).await?; + // Let ctx go out of scope before we can mutate self.state + let _ = ctx; + + self.apply_transition(transition).await + } + + async fn apply_transition( + &mut self, + transition: TransitionResult, + ) -> Result<(), SessionOperationError> { + match transition { + TransitionResult::Stay => {} + TransitionResult::TransitionTo(new_state) => { + self.state = new_state; + } + TransitionResult::TransitionWithBacklog { + new_state, + mut backlog, + } => { + self.state = new_state; + debug!("resend is done, processing backlog"); + while let Some(msg) = backlog.pop_front() { + let seq_number: u64 = msg.get(MSG_SEQ_NUM).unwrap_or_else(|e| { + error!("failed to get seq number: {:?}", e); + 0 + }); + let msg_type: &str = msg + .header() + .get(hotfix_message::session_fields::MSG_TYPE) + .unwrap_or(""); + debug!(seq_number, msg_type, "processing queued message"); + + if msg_type == ResendRequest::MSG_TYPE { + // ResendRequest was already processed when it arrived (it bypasses + // the queue). Just increment the target seq number + // for sequence accounting purposes. + self.store.increment_target_seq_number().await?; + } else { + self.dispatch_valid_message(msg).await?; + } } + debug!("resend backlog is cleared, resuming normal operation"); } - debug!("resend backlog is cleared, resuming normal operation"); } - Ok(()) } - fn verify_message( - &self, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result<(), MessageVerificationError> { - let expected_seq_number = if check_too_high || check_too_low { - Some(self.store.next_target_seq_number()) - } else { - None - }; - verify_message( - message, - &self.config, - expected_seq_number, - check_too_high, - check_too_low, - ) - } - async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { if let SessionState::Disconnected(s) = &self.state { self.state = s.on_connect(writer, Duration::from_secs(self.config.logon_timeout)); @@ -356,298 +316,6 @@ where } } - async fn on_logon(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) = &self.state { - match self.verify_message(message, true, true) { - Ok(_) => { - // happy logon flow, the session is now active - self.state = - SessionState::new_active(writer.clone(), self.config.heartbeat_interval); - self.application.on_logon().await; - self.store.increment_target_seq_number().await?; - } - Err(err) => self.handle_verification_error(err).await?, - } - } else { - error!("received unexpected logon message"); - } - - Ok(()) - } - - async fn on_logout(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, false, false) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - if self.state.is_logged_on() { - self.send_logout("Logout acknowledged").await?; - } - - self.application.on_logout("peer has logged us out").await; - - match self.state { - // if the session is already disconnected, we have nothing else to do - SessionState::Disconnected(..) => {} - // if we initiated the logout, preserve the reconnect flag - SessionState::AwaitingLogout(AwaitingLogoutState { reconnect, .. }) => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(reconnect, "logout completed"); - } - // otherwise assume it makes sense to try to reconnect - _ => { - self.state.disconnect_writer().await; - self.state = SessionState::new_disconnected(true, "peer has logged us out") - } - } - - self.store.increment_target_seq_number().await?; - Ok(()) - } - - async fn on_heartbeat(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, true, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - if let (Some(expected_req_id), Ok(message_req_id)) = ( - &self.state.expected_test_response_id(), - message.get::<&str>(TEST_REQ_ID), - ) && expected_req_id.as_str() == message_req_id - { - debug!("received response for TestRequest, resetting timer"); - self.reset_peer_timer(None); - } - - self.store.increment_target_seq_number().await?; - Ok(()) - } - - async fn on_test_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, true, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { - // TODO: send reject? - todo!() - }); - - self.store.increment_target_seq_number().await?; - - self.send_message(Heartbeat::for_request(req_id.to_string())) - .await - .with_send_context("heartbeat response")?; - - Ok(()) - } - - async fn on_resend_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if !self.state.is_connected() { - warn!("received resend request while disconnected, ignoring"); - return Ok(()); - } - - // Verify with check_too_high=false so ResendRequest is never blocked by seq-too-high. - // This is the key part of the QFJ-673 deadlock fix: when both sides send ResendRequest - // simultaneously, each side's ResendRequest will have a seq number higher than expected. - // By not treating that as an error, we allow the ResendRequest to be processed. - match self.verify_message(message, false, true) { - Ok(_) => {} - Err(err) => { - self.handle_verification_error(err).await?; - return Ok(()); - } - } - - let msg_seq_num = get_msg_seq_num(message); - let expected = self.store.next_target_seq_number(); - - // If seq is too high and we're in AwaitingResend, queue it for seq accounting - // when the gap fill catches up, but still process the resend below. - if msg_seq_num > expected - && let SessionState::AwaitingResend(state) = &mut self.state - { - state.inbound_queue.push_back(message.clone()); - } - - let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { - Ok(seq_number) => seq_number, - Err(_) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing begin sequence number for resend request"); - self.send_message(reject) - .await - .with_send_context("reject for missing BEGIN_SEQ_NO")?; - return Ok(()); - } - }; - - let end_seq_number: u64 = match message.get(END_SEQ_NO) { - Ok(seq_number) => { - let last_seq_number = self.store.next_sender_seq_number() - 1; - if seq_number == 0 { - last_seq_number - } else { - std::cmp::min(seq_number, last_seq_number) - } - } - Err(_) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing end sequence number for resend request"); - self.send_message(reject) - .await - .with_send_context("reject for missing END_SEQ_NO")?; - return Ok(()); - } - }; - - // Only increment target seq if seq matches expected - if msg_seq_num == expected { - self.store.increment_target_seq_number().await?; - } - - self.resend_messages(begin_seq_number, end_seq_number, message) - .await?; - - Ok(()) - } - - /// Handle Reject messages. - async fn on_reject(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Err(err) = self.verify_message(message, false, true) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - self.store.increment_target_seq_number().await?; - Ok(()) - } - - async fn on_sequence_reset(&mut self, message: &Message) -> Result<(), SessionOperationError> { - let msg_seq_num = get_msg_seq_num(message); - let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); - if let Err(err) = self.verify_message(message, is_gap_fill, is_gap_fill) { - self.handle_verification_error(err).await?; - return Ok(()); - } - - let end: u64 = match message.get(NEW_SEQ_NO) { - Ok(new_seq_no) => new_seq_no, - Err(err) => { - error!( - "received sequence reset message without new sequence number: {:?}", - err - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("missing NewSeqNo tag in sequence reset message"); - self.send_message(reject) - .await - .with_send_context("reject for missing NEW_SEQ_NO")?; - - // note: we don't increment the target seq number here - // this is an ambiguous case in the specification, but leaving the - // sequence number as is feels the safest - return Ok(()); - } - }; - - // sequence resets cannot move the target seq number backwards - // regardless of whether the message is a gap fill or not - if end <= self.store.next_target_seq_number() { - error!( - "received sequence reset message which would move target seq number backwards: {end}", - ); - let text = - format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&text); - self.send_message(reject) - .await - .with_send_context("reject for invalid sequence reset")?; - return Ok(()); - } - - self.store.set_target_seq_number(end - 1).await?; - Ok(()) - } - - async fn handle_verification_error( - &mut self, - error: MessageVerificationError, - ) -> Result<(), SessionOperationError> { - let Session { - ref mut state, - ref mut store, - ref config, - ref message_builder, - ref message_config, - .. - } = *self; - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - if let Some(new_state) = ctx.handle_verification_error(state, error).await? { - *state = new_state; - } - Ok(()) - } - - async fn handle_invalid_msg_type(&mut self, message: Message, msg_type: &str) { - let Session { - ref mut state, - ref mut store, - ref config, - ref message_builder, - ref message_config, - .. - } = *self; - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - ctx.handle_invalid_msg_type(state, &message, msg_type).await; - } - - async fn resend_messages( - &mut self, - begin: u64, - end: u64, - _message: &Message, - ) -> Result<(), SessionOperationError> { - let writer = self.state.get_writer().cloned(); - if let Some(writer) = writer { - let Session { - ref mut store, - ref config, - ref message_builder, - ref message_config, - .. - } = *self; - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - ctx.resend_messages(&writer, begin, end).await?; - self.reset_heartbeat_timer(); - } - Ok(()) - } - fn reset_heartbeat_timer(&mut self) { self.state .reset_heartbeat_timer(self.config.heartbeat_interval); @@ -658,10 +326,12 @@ where .reset_peer_timer(self.config.heartbeat_interval, test_request_id); } - async fn send_app_message(&mut self, message: App::Outbound) -> Result { - if !self.state.is_connected() { - return Err(SendError::Disconnected); - } + /// Legacy send_app_message for non-Active connected states. + async fn send_app_message_legacy( + &mut self, + message: App::Outbound, + ) -> Result { + use crate::application::OutboundDecision; match self.application.on_outbound_message(&message).await { OutboundDecision::Send => { @@ -680,6 +350,7 @@ where } } + /// Legacy send_message used by send_logon, send_logout, and error handling paths. async fn send_message( &mut self, message: impl OutboundMessage, @@ -745,12 +416,6 @@ where } /// Sends a logout message and immediately disconnects the counterparty. - /// - /// This should be used sparingly in scenarios where there is a major issue - /// requiring operational intervention, such as the sequence number being lower - /// than expected, or some other key header field containing an invalid value. - /// - /// In other scenarios, [`initiate_graceful_logout`] should be preferred. async fn logout_and_terminate(&mut self, reason: &str) { if let Err(err) = self.send_logout(reason).await { warn!("failed to send logout during session termination: {}", err); @@ -758,11 +423,7 @@ where self.state.disconnect_writer().await; } - /// Sends a logout message and puts the session state into an [`AwaitingLogout`] state. - /// - /// The session waits for a configurable timeout period for the counterparty to - /// respond with a `Logout` message. If no response is received within the timeout - /// period, it disconnects the counterparty. + /// Sends a logout message and puts the session state into an AwaitingLogout state. async fn initiate_graceful_logout( &mut self, reason: &str, @@ -812,10 +473,41 @@ where async fn handle_outbound_message(&mut self, request: OutboundRequest) { let OutboundRequest { message, confirm } = request; - let result = self.send_app_message(message).await; + + let is_active = matches!(self.state, SessionState::Active(_)); + let is_connected = self.state.is_connected(); + + let result = if !is_connected { + Err(SendError::Disconnected) + } else if is_active { + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + ref mut application, + .. + } = *self; + + if let SessionState::Active(s) = state { + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + s.send_app_message(&mut ctx, application, message).await + } else { + unreachable!() + } + } else { + // Legacy path: session is connected but not Active (e.g. AwaitingLogon). + self.send_app_message_legacy(message).await + }; + match confirm { Some(tx) => { - // Ignore send errors - receiver may have been dropped let _ = tx.send(result); } None => { @@ -917,8 +609,6 @@ where // we are in the same period, nothing needs to be done } Ok(SessionPeriodComparison::DifferentPeriod) => { - // the message store is for a previous session, - // we need to terminate this session, reset the store, and reestablish the session self.logout_and_terminate("session period changed").await; if let Err(err) = self.store.reset().await { error!("error resetting session store: {err:}"); @@ -927,8 +617,6 @@ where } } Ok(SessionPeriodComparison::OutsideSessionTime { .. }) => { - // the creation_time was recorded outside the session schedule, - // treat this similarly to a different period - reset the store warn!("store creation time is outside session schedule, resetting store"); self.logout_and_terminate("creation time outside schedule") .await; @@ -939,22 +627,18 @@ where } } Err(err) => { - // actual schedule calculation error (e.g., DST transition, date overflow) error!("error checking session period: {err:?}"); self.logout_and_terminate("internal error").await; } } - } else if self.state.is_connected() { - // we are currently outside scheduled session time - if let Err(err) = self + } else if self.state.is_connected() + && let Err(err) = self .initiate_graceful_logout("End of session time", true) .await - { - error!(err = ?err, "failed to initiate graceful logout"); - } + { + error!(err = ?err, "failed to initiate graceful logout"); } - // we always need to reschedule the check, otherwise we won't be able to resume an inactive session let deadline = Instant::now() + Duration::from_secs(SCHEDULE_CHECK_INTERVAL); self.schedule_check_timer.as_mut().reset(deadline); } @@ -970,9 +654,6 @@ where /// Extracts MsgSeqNum from a message header. /// -/// To be removed once https://github.com/Validus-Risk-Management/hotfix/issues/301 -/// is implemented. -/// /// # Panics /// Panics if the message does not contain a valid MsgSeqNum field. /// This should never happen for messages that have passed validation. @@ -1282,8 +963,6 @@ mod tests { session.handle_schedule_check().await; // Store reset should have been called (indicates DifferentPeriod branch was taken) - // Note: logout_and_terminate disconnects the writer but state transition to - // Disconnected happens asynchronously via event processing, not in this call assert!( session.store.was_reset_called(), "Store reset should be called for different period" @@ -1348,7 +1027,6 @@ mod tests { let state = SessionState::new_active(writer, 30); // Creation time is today but at a time outside the schedule window - // Use a time that's definitely outside the window (6 hours from now) let outside_hour = (current_hour + 6) % 24; let creation_time = DateTime::from_naive_utc_and_offset( NaiveDate::from_ymd_opt(now.year(), now.month(), now.day()) diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 0e1fbfee..ae41127b 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -7,7 +7,9 @@ mod disconnected; pub(crate) use active::{ActiveState, calculate_peer_interval}; pub(crate) use awaiting_logon::AwaitingLogonState; pub(crate) use awaiting_logout::AwaitingLogoutState; -pub(crate) use awaiting_resend::{AwaitingResendState, AwaitingResendTransitionOutcome}; +pub(crate) use awaiting_resend::AwaitingResendState; +#[cfg(test)] +pub(crate) use awaiting_resend::AwaitingResendTransitionOutcome; pub(crate) use disconnected::DisconnectedState; use crate::config::SessionConfig; @@ -15,12 +17,11 @@ use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; use crate::message::reject::Reject; -use crate::message::resend_request::ResendRequest; use crate::message::sequence_reset::SequenceReset; use crate::message::verification::verify_message as verify_message_impl; use crate::message::verification_error::{CompIdType, MessageVerificationError}; use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; -use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; +use crate::session::error::{InternalSendError, SessionOperationError}; use crate::session::event::AwaitingActiveSessionResponse; use crate::session::get_msg_seq_num; use crate::session::info::Status as SessionInfoStatus; @@ -30,6 +31,7 @@ use hotfix_message::message::{Config as MessageConfig, Message}; use hotfix_message::session_fields::SessionRejectReason; use hotfix_message::{MessageBuilder, Part}; use hotfix_store::MessageStore; +use std::collections::VecDeque; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; @@ -46,11 +48,26 @@ pub(crate) struct SessionCtx<'a, Store> { pub(crate) struct PreparedMessage { pub seq_num: u64, - #[allow(dead_code)] // used in later sub-phases + #[allow(dead_code)] pub msg_type: String, pub raw: RawFixMessage, } +pub(crate) enum TransitionResult { + Stay, + TransitionTo(SessionState), + TransitionWithBacklog { + new_state: SessionState, + backlog: VecDeque, + }, +} + +pub(crate) enum VerifyResult { + Passed, + SeqTooHigh { expected: u64, actual: u64 }, + ErrorHandled(Option), +} + impl SessionCtx<'_, Store> { pub async fn prepare_message( &mut self, @@ -99,7 +116,6 @@ impl SessionCtx<'_, Store> { Ok(prepared.seq_num) } - #[allow(dead_code)] // used when states handle their own messages in 2e pub fn verify_message( &self, message: &Message, @@ -120,10 +136,33 @@ impl SessionCtx<'_, Store> { ) } - /// Handle a verification error. Returns `Some(new_state)` if a state transition is needed. + /// Verify a message and handle the error if verification fails. + /// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, + /// allowing the caller to handle the transition. + pub async fn verify_and_handle( + &mut self, + writer: &WriterRef, + message: &Message, + check_too_high: bool, + check_too_low: bool, + ) -> Result { + match self.verify_message(message, check_too_high, check_too_low) { + Ok(()) => Ok(VerifyResult::Passed), + Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { + Ok(VerifyResult::SeqTooHigh { expected, actual }) + } + Err(err) => { + let transition = self.handle_verification_error(writer, err).await?; + Ok(VerifyResult::ErrorHandled(transition)) + } + } + } + + /// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). + /// Returns `Some(new_state)` if a state transition is needed. pub async fn handle_verification_error( &mut self, - state: &mut SessionState, + writer: &WriterRef, error: MessageVerificationError, ) -> Result, SessionOperationError> { match error { @@ -132,14 +171,18 @@ impl SessionCtx<'_, Store> { actual, possible_duplicate, } => Ok(self - .handle_sequence_number_too_low(state, expected, actual, possible_duplicate) + .handle_sequence_number_too_low(writer, expected, actual, possible_duplicate) .await), MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - self.handle_sequence_number_too_high(state, expected, actual) - .await + // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), + // but handle gracefully if it is. + warn!( + "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" + ); + Ok(None) } MessageVerificationError::IncorrectBeginString(begin_string) => Ok(Some( - self.handle_incorrect_begin_string(state, begin_string) + self.handle_incorrect_begin_string(writer, begin_string) .await, )), MessageVerificationError::IncorrectCompId { @@ -147,12 +190,12 @@ impl SessionCtx<'_, Store> { comp_id_type, msg_seq_num, } => Ok(Some( - self.handle_incorrect_comp_id(state, comp_id, comp_id_type, msg_seq_num) + self.handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) .await, )), MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { self.handle_sending_time_accuracy_problem( - state, + writer, msg_seq_num, "unexpected sending time", ) @@ -161,7 +204,7 @@ impl SessionCtx<'_, Store> { } MessageVerificationError::SendingTimeMissing { msg_seq_num } => { self.handle_sending_time_accuracy_problem( - state, + writer, msg_seq_num, "sending time missing", ) @@ -169,7 +212,7 @@ impl SessionCtx<'_, Store> { Ok(None) } MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { - self.handle_original_sending_time_missing(state, msg_seq_num) + self.handle_original_sending_time_missing(writer, msg_seq_num) .await; Ok(None) } @@ -177,7 +220,7 @@ impl SessionCtx<'_, Store> { msg_seq_num, .. } => { self.handle_sending_time_accuracy_problem( - state, + writer, msg_seq_num, "original sending time is after sending time", ) @@ -189,11 +232,11 @@ impl SessionCtx<'_, Store> { async fn handle_incorrect_begin_string( &mut self, - state: &SessionState, + writer: &WriterRef, received_begin_string: String, ) -> SessionState { self.logout_and_terminate( - state, + writer, &format!("beginString={received_begin_string} is not supported"), ) .await; @@ -202,7 +245,7 @@ impl SessionCtx<'_, Store> { async fn handle_incorrect_comp_id( &mut self, - state: &SessionState, + writer: &WriterRef, received_comp_id: String, comp_id_type: CompIdType, msg_seq_num: u64, @@ -213,19 +256,17 @@ impl SessionCtx<'_, Store> { let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::ValueIsIncorrect) .text(&format!("invalid comp ID {received_comp_id}")); - if let Some(writer) = state.get_writer() - && let Err(err) = self.send_message(writer, reject).await - { + if let Err(err) = self.send_message(writer, reject).await { error!("failed to send reject message with invalid comp ID: {err}"); } - self.logout_and_terminate(state, "incorrect comp ID received") + self.logout_and_terminate(writer, "incorrect comp ID received") .await; SessionState::new_disconnected(true, "incorrect comp ID") } async fn handle_sequence_number_too_low( &mut self, - state: &SessionState, + writer: &WriterRef, expected: u64, actual: u64, possible_duplicate: bool, @@ -240,62 +281,20 @@ impl SessionCtx<'_, Store> { "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." ); let reason = format!("sequence number too low (actual {actual}, expected {expected})"); - self.logout_and_terminate(state, &reason).await; + self.logout_and_terminate(writer, &reason).await; Some(SessionState::new_disconnected(false, &reason)) } - async fn handle_sequence_number_too_high( - &mut self, - state: &mut SessionState, - expected: u64, - actual: u64, - ) -> Result, SessionOperationError> { - match state.try_transition_to_awaiting_resend(expected, actual) { - AwaitingResendTransitionOutcome::Success => { - debug!( - "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." - ); - if let Some(writer) = state.get_writer() { - let request = ResendRequest::new(expected, actual); - self.send_message(writer, request) - .await - .with_send_context("resend request")?; - } - Ok(None) // state already mutated by try_transition_to_awaiting_resend - } - AwaitingResendTransitionOutcome::InvalidState(reason) => { - error!("failed to request resend: {reason}"); - Ok(None) - } - AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { - state.disconnect_writer().await; - Ok(Some(SessionState::new_disconnected( - false, - "awaiting resend begin seq number unexpectedly lower than the previous resend request's", - ))) - } - AwaitingResendTransitionOutcome::AttemptsExceeded => { - state.disconnect_writer().await; - Ok(Some(SessionState::new_disconnected( - false, - "resend request attempts exceeded, manual intervention required", - ))) - } - } - } - async fn handle_sending_time_accuracy_problem( &mut self, - state: &SessionState, + writer: &WriterRef, msg_seq_num: u64, text: &str, ) { let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) .text(text); - if let Some(writer) = state.get_writer() - && let Err(err) = self.send_message(writer, reject).await - { + if let Err(err) = self.send_message(writer, reject).await { error!("failed to send reject for time accuracy problem: {err}"); } if let Err(err) = self.store.increment_target_seq_number().await { @@ -303,17 +302,11 @@ impl SessionCtx<'_, Store> { } } - async fn handle_original_sending_time_missing( - &mut self, - state: &SessionState, - msg_seq_num: u64, - ) { + async fn handle_original_sending_time_missing(&mut self, writer: &WriterRef, msg_seq_num: u64) { let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("original sending time is required"); - if let Some(writer) = state.get_writer() - && let Err(err) = self.send_message(writer, reject).await - { + if let Err(err) = self.send_message(writer, reject).await { error!("failed to send reject for time missing tag: {err}"); } if let Err(err) = self.store.increment_target_seq_number().await { @@ -322,15 +315,13 @@ impl SessionCtx<'_, Store> { } /// Send a logout message and immediately disconnect. - async fn logout_and_terminate(&mut self, state: &SessionState, reason: &str) { - if let Some(writer) = state.get_writer() { - let logout = Logout::with_reason(reason.to_string()); - match self.prepare_message(logout).await { - Ok(prepared) => writer.send_raw_message(prepared.raw).await, - Err(err) => warn!("failed to send logout during session termination: {err}"), - } - writer.disconnect().await; + pub(crate) async fn logout_and_terminate(&mut self, writer: &WriterRef, reason: &str) { + let logout = Logout::with_reason(reason.to_string()); + match self.prepare_message(logout).await { + Ok(prepared) => writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout during session termination: {err}"), } + writer.disconnect().await; } pub async fn resend_messages( @@ -397,6 +388,7 @@ impl SessionCtx<'_, Store> { } if let Some(begin) = reset_start { + // the final reset if needed let end = sequence_number; Self::log_skipped_admin_messages(begin, end); self.send_sequence_reset(writer, begin, end).await?; @@ -440,7 +432,7 @@ impl SessionCtx<'_, Store> { pub async fn handle_invalid_msg_type( &mut self, - state: &SessionState, + writer: &WriterRef, message: &Message, msg_type: &str, ) { @@ -449,9 +441,7 @@ impl SessionCtx<'_, Store> { let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::InvalidMsgtype) .text(&format!("invalid message type {msg_type}")); - if let Some(writer) = state.get_writer() - && let Err(err) = self.send_message(writer, reject).await - { + if let Err(err) = self.send_message(writer, reject).await { error!("failed to send reject message for invalid msgtype: {err}"); } @@ -594,6 +584,7 @@ impl SessionState { } } + #[cfg(test)] pub fn try_transition_to_awaiting_resend( &mut self, begin: u64, @@ -693,6 +684,7 @@ impl SessionState { self.get_writer().is_some() } + #[cfg(test)] pub fn is_logged_on(&self) -> bool { matches!(self, SessionState::Active(_)) || matches!(self, SessionState::AwaitingResend { .. }) diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index aa506d2a..6d87f2ae 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -1,12 +1,27 @@ +use crate::Application; +use crate::application::{InboundDecision, OutboundDecision}; +use crate::message::business_reject::BusinessReject; use crate::message::heartbeat::Heartbeat; +use crate::message::logon::Logon; use crate::message::logout::Logout; +use crate::message::reject::Reject; +use crate::message::resend_request::ResendRequest; +use crate::message::sequence_reset::SequenceReset; use crate::message::test_request::TestRequest; -use crate::session::state::{SessionCtx, SessionState, TestRequestId}; +use crate::session::error::{InternalSendResultExt, SendError, SendOutcome, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::state::{ + AwaitingResendState, SessionCtx, SessionState, TestRequestId, TransitionResult, VerifyResult, +}; use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::{ + BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_TYPE, NEW_SEQ_NO, SessionRejectReason, TEST_REQ_ID, +}; use hotfix_store::MessageStore; use std::time::Duration; use tokio::time::Instant; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; pub(crate) struct ActiveState { /// The writer's reference to send messages to the counterparty @@ -95,6 +110,402 @@ impl ActiveState { self.writer.send_raw_message(prepared.raw).await; self.reset_heartbeat_timer(ctx.config.heartbeat_interval); } + + pub(crate) async fn on_fix_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + match message_type { + Heartbeat::MSG_TYPE => self.on_heartbeat(ctx, &message).await, + TestRequest::MSG_TYPE => self.on_test_request(ctx, &message).await, + ResendRequest::MSG_TYPE => self.on_resend_request(ctx, &message).await, + Reject::MSG_TYPE => self.on_reject(ctx, &message).await, + SequenceReset::MSG_TYPE => self.on_sequence_reset(ctx, &message).await, + Logout::MSG_TYPE => self.on_logout(ctx, app, &message).await, + Logon::MSG_TYPE => { + error!("received unexpected logon message"); + Ok(TransitionResult::Stay) + } + _ => self.on_app_message(ctx, app, &message).await, + } + } + + async fn on_heartbeat( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + if let (Some(expected_req_id), Ok(message_req_id)) = ( + self.expected_test_response_id(), + message.get::<&str>(TEST_REQ_ID), + ) && expected_req_id.as_str() == message_req_id + { + debug!("received response for TestRequest, resetting timer"); + self.reset_peer_timer(ctx.config.heartbeat_interval, None); + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_test_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { + // TODO: send reject? + todo!() + }); + + ctx.store.increment_target_seq_number().await?; + + ctx.send_message(&self.writer, Heartbeat::for_request(req_id.to_string())) + .await + .with_send_context("heartbeat response")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + + Ok(TransitionResult::Stay) + } + + async fn on_resend_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // ResendRequest with check_too_high=false should never get SeqTooHigh, + // but handle gracefully + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let msg_seq_num = get_msg_seq_num(message); + let expected = ctx.store.next_target_seq_number(); + + let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { + Ok(seq_number) => seq_number, + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing begin sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing BEGIN_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + let end_seq_number: u64 = match message.get(END_SEQ_NO) { + Ok(seq_number) => { + let last_seq_number = ctx.store.next_sender_seq_number() - 1; + if seq_number == 0 { + last_seq_number + } else { + std::cmp::min(seq_number, last_seq_number) + } + } + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing end sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing END_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + // Only increment target seq if seq matches expected + if msg_seq_num == expected { + ctx.store.increment_target_seq_number().await?; + } + + ctx.resend_messages(&self.writer, begin_seq_number, end_seq_number) + .await?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + + Ok(TransitionResult::Stay) + } + + async fn on_reject( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_sequence_reset( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &hotfix_message::message::Message, + ) -> Result { + let msg_seq_num = get_msg_seq_num(message); + let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); + match ctx + .verify_and_handle(&self.writer, message, is_gap_fill, is_gap_fill) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let end: u64 = match message.get(NEW_SEQ_NO) { + Ok(new_seq_no) => new_seq_no, + Err(err) => { + error!( + "received sequence reset message without new sequence number: {:?}", + err + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing NewSeqNo tag in sequence reset message"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing NEW_SEQ_NO")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + }; + + if end <= ctx.store.next_target_seq_number() { + error!( + "received sequence reset message which would move target seq number backwards: {end}", + ); + let text = + format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&text); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for invalid sequence reset")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + return Ok(TransitionResult::Stay); + } + + ctx.store.set_target_seq_number(end - 1).await?; + Ok(TransitionResult::Stay) + } + + async fn on_logout( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, false) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // verify with check_too_high=false, so this shouldn't happen + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + // We are logged on, send logout response + let logout = Logout::with_reason("Logout acknowledged".to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => { + self.writer.send_raw_message(prepared.raw).await; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + Err(err) => warn!("failed to send logout acknowledgement: {err}"), + } + + app.on_logout("peer has logged us out").await; + + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(true, "peer has logged us out"), + )) + } + + async fn on_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &hotfix_message::message::Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self + .transition_to_awaiting_resend(ctx, expected, actual) + .await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + match app.on_inbound_message(message).await { + InboundDecision::Accept => {} + InboundDecision::Reject { reason, text } => { + let msg_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + let mut reject = + BusinessReject::new(msg_type, reason).ref_seq_num(get_msg_seq_num(message)); + if let Some(text) = text { + reject = reject.text(&text); + } + ctx.send_message(&self.writer, reject) + .await + .with_send_context("business message reject")?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + } + InboundDecision::TerminateSession => { + error!("failed to send inbound message to application"); + self.writer.disconnect().await; + } + } + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::Stay) + } + + async fn transition_to_awaiting_resend( + &self, + ctx: &mut SessionCtx<'_, Store>, + expected: u64, + actual: u64, + ) -> Result { + debug!("we are behind target (ours: {expected}, theirs: {actual}), requesting resend."); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + let new_state = SessionState::AwaitingResend(AwaitingResendState::new( + self.writer.clone(), + expected, + actual, + )); + Ok(TransitionResult::TransitionTo(new_state)) + } + + pub(crate) async fn send_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: App::Outbound, + ) -> Result { + match app.on_outbound_message(&message).await { + OutboundDecision::Send => { + let seq_num = + ctx.send_message(&self.writer, message) + .await + .map_err(|e| match e { + crate::session::error::InternalSendError::Persist(e) => { + SendError::Persist(e) + } + crate::session::error::InternalSendError::SequenceNumber(e) => { + SendError::SequenceNumber(e) + } + })?; + self.reset_heartbeat_timer(ctx.config.heartbeat_interval); + Ok(SendOutcome::Sent { + sequence_number: seq_num, + }) + } + OutboundDecision::Drop => { + debug!("dropped outbound message as instructed by the application"); + Ok(SendOutcome::Dropped) + } + OutboundDecision::TerminateSession => { + warn!("the application indicated we should terminate the session"); + self.writer.disconnect().await; + Err(SendError::SessionTerminated) + } + } + } } #[inline] diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index 9fe1aa3d..bb83e532 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -1,4 +1,11 @@ +use crate::Application; +use crate::message::logon::Logon; +use crate::session::error::SessionOperationError; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::MSG_TYPE; +use hotfix_store::MessageStore; use tokio::time::Instant; use tracing::warn; @@ -9,13 +16,70 @@ pub(crate) struct AwaitingLogonState { } impl AwaitingLogonState { - pub(crate) async fn on_disconnect(&self, reason: &str) -> super::SessionState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { self.writer.disconnect().await; - super::SessionState::new_disconnected(true, reason) + SessionState::new_disconnected(true, reason) } pub(crate) async fn on_peer_timeout(&self) { warn!("peer didn't respond to our Logon, disconnecting.."); self.writer.disconnect().await; } + + pub(crate) async fn on_fix_message( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + if message_type != Logon::MSG_TYPE { + self.writer.disconnect().await; + return Ok(TransitionResult::Stay); + } + + // process logon + match ctx + .verify_and_handle(&self.writer, &message, true, true) + .await? + { + VerifyResult::Passed => { + // happy logon flow, the session is now active + let new_state = + SessionState::new_active(self.writer.clone(), ctx.config.heartbeat_interval); + app.on_logon().await; + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::TransitionTo(new_state)) + } + VerifyResult::SeqTooHigh { expected, actual } => { + // Unusual during logon, but handle it + use crate::message::resend_request::ResendRequest; + use crate::session::error::InternalSendResultExt; + use crate::session::state::AwaitingResendState; + use tracing::debug; + + debug!( + "we are behind target during logon (ours: {expected}, theirs: {actual}), requesting resend." + ); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + let new_state = SessionState::AwaitingResend(AwaitingResendState::new( + self.writer.clone(), + expected, + actual, + )); + Ok(TransitionResult::TransitionTo(new_state)) + } + VerifyResult::ErrorHandled(Some(new_state)) => { + Ok(TransitionResult::TransitionTo(new_state)) + } + VerifyResult::ErrorHandled(None) => Ok(TransitionResult::Stay), + } + } } diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index a0e3b90f..2466657e 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -1,4 +1,11 @@ +use crate::Application; +use crate::message::logout::Logout; +use crate::session::error::SessionOperationError; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::session_fields::MSG_TYPE; +use hotfix_store::MessageStore; use tokio::time::Instant; use tracing::warn; @@ -9,13 +16,54 @@ pub(crate) struct AwaitingLogoutState { } impl AwaitingLogoutState { - pub(crate) fn on_disconnect(&self, reason: &str) -> super::SessionState { - super::SessionState::new_disconnected(self.reconnect, reason) + pub(crate) fn on_disconnect(&self, reason: &str) -> SessionState { + SessionState::new_disconnected(self.reconnect, reason) } - pub(crate) async fn on_peer_timeout(&self) -> super::SessionState { + pub(crate) async fn on_peer_timeout(&self) -> SessionState { warn!("peer didn't respond to our Logout, disconnecting.."); self.writer.disconnect().await; - super::SessionState::new_disconnected(self.reconnect, "logout timeout") + SessionState::new_disconnected(self.reconnect, "logout timeout") + } + + pub(crate) async fn on_fix_message( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: hotfix_message::message::Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + if message_type == Logout::MSG_TYPE { + // Process the logout + match ctx + .verify_and_handle(&self.writer, &message, false, false) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // verify with check_too_high=false, shouldn't happen + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + app.on_logout("peer has logged us out").await; + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(self.reconnect, "logout completed"), + )) + } else { + // Other messages during logout: increment target seq and stay + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } } } diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index 3e5afd9c..d0bcb98c 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -1,6 +1,25 @@ +use crate::Application; +use crate::application::InboundDecision; +use crate::message::business_reject::BusinessReject; +use crate::message::heartbeat::Heartbeat; +use crate::message::logon::Logon; +use crate::message::logout::Logout; +use crate::message::reject::Reject; +use crate::message::resend_request::ResendRequest; +use crate::message::sequence_reset::SequenceReset; +use crate::message::test_request::TestRequest; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; +use hotfix_message::Part; use hotfix_message::message::Message; +use hotfix_message::session_fields::{ + BEGIN_SEQ_NO, END_SEQ_NO, GAP_FILL_FLAG, MSG_TYPE, NEW_SEQ_NO, SessionRejectReason, TEST_REQ_ID, +}; +use hotfix_store::MessageStore; use std::collections::VecDeque; +use tracing::{debug, error, warn}; const MAX_RESEND_ATTEMPTS: usize = 3; @@ -19,9 +38,9 @@ pub(crate) struct AwaitingResendState { } impl AwaitingResendState { - pub(crate) async fn on_disconnect(&self, reason: &str) -> super::SessionState { + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { self.writer.disconnect().await; - super::SessionState::new_disconnected(true, reason) + SessionState::new_disconnected(true, reason) } pub(crate) fn new(writer: WriterRef, begin_seq_number: u64, end_seq_number: u64) -> Self { @@ -56,8 +75,393 @@ impl AwaitingResendState { AwaitingResendTransitionOutcome::Success } + + pub(crate) async fn on_fix_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: Message, + ) -> Result { + let message_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + + let seq_number = get_msg_seq_num(&message); + + // If msg seq > end_seq_number AND not ResendRequest: queue it + if seq_number > self.end_seq_number && message_type != ResendRequest::MSG_TYPE { + self.inbound_queue.push_back(message); + return Ok(TransitionResult::Stay); + } + + // Dispatch by message type + let result = match message_type { + Heartbeat::MSG_TYPE => self.on_heartbeat(ctx, &message).await?, + TestRequest::MSG_TYPE => self.on_test_request(ctx, &message).await?, + ResendRequest::MSG_TYPE => self.on_resend_request(ctx, &message).await?, + Reject::MSG_TYPE => self.on_reject(ctx, &message).await?, + SequenceReset::MSG_TYPE => self.on_sequence_reset(ctx, &message).await?, + Logout::MSG_TYPE => self.on_logout(ctx, app, &message).await?, + Logon::MSG_TYPE => { + error!("received unexpected logon message"); + TransitionResult::Stay + } + _ => self.on_app_message(ctx, app, &message).await?, + }; + + // If a transition happened, return it directly + if !matches!(result, TransitionResult::Stay) { + return Ok(result); + } + + // Check if resend is done + self.check_end_of_resend(ctx) + } + + fn check_end_of_resend( + &mut self, + ctx: &SessionCtx<'_, Store>, + ) -> Result { + if ctx.store.next_target_seq_number() > self.end_seq_number { + let new_state = + SessionState::new_active(self.writer.clone(), ctx.config.heartbeat_interval); + let backlog = std::mem::take(&mut self.inbound_queue); + Ok(TransitionResult::TransitionWithBacklog { new_state, backlog }) + } else { + Ok(TransitionResult::Stay) + } + } + + async fn on_heartbeat( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_test_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| todo!()); + + ctx.store.increment_target_seq_number().await?; + + ctx.send_message(&self.writer, Heartbeat::for_request(req_id.to_string())) + .await + .with_send_context("heartbeat response")?; + + Ok(TransitionResult::Stay) + } + + async fn on_resend_request( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => { + // check_too_high=false, shouldn't happen + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let msg_seq_num = get_msg_seq_num(message); + let expected = ctx.store.next_target_seq_number(); + + // If seq is too high, queue it for seq accounting when the gap fill catches up, + // but still process the resend below. + if msg_seq_num > expected { + self.inbound_queue.push_back(message.clone()); + } + + let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { + Ok(seq_number) => seq_number, + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing begin sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing BEGIN_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + let end_seq_number: u64 = match message.get(END_SEQ_NO) { + Ok(seq_number) => { + let last_seq_number = ctx.store.next_sender_seq_number() - 1; + if seq_number == 0 { + last_seq_number + } else { + std::cmp::min(seq_number, last_seq_number) + } + } + Err(_) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing end sequence number for resend request"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing END_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + if msg_seq_num == expected { + ctx.store.increment_target_seq_number().await?; + } + + ctx.resend_messages(&self.writer, begin_seq_number, end_seq_number) + .await?; + + Ok(TransitionResult::Stay) + } + + async fn on_reject( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + ctx.store.increment_target_seq_number().await?; + Ok(TransitionResult::Stay) + } + + async fn on_sequence_reset( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + message: &Message, + ) -> Result { + let msg_seq_num = get_msg_seq_num(message); + let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); + match ctx + .verify_and_handle(&self.writer, message, is_gap_fill, is_gap_fill) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + let end: u64 = match message.get(NEW_SEQ_NO) { + Ok(new_seq_no) => new_seq_no, + Err(err) => { + error!( + "received sequence reset message without new sequence number: {:?}", + err + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("missing NewSeqNo tag in sequence reset message"); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for missing NEW_SEQ_NO")?; + return Ok(TransitionResult::Stay); + } + }; + + if end <= ctx.store.next_target_seq_number() { + error!( + "received sequence reset message which would move target seq number backwards: {end}", + ); + let text = + format!("attempt to lower sequence number, invalid value NewSeqNo(36)={end}"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&text); + ctx.send_message(&self.writer, reject) + .await + .with_send_context("reject for invalid sequence reset")?; + return Ok(TransitionResult::Stay); + } + + ctx.store.set_target_seq_number(end - 1).await?; + Ok(TransitionResult::Stay) + } + + async fn on_logout( + &self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, false, false) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { .. } => {} + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + // We are in AwaitingResend (logged on), send logout response + let logout = Logout::with_reason("Logout acknowledged".to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => self.writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout acknowledgement: {err}"), + } + + app.on_logout("peer has logged us out").await; + + self.writer.disconnect().await; + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected(true, "peer has logged us out"), + )) + } + + async fn on_app_message( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + app: &mut App, + message: &Message, + ) -> Result { + match ctx + .verify_and_handle(&self.writer, message, true, true) + .await? + { + VerifyResult::Passed => {} + VerifyResult::SeqTooHigh { expected, actual } => { + return self.handle_seq_too_high(ctx, expected, actual).await; + } + VerifyResult::ErrorHandled(Some(new_state)) => { + return Ok(TransitionResult::TransitionTo(new_state)); + } + VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + } + + match app.on_inbound_message(message).await { + InboundDecision::Accept => {} + InboundDecision::Reject { reason, text } => { + let msg_type: &str = message + .header() + .get(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))?; + let mut reject = + BusinessReject::new(msg_type, reason).ref_seq_num(get_msg_seq_num(message)); + if let Some(text) = text { + reject = reject.text(&text); + } + ctx.send_message(&self.writer, reject) + .await + .with_send_context("business message reject")?; + } + InboundDecision::TerminateSession => { + error!("failed to send inbound message to application"); + self.writer.disconnect().await; + } + } + ctx.store.increment_target_seq_number().await?; + + Ok(TransitionResult::Stay) + } + + async fn handle_seq_too_high( + &mut self, + ctx: &mut SessionCtx<'_, Store>, + expected: u64, + actual: u64, + ) -> Result { + match self.update(expected, actual) { + AwaitingResendTransitionOutcome::Success => { + debug!( + "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." + ); + let request = ResendRequest::new(expected, actual); + ctx.send_message(&self.writer, request) + .await + .with_send_context("resend request")?; + Ok(TransitionResult::Stay) + } + AwaitingResendTransitionOutcome::InvalidState(reason) => { + error!("failed to request resend: {reason}"); + Ok(TransitionResult::Stay) + } + AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { + self.writer.disconnect().await; + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected( + false, + "awaiting resend begin seq number unexpectedly lower than the previous resend request's", + ), + )) + } + AwaitingResendTransitionOutcome::AttemptsExceeded => { + self.writer.disconnect().await; + Ok(TransitionResult::TransitionTo( + SessionState::new_disconnected( + false, + "resend request attempts exceeded, manual intervention required", + ), + )) + } + } + } } +#[allow(dead_code)] // InvalidState is used only by AwaitingResendState::handle_seq_too_high and tests pub(crate) enum AwaitingResendTransitionOutcome { Success, InvalidState(String), @@ -68,7 +472,7 @@ pub(crate) enum AwaitingResendTransitionOutcome { #[cfg(test)] mod tests { use super::*; - use crate::session::state::SessionState; + use crate::session::state::AwaitingLogoutState; use tokio::sync::mpsc; use tokio::time::Instant; @@ -104,8 +508,6 @@ mod tests { #[test] fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { - use crate::session::state::AwaitingLogoutState; - let mut state = SessionState::AwaitingLogout(AwaitingLogoutState { writer: create_writer_ref(), logout_timeout: Instant::now(), From c8b0e3730461383943166f20b266abe29796a8d0 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 12:13:54 +0100 Subject: [PATCH 08/14] Move bulk of message sending logic out of session --- crates/hotfix/src/initiator.rs | 8 +- crates/hotfix/src/session.rs | 247 ++++++------------ crates/hotfix/src/session/state.rs | 79 ------ .../src/session/state/awaiting_logon.rs | 1 - .../hotfix/src/session/state/disconnected.rs | 1 - 5 files changed, 93 insertions(+), 243 deletions(-) diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index 19543e4a..f06d1adf 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -402,9 +402,13 @@ mod tests { .await .expect("initiator should connect"); - // Message should be received by session and persisted (seq 2 after Logon) + // Session is in AwaitingLogon (no logon response from counterparty), + // so send should be rejected — only Active sessions accept app messages let result = initiator.send(DummyMessage).await; - assert!(matches!(result, Ok(SendOutcome::Sent { .. }))); + assert!( + matches!(result, Err(crate::session::error::SendError::Disconnected)), + "expected Disconnected error, got: {result:?}" + ); } #[tokio::test] diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 177685d5..3708f1e8 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -19,8 +19,6 @@ use tracing::{debug, error, info, warn}; use crate::Application; use crate::config::SessionConfig; -use crate::message::OutboundMessage; -use crate::message::generate_message; use crate::message::logon::{Logon, ResetSeqNumConfig}; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; @@ -28,7 +26,7 @@ use crate::message::reject::Reject; use crate::message::resend_request::ResendRequest; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; -use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; pub use crate::session::error::{SendError, SendOutcome}; pub use crate::session::info::{SessionInfo, Status}; pub use crate::session::session_handle::SessionHandle; @@ -38,7 +36,7 @@ pub(crate) use crate::session::session_ref::InternalSessionRef; pub use crate::session::session_ref::InternalSessionRef; use crate::session::session_ref::OutboundRequest; use crate::session::state::SessionState; -use crate::session::state::{SessionCtx, TestRequestId, TransitionResult}; +use crate::session::state::{SessionCtx, TransitionResult}; use crate::session_schedule::{SessionPeriodComparison, SessionSchedule}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; @@ -112,8 +110,10 @@ where debug!("received message: {}", raw_message); // Reset peer timer before dispatching (if not expecting test response) - if !self.state.is_expecting_test_response() { - self.reset_peer_timer(None); + if let SessionState::Active(active) = &mut self.state + && active.expected_test_response_id().is_none() + { + active.reset_peer_timer(self.config.heartbeat_interval, None); } match self.message_builder.build(raw_message.as_bytes()) { @@ -141,60 +141,52 @@ where message: Message, reason: InvalidReason, ) -> Result<(), SessionOperationError> { + let Session { + ref state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + .. + } = *self; + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + let Some(writer) = state.get_writer() else { + return Ok(()); + }; + match reason { InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidTagNumber) - .text(&format!("invalid field {tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid field")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidTagNumber) + .text(&format!("invalid field {tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid field")?; } } InvalidReason::InvalidComponent(_component_name) => { warn!("received invalid component"); } InvalidReason::InvalidMsgType(msg_type) => { - let Session { - ref state, - ref mut store, - ref config, - ref message_builder, - ref message_config, - .. - } = *self; - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - if let Some(writer) = state.get_writer() { - ctx.handle_invalid_msg_type(writer, &message, &msg_type) - .await; - } + ctx.handle_invalid_msg_type(writer, &message, &msg_type) + .await; } InvalidReason::InvalidOrderInGroup { tag, .. } => { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason( - SessionRejectReason::RepeatingGroupFieldsOutOfOrder, - ) - .text(&format!("field appears in incorrect order:{tag}")); - self.send_message(reject) - .await - .with_send_context("reject for invalid group order")?; - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason( + SessionRejectReason::RepeatingGroupFieldsOutOfOrder, + ) + .text(&format!("field appears in incorrect order:{tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid group order")?; } } } @@ -294,7 +286,8 @@ where if let SessionState::Disconnected(s) = &self.state { self.state = s.on_connect(writer, Duration::from_secs(self.config.logon_timeout)); } - self.reset_peer_timer(None); + // Reset peer timer on the new AwaitingLogon state — no-op since it uses logon_timeout + // Send logon self.send_logon().await?; Ok(()) @@ -316,81 +309,15 @@ where } } - fn reset_heartbeat_timer(&mut self) { - self.state - .reset_heartbeat_timer(self.config.heartbeat_interval); - } - - fn reset_peer_timer(&mut self, test_request_id: Option) { - self.state - .reset_peer_timer(self.config.heartbeat_interval, test_request_id); - } - - /// Legacy send_app_message for non-Active connected states. - async fn send_app_message_legacy( - &mut self, - message: App::Outbound, - ) -> Result { - use crate::application::OutboundDecision; - - match self.application.on_outbound_message(&message).await { - OutboundDecision::Send => { - let sequence_number = self.send_message(message).await?; - Ok(SendOutcome::Sent { sequence_number }) - } - OutboundDecision::Drop => { - debug!("dropped outbound message as instructed by the application"); - Ok(SendOutcome::Dropped) - } - OutboundDecision::TerminateSession => { - warn!("the application indicated we should terminate the session"); - self.state.disconnect_writer().await; - Err(SendError::SessionTerminated) - } - } - } - - /// Legacy send_message used by send_logon, send_logout, and error handling paths. - async fn send_message( - &mut self, - message: impl OutboundMessage, - ) -> Result { - let seq_num = self.store.next_sender_seq_number(); - let msg_type = message.message_type().to_string(); - let msg = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - seq_num, - message, - ) - .map_err(|e| { - InternalSendError::Persist(crate::store::StoreError::PersistMessage { - sequence_number: seq_num, - source: e.into(), - }) - })?; - - self.store - .increment_sender_seq_number() - .await - .map_err(InternalSendError::SequenceNumber)?; - - self.store - .add(seq_num, &msg) - .await - .map_err(InternalSendError::Persist)?; - - self.send_raw(&msg_type, msg).await; - - Ok(seq_num) - } - - async fn send_raw(&mut self, message_type: &str, data: Vec) { - self.state - .send_message(message_type, RawFixMessage::new(data)) - .await; - self.reset_heartbeat_timer(); + fn make_ctx(&mut self) -> (SessionCtx<'_, Store>, Option<&WriterRef>) { + let writer = self.state.get_writer(); + let ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + (ctx, writer) } async fn send_logon(&mut self) -> Result<(), SessionOperationError> { @@ -403,15 +330,23 @@ where self.reset_on_next_logon = false; let logon = Logon::new(self.config.heartbeat_interval, reset_config); - self.send_message(logon).await.with_send_context("logon")?; + let (mut ctx, writer) = self.make_ctx(); + if let Some(writer) = writer { + ctx.send_message(writer, logon) + .await + .with_send_context("logon")?; + } Ok(()) } async fn send_logout(&mut self, reason: &str) -> Result<(), SessionOperationError> { let logout = Logout::with_reason(reason.to_string()); - self.send_message(logout) - .await - .with_send_context("logout")?; + let (mut ctx, writer) = self.make_ctx(); + if let Some(writer) = writer { + ctx.send_message(writer, logout) + .await + .with_send_context("logout")?; + } Ok(()) } @@ -420,7 +355,9 @@ where if let Err(err) = self.send_logout(reason).await { warn!("failed to send logout during session termination: {}", err); } - self.state.disconnect_writer().await; + if let Some(writer) = self.state.get_writer() { + writer.disconnect().await; + } } /// Sends a logout message and puts the session state into an AwaitingLogout state. @@ -474,36 +411,26 @@ where async fn handle_outbound_message(&mut self, request: OutboundRequest) { let OutboundRequest { message, confirm } = request; - let is_active = matches!(self.state, SessionState::Active(_)); - let is_connected = self.state.is_connected(); + let Session { + ref mut state, + ref mut store, + ref config, + ref message_builder, + ref message_config, + ref mut application, + .. + } = *self; - let result = if !is_connected { - Err(SendError::Disconnected) - } else if is_active { - let Session { - ref mut state, - ref mut store, - ref config, - ref message_builder, - ref message_config, - ref mut application, - .. - } = *self; - - if let SessionState::Active(s) = state { - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - s.send_app_message(&mut ctx, application, message).await - } else { - unreachable!() - } + let result = if let SessionState::Active(s) = state { + let mut ctx = SessionCtx { + config, + store, + message_builder, + message_config, + }; + s.send_app_message(&mut ctx, application, message).await } else { - // Legacy path: session is connected but not Active (e.g. AwaitingLogon). - self.send_app_message_legacy(message).await + Err(SendError::Disconnected) }; match confirm { @@ -631,7 +558,7 @@ where self.logout_and_terminate("internal error").await; } } - } else if self.state.is_connected() + } else if self.state.get_writer().is_some() && let Err(err) = self .initiate_graceful_logout("End of session time", true) .await diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index ae41127b..0b78e6b1 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -13,7 +13,6 @@ pub(crate) use awaiting_resend::AwaitingResendTransitionOutcome; pub(crate) use disconnected::DisconnectedState; use crate::config::SessionConfig; -use crate::message::logon::Logon; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; use crate::message::reject::Reject; @@ -504,53 +503,6 @@ impl SessionState { } } - pub async fn send_message(&mut self, message_type: &str, message: RawFixMessage) { - match self { - Self::Active(ActiveState { writer, .. }) - | Self::AwaitingResend(AwaitingResendState { writer, .. }) => { - if message_type == Logon::MSG_TYPE { - error!("logon message is invalid for active sessions") - } else { - writer.send_raw_message(message).await - } - } - Self::AwaitingLogon(AwaitingLogonState { - writer, logon_sent, .. - }) => match message_type { - Logon::MSG_TYPE => { - if *logon_sent { - error!("trying to send logon twice"); - } else { - writer.send_raw_message(message).await; - *logon_sent = true; - } - } - Logout::MSG_TYPE => { - writer.send_raw_message(message).await; - } - _ => error!("invalid outgoing message for AwaitingLogon state"), - }, - Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) => { - // Logout messages are allowed because we first transition into AwaitingLogout - // and only then send the logout message - if message_type == Logout::MSG_TYPE { - writer.send_raw_message(message).await - } - } - _ => error!("trying to write without an established connection"), - } - } - - pub async fn disconnect_writer(&self) { - match self { - Self::Active(ActiveState { writer, .. }) - | Self::AwaitingLogon(AwaitingLogonState { writer, .. }) - | Self::AwaitingLogout(AwaitingLogoutState { writer, .. }) - | Self::AwaitingResend(AwaitingResendState { writer, .. }) => writer.disconnect().await, - _ => debug!("disconnecting an already disconnected session"), - } - } - pub(crate) fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) @@ -646,12 +598,6 @@ impl SessionState { } } - pub fn reset_heartbeat_timer(&mut self, heartbeat_interval: u64) { - if let Self::Active(state) = self { - state.reset_heartbeat_timer(heartbeat_interval); - } - } - pub fn peer_deadline(&self) -> Option<&Instant> { match self { Self::Active(state) => Some(state.peer_deadline()), @@ -663,37 +609,12 @@ impl SessionState { } } - pub fn reset_peer_timer( - &mut self, - heartbeat_interval: u64, - test_request_id: Option, - ) { - if let Self::Active(state) = self { - state.reset_peer_timer(heartbeat_interval, test_request_id); - } - } - - pub fn expected_test_response_id(&self) -> Option<&TestRequestId> { - match self { - Self::Active(state) => state.expected_test_response_id(), - _ => None, - } - } - - pub fn is_connected(&self) -> bool { - self.get_writer().is_some() - } - #[cfg(test)] pub fn is_logged_on(&self) -> bool { matches!(self, SessionState::Active(_)) || matches!(self, SessionState::AwaitingResend { .. }) } - pub fn is_expecting_test_response(&self) -> bool { - self.expected_test_response_id().is_some() - } - pub fn as_status(&self) -> SessionInfoStatus { match self { SessionState::AwaitingLogon(_) => SessionInfoStatus::AwaitingLogon, diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index bb83e532..897fcb27 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -11,7 +11,6 @@ use tracing::warn; pub(crate) struct AwaitingLogonState { pub(crate) writer: WriterRef, - pub(crate) logon_sent: bool, pub(crate) logon_timeout: Instant, } diff --git a/crates/hotfix/src/session/state/disconnected.rs b/crates/hotfix/src/session/state/disconnected.rs index c96e2be0..04efa2fd 100644 --- a/crates/hotfix/src/session/state/disconnected.rs +++ b/crates/hotfix/src/session/state/disconnected.rs @@ -45,7 +45,6 @@ impl DisconnectedState { ) -> super::SessionState { super::SessionState::AwaitingLogon(AwaitingLogonState { writer, - logon_sent: false, logon_timeout: Instant::now() + logon_timeout, }) } From a3ef9e5ffccc2034db70e30f9e7aac9494133c9d Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 12:27:53 +0100 Subject: [PATCH 09/14] Break session context out into its own module --- crates/hotfix/src/session.rs | 44 +- crates/hotfix/src/session/state.rs | 495 +----------------- .../src/session/state/awaiting_resend.rs | 48 +- crates/hotfix/src/session/state/ctx.rs | 441 ++++++++++++++++ 4 files changed, 497 insertions(+), 531 deletions(-) create mode 100644 crates/hotfix/src/session/state/ctx.rs diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 3708f1e8..1aa7e24e 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -10,7 +10,6 @@ use chrono::Utc; use hotfix_message::dict::Dictionary; use hotfix_message::message::{Config as MessageConfig, Message}; use hotfix_message::{MessageBuilder, Part}; -use std::future::Future; use std::pin::Pin; use tokio::select; use tokio::sync::mpsc; @@ -193,17 +192,18 @@ where Ok(()) } - fn dispatch_valid_message( + async fn dispatch_valid_message( &mut self, message: Message, - ) -> Pin> + Send + '_>> { - Box::pin(self.dispatch_valid_message_inner(message)) + ) -> Result<(), SessionOperationError> { + let transition = self.dispatch_to_state(message).await?; + self.apply_transition(transition).await } - async fn dispatch_valid_message_inner( + async fn dispatch_to_state( &mut self, message: Message, - ) -> Result<(), SessionOperationError> { + ) -> Result { let Session { ref mut state, ref mut store, @@ -235,10 +235,7 @@ where SessionState::Disconnected(_) => TransitionResult::Stay, }; - // Let ctx go out of scope before we can mutate self.state - let _ = ctx; - - self.apply_transition(transition).await + Ok(transition) } async fn apply_transition( @@ -273,7 +270,12 @@ where // for sequence accounting purposes. self.store.increment_target_seq_number().await?; } else { - self.dispatch_valid_message(msg).await?; + let inner_transition = self.dispatch_to_state(msg).await?; + // Backlog messages can't produce more backlogs (only AwaitingResend + // produces TransitionWithBacklog, and we've already transitioned to Active) + if let TransitionResult::TransitionTo(s) = inner_transition { + self.state = s; + } } } debug!("resend backlog is cleared, resuming normal operation"); @@ -366,13 +368,23 @@ where reason: &str, reconnect: bool, ) -> Result<(), SessionOperationError> { - if self.state.try_transition_to_awaiting_logout( - Duration::from_secs(self.config.logout_timeout), - reconnect, - ) { - self.send_logout(reason).await?; + if matches!(self.state, SessionState::AwaitingLogout(_)) { + debug!("already in awaiting logout state"); + return Ok(()); } + let Some(writer) = self.state.get_writer().cloned() else { + error!("trying to transition to awaiting logout without an established connection"); + return Ok(()); + }; + + self.state = SessionState::AwaitingLogout(state::AwaitingLogoutState { + writer, + logout_timeout: Instant::now() + Duration::from_secs(self.config.logout_timeout), + reconnect, + }); + self.send_logout(reason).await?; + Ok(()) } diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 0b78e6b1..03414df6 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -2,465 +2,25 @@ mod active; mod awaiting_logon; mod awaiting_logout; mod awaiting_resend; +mod ctx; mod disconnected; pub(crate) use active::{ActiveState, calculate_peer_interval}; pub(crate) use awaiting_logon::AwaitingLogonState; pub(crate) use awaiting_logout::AwaitingLogoutState; pub(crate) use awaiting_resend::AwaitingResendState; -#[cfg(test)] -pub(crate) use awaiting_resend::AwaitingResendTransitionOutcome; +pub(crate) use ctx::{SessionCtx, TransitionResult, VerifyResult}; pub(crate) use disconnected::DisconnectedState; -use crate::config::SessionConfig; -use crate::message::logout::Logout; -use crate::message::parser::RawFixMessage; -use crate::message::reject::Reject; -use crate::message::sequence_reset::SequenceReset; -use crate::message::verification::verify_message as verify_message_impl; -use crate::message::verification_error::{CompIdType, MessageVerificationError}; -use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; -use crate::session::error::{InternalSendError, SessionOperationError}; use crate::session::event::AwaitingActiveSessionResponse; -use crate::session::get_msg_seq_num; use crate::session::info::Status as SessionInfoStatus; -use crate::store::StoreError; use crate::transport::writer::WriterRef; -use hotfix_message::message::{Config as MessageConfig, Message}; -use hotfix_message::session_fields::SessionRejectReason; -use hotfix_message::{MessageBuilder, Part}; -use hotfix_store::MessageStore; -use std::collections::VecDeque; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; -use tracing::{debug, enabled, error, info, warn}; +use tracing::error; -use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE}; - -pub(crate) struct SessionCtx<'a, Store> { - pub config: &'a SessionConfig, - pub store: &'a mut Store, - pub message_builder: &'a MessageBuilder, - pub message_config: &'a MessageConfig, -} - -pub(crate) struct PreparedMessage { - pub seq_num: u64, - #[allow(dead_code)] - pub msg_type: String, - pub raw: RawFixMessage, -} - -pub(crate) enum TransitionResult { - Stay, - TransitionTo(SessionState), - TransitionWithBacklog { - new_state: SessionState, - backlog: VecDeque, - }, -} - -pub(crate) enum VerifyResult { - Passed, - SeqTooHigh { expected: u64, actual: u64 }, - ErrorHandled(Option), -} - -impl SessionCtx<'_, Store> { - pub async fn prepare_message( - &mut self, - message: impl OutboundMessage, - ) -> Result { - let seq_num = self.store.next_sender_seq_number(); - let msg_type = message.message_type().to_string(); - let msg = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - seq_num, - message, - ) - .map_err(|e| { - InternalSendError::Persist(StoreError::PersistMessage { - sequence_number: seq_num, - source: e.into(), - }) - })?; - - self.store - .increment_sender_seq_number() - .await - .map_err(InternalSendError::SequenceNumber)?; - self.store - .add(seq_num, &msg) - .await - .map_err(InternalSendError::Persist)?; - - Ok(PreparedMessage { - seq_num, - msg_type, - raw: RawFixMessage::new(msg), - }) - } - - /// Prepare, persist, and send a message via the given writer. - pub async fn send_message( - &mut self, - writer: &WriterRef, - message: impl OutboundMessage, - ) -> Result { - let prepared = self.prepare_message(message).await?; - writer.send_raw_message(prepared.raw).await; - Ok(prepared.seq_num) - } - - pub fn verify_message( - &self, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result<(), MessageVerificationError> { - let expected_seq_number = if check_too_high || check_too_low { - Some(self.store.next_target_seq_number()) - } else { - None - }; - verify_message_impl( - message, - self.config, - expected_seq_number, - check_too_high, - check_too_low, - ) - } - - /// Verify a message and handle the error if verification fails. - /// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, - /// allowing the caller to handle the transition. - pub async fn verify_and_handle( - &mut self, - writer: &WriterRef, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result { - match self.verify_message(message, check_too_high, check_too_low) { - Ok(()) => Ok(VerifyResult::Passed), - Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { - Ok(VerifyResult::SeqTooHigh { expected, actual }) - } - Err(err) => { - let transition = self.handle_verification_error(writer, err).await?; - Ok(VerifyResult::ErrorHandled(transition)) - } - } - } - - /// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). - /// Returns `Some(new_state)` if a state transition is needed. - pub async fn handle_verification_error( - &mut self, - writer: &WriterRef, - error: MessageVerificationError, - ) -> Result, SessionOperationError> { - match error { - MessageVerificationError::SeqNumberTooLow { - expected, - actual, - possible_duplicate, - } => Ok(self - .handle_sequence_number_too_low(writer, expected, actual, possible_duplicate) - .await), - MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), - // but handle gracefully if it is. - warn!( - "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" - ); - Ok(None) - } - MessageVerificationError::IncorrectBeginString(begin_string) => Ok(Some( - self.handle_incorrect_begin_string(writer, begin_string) - .await, - )), - MessageVerificationError::IncorrectCompId { - comp_id, - comp_id_type, - msg_seq_num, - } => Ok(Some( - self.handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) - .await, - )), - MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "unexpected sending time", - ) - .await; - Ok(None) - } - MessageVerificationError::SendingTimeMissing { msg_seq_num } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "sending time missing", - ) - .await; - Ok(None) - } - MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { - self.handle_original_sending_time_missing(writer, msg_seq_num) - .await; - Ok(None) - } - MessageVerificationError::OriginalSendingTimeAfterSendingTime { - msg_seq_num, .. - } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "original sending time is after sending time", - ) - .await; - Ok(None) - } - } - } - - async fn handle_incorrect_begin_string( - &mut self, - writer: &WriterRef, - received_begin_string: String, - ) -> SessionState { - self.logout_and_terminate( - writer, - &format!("beginString={received_begin_string} is not supported"), - ) - .await; - SessionState::new_disconnected(true, "incorrect begin string") - } - - async fn handle_incorrect_comp_id( - &mut self, - writer: &WriterRef, - received_comp_id: String, - comp_id_type: CompIdType, - msg_seq_num: u64, - ) -> SessionState { - error!( - "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&format!("invalid comp ID {received_comp_id}")); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject message with invalid comp ID: {err}"); - } - self.logout_and_terminate(writer, "incorrect comp ID received") - .await; - SessionState::new_disconnected(true, "incorrect comp ID") - } - - async fn handle_sequence_number_too_low( - &mut self, - writer: &WriterRef, - expected: u64, - actual: u64, - possible_duplicate: bool, - ) -> Option { - if possible_duplicate { - warn!( - "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" - ); - return None; - } - error!( - "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." - ); - let reason = format!("sequence number too low (actual {actual}, expected {expected})"); - self.logout_and_terminate(writer, &reason).await; - Some(SessionState::new_disconnected(false, &reason)) - } - - async fn handle_sending_time_accuracy_problem( - &mut self, - writer: &WriterRef, - msg_seq_num: u64, - text: &str, - ) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) - .text(text); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject for time accuracy problem: {err}"); - } - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - - async fn handle_original_sending_time_missing(&mut self, writer: &WriterRef, msg_seq_num: u64) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("original sending time is required"); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject for time missing tag: {err}"); - } - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - - /// Send a logout message and immediately disconnect. - pub(crate) async fn logout_and_terminate(&mut self, writer: &WriterRef, reason: &str) { - let logout = Logout::with_reason(reason.to_string()); - match self.prepare_message(logout).await { - Ok(prepared) => writer.send_raw_message(prepared.raw).await, - Err(err) => warn!("failed to send logout during session termination: {err}"), - } - writer.disconnect().await; - } - - pub async fn resend_messages( - &mut self, - writer: &WriterRef, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - info!(begin, end, "resending messages as requested"); - let messages = self.store.get_slice(begin as usize, end as usize).await?; - - let no = messages.len(); - debug!(number_of_messages = no, "number of messages"); - - let mut reset_start: Option = None; - let mut sequence_number = 0; - - for msg in messages { - let mut message = self - .message_builder - .build(msg.as_slice()) - .into_message() - .ok_or_else(|| { - SessionOperationError::StoredMessageParse(format!( - "failed to build message for raw message: {msg:?}" - )) - })?; - sequence_number = get_msg_seq_num(&message); - let message_type: String = message - .header() - .get::<&str>(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? - .to_string(); - - if is_admin(&message_type) { - if reset_start.is_none() { - reset_start = Some(sequence_number); - } - continue; - } - - if let Some(begin) = reset_start { - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(writer, begin, end).await?; - reset_start = None; - } - - if let Err(e) = prepare_message_for_resend(&mut message) { - error!( - error = e, - "failed to prepare message for resend, sending original" - ); - } - writer - .send_raw_message(RawFixMessage::new(message.encode(self.message_config)?)) - .await; - - if enabled!(tracing::Level::DEBUG) - && let Ok(m) = String::from_utf8(msg.clone()) - { - debug!(sequence_number, message = m, "resent message"); - } - } - - if let Some(begin) = reset_start { - // the final reset if needed - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(writer, begin, end).await?; - } - - Ok(()) - } - - pub async fn send_sequence_reset( - &mut self, - writer: &WriterRef, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let sequence_reset = SequenceReset { - gap_fill: true, - new_seq_no: end, - }; - let raw_message = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - begin, - sequence_reset, - )?; - - writer - .send_raw_message(RawFixMessage::new(raw_message)) - .await; - debug!(begin, end, "sent reset sequence"); - - Ok(()) - } - - fn log_skipped_admin_messages(begin: u64, end: u64) { - info!( - begin, - end, "skipped admin message(s) during resend, requesting reset for these" - ); - } - - pub async fn handle_invalid_msg_type( - &mut self, - writer: &WriterRef, - message: &Message, - msg_type: &str, - ) { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidMsgtype) - .text(&format!("invalid message type {msg_type}")); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject message for invalid msgtype: {err}"); - } - - #[allow(clippy::collapsible_if)] - if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) - && self.store.next_target_seq_number() == seq_num - { - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } -} - -const TEST_REQUEST_THRESHOLD: f64 = 1.2; +pub(crate) const TEST_REQUEST_THRESHOLD: f64 = 1.2; pub(crate) type TestRequestId = String; @@ -513,53 +73,6 @@ impl SessionState { } } - pub fn try_transition_to_awaiting_logout( - &mut self, - logout_timeout: Duration, - reconnect: bool, - ) -> bool { - if matches!(self, SessionState::AwaitingLogout(_)) { - debug!("already in awaiting logout state"); - return false; - } - - if let Some(writer) = self.get_writer() { - *self = SessionState::AwaitingLogout(AwaitingLogoutState { - writer: writer.clone(), - logout_timeout: Instant::now() + logout_timeout, - reconnect, - }); - true - } else { - error!("trying to transition to awaiting logout without an established connection"); - false - } - } - - #[cfg(test)] - pub fn try_transition_to_awaiting_resend( - &mut self, - begin: u64, - end: u64, - ) -> AwaitingResendTransitionOutcome { - match self { - SessionState::AwaitingLogon(AwaitingLogonState { writer, .. }) - | SessionState::Active(ActiveState { writer, .. }) => { - let awaiting_resend = AwaitingResendState::new(writer.to_owned(), begin, end); - *self = SessionState::AwaitingResend(awaiting_resend); - AwaitingResendTransitionOutcome::Success - } - SessionState::AwaitingResend(state) => state.update(begin, end), - SessionState::AwaitingLogout(_) => AwaitingResendTransitionOutcome::InvalidState( - "trying to request a resend while we are already logging out".to_string(), - ), - SessionState::Disconnected(_) => AwaitingResendTransitionOutcome::InvalidState( - "trying to transition to awaiting resend without an established connection" - .to_string(), - ), - } - } - pub fn register_session_awaiter( &mut self, responder: oneshot::Sender, diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index d0bcb98c..996b1770 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -472,15 +472,13 @@ pub(crate) enum AwaitingResendTransitionOutcome { #[cfg(test)] mod tests { use super::*; - use crate::session::state::AwaitingLogoutState; use tokio::sync::mpsc; - use tokio::time::Instant; #[test] - fn test_awaiting_resend_transition_begin_seq_number_too_low() { + fn test_update_begin_seq_number_too_low() { let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); - let result = state.try_transition_to_awaiting_resend(0, 5); + let mut state = AwaitingResendState::new(writer, 1, 5); + let result = state.update(0, 5); assert!(matches!( result, AwaitingResendTransitionOutcome::BeginSeqNumberTooLow @@ -488,18 +486,18 @@ mod tests { } #[test] - fn test_awaiting_resend_transition_attempts_exceeded() { + fn test_update_attempts_exceeded() { let writer = create_writer_ref(); - let mut state = SessionState::AwaitingResend(AwaitingResendState::new(writer, 1, 5)); + let mut state = AwaitingResendState::new(writer, 1, 5); - // we can transition twice more without hitting the limit - let result = state.try_transition_to_awaiting_resend(1, 5); + // we can update twice more without hitting the limit + let result = state.update(1, 5); assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - let result = state.try_transition_to_awaiting_resend(1, 5); + let result = state.update(1, 5); assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); - // the fourth time we'd get into an AwaitingResendState with the same begin seq number, we get an error - let result = state.try_transition_to_awaiting_resend(1, 5); + // the fourth time with the same begin seq number, we get an error + let result = state.update(1, 5); assert!(matches!( result, AwaitingResendTransitionOutcome::AttemptsExceeded @@ -507,18 +505,20 @@ mod tests { } #[test] - fn test_awaiting_resend_transition_when_awaiting_logout_is_prevented() { - let mut state = SessionState::AwaitingLogout(AwaitingLogoutState { - writer: create_writer_ref(), - logout_timeout: Instant::now(), - reconnect: false, - }); - - let result = state.try_transition_to_awaiting_resend(1, 5); - assert!(matches!( - result, - AwaitingResendTransitionOutcome::InvalidState(_) - )); + fn test_update_resets_attempts_on_new_begin_seq() { + let writer = create_writer_ref(); + let mut state = AwaitingResendState::new(writer, 1, 5); + + // Use up attempts on begin=1 + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + let result = state.update(1, 5); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + + // A new begin_seq resets the counter + let result = state.update(3, 10); + assert!(matches!(result, AwaitingResendTransitionOutcome::Success)); + assert_eq!(state.resend_attempts, 1); } fn create_writer_ref() -> WriterRef { diff --git a/crates/hotfix/src/session/state/ctx.rs b/crates/hotfix/src/session/state/ctx.rs new file mode 100644 index 00000000..2f81dfa0 --- /dev/null +++ b/crates/hotfix/src/session/state/ctx.rs @@ -0,0 +1,441 @@ +use crate::config::SessionConfig; +use crate::message::logout::Logout; +use crate::message::parser::RawFixMessage; +use crate::message::reject::Reject; +use crate::message::sequence_reset::SequenceReset; +use crate::message::verification::verify_message as verify_message_impl; +use crate::message::verification_error::{CompIdType, MessageVerificationError}; +use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; +use crate::session::error::{InternalSendError, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::state::SessionState; +use crate::store::StoreError; +use crate::transport::writer::WriterRef; +use hotfix_message::message::{Config as MessageConfig, Message}; +use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE, SessionRejectReason}; +use hotfix_message::{MessageBuilder, Part}; +use hotfix_store::MessageStore; +use std::collections::VecDeque; +use tracing::{debug, enabled, error, info, warn}; + +pub(crate) struct SessionCtx<'a, Store> { + pub config: &'a SessionConfig, + pub store: &'a mut Store, + pub message_builder: &'a MessageBuilder, + pub message_config: &'a MessageConfig, +} + +pub(crate) struct PreparedMessage { + pub seq_num: u64, + #[allow(dead_code)] + pub msg_type: String, + pub raw: RawFixMessage, +} + +pub(crate) enum TransitionResult { + Stay, + TransitionTo(SessionState), + TransitionWithBacklog { + new_state: SessionState, + backlog: VecDeque, + }, +} + +pub(crate) enum VerifyResult { + Passed, + SeqTooHigh { expected: u64, actual: u64 }, + ErrorHandled(Option), +} + +impl SessionCtx<'_, Store> { + pub async fn prepare_message( + &mut self, + message: impl OutboundMessage, + ) -> Result { + let seq_num = self.store.next_sender_seq_number(); + let msg_type = message.message_type().to_string(); + let msg = generate_message( + &self.config.begin_string, + &self.config.sender_comp_id, + &self.config.target_comp_id, + seq_num, + message, + ) + .map_err(|e| { + InternalSendError::Persist(StoreError::PersistMessage { + sequence_number: seq_num, + source: e.into(), + }) + })?; + + self.store + .increment_sender_seq_number() + .await + .map_err(InternalSendError::SequenceNumber)?; + self.store + .add(seq_num, &msg) + .await + .map_err(InternalSendError::Persist)?; + + Ok(PreparedMessage { + seq_num, + msg_type, + raw: RawFixMessage::new(msg), + }) + } + + /// Prepare, persist, and send a message via the given writer. + pub async fn send_message( + &mut self, + writer: &WriterRef, + message: impl OutboundMessage, + ) -> Result { + let prepared = self.prepare_message(message).await?; + writer.send_raw_message(prepared.raw).await; + Ok(prepared.seq_num) + } + + pub fn verify_message( + &self, + message: &Message, + check_too_high: bool, + check_too_low: bool, + ) -> Result<(), MessageVerificationError> { + let expected_seq_number = if check_too_high || check_too_low { + Some(self.store.next_target_seq_number()) + } else { + None + }; + verify_message_impl( + message, + self.config, + expected_seq_number, + check_too_high, + check_too_low, + ) + } + + /// Verify a message and handle the error if verification fails. + /// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, + /// allowing the caller to handle the transition. + pub async fn verify_and_handle( + &mut self, + writer: &WriterRef, + message: &Message, + check_too_high: bool, + check_too_low: bool, + ) -> Result { + match self.verify_message(message, check_too_high, check_too_low) { + Ok(()) => Ok(VerifyResult::Passed), + Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { + Ok(VerifyResult::SeqTooHigh { expected, actual }) + } + Err(err) => { + let transition = self.handle_verification_error(writer, err).await?; + Ok(VerifyResult::ErrorHandled(transition)) + } + } + } + + /// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). + /// Returns `Some(new_state)` if a state transition is needed. + pub async fn handle_verification_error( + &mut self, + writer: &WriterRef, + error: MessageVerificationError, + ) -> Result, SessionOperationError> { + match error { + MessageVerificationError::SeqNumberTooLow { + expected, + actual, + possible_duplicate, + } => Ok(self + .handle_sequence_number_too_low(writer, expected, actual, possible_duplicate) + .await), + MessageVerificationError::SeqNumberTooHigh { expected, actual } => { + // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), + // but handle gracefully if it is. + warn!( + "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" + ); + Ok(None) + } + MessageVerificationError::IncorrectBeginString(begin_string) => Ok(Some( + self.handle_incorrect_begin_string(writer, begin_string) + .await, + )), + MessageVerificationError::IncorrectCompId { + comp_id, + comp_id_type, + msg_seq_num, + } => Ok(Some( + self.handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) + .await, + )), + MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { + self.handle_sending_time_accuracy_problem( + writer, + msg_seq_num, + "unexpected sending time", + ) + .await; + Ok(None) + } + MessageVerificationError::SendingTimeMissing { msg_seq_num } => { + self.handle_sending_time_accuracy_problem( + writer, + msg_seq_num, + "sending time missing", + ) + .await; + Ok(None) + } + MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { + self.handle_original_sending_time_missing(writer, msg_seq_num) + .await; + Ok(None) + } + MessageVerificationError::OriginalSendingTimeAfterSendingTime { + msg_seq_num, .. + } => { + self.handle_sending_time_accuracy_problem( + writer, + msg_seq_num, + "original sending time is after sending time", + ) + .await; + Ok(None) + } + } + } + + async fn handle_incorrect_begin_string( + &mut self, + writer: &WriterRef, + received_begin_string: String, + ) -> SessionState { + self.logout_and_terminate( + writer, + &format!("beginString={received_begin_string} is not supported"), + ) + .await; + SessionState::new_disconnected(true, "incorrect begin string") + } + + async fn handle_incorrect_comp_id( + &mut self, + writer: &WriterRef, + received_comp_id: String, + comp_id_type: CompIdType, + msg_seq_num: u64, + ) -> SessionState { + error!( + "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" + ); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&format!("invalid comp ID {received_comp_id}")); + if let Err(err) = self.send_message(writer, reject).await { + error!("failed to send reject message with invalid comp ID: {err}"); + } + self.logout_and_terminate(writer, "incorrect comp ID received") + .await; + SessionState::new_disconnected(true, "incorrect comp ID") + } + + async fn handle_sequence_number_too_low( + &mut self, + writer: &WriterRef, + expected: u64, + actual: u64, + possible_duplicate: bool, + ) -> Option { + if possible_duplicate { + warn!( + "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" + ); + return None; + } + error!( + "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." + ); + let reason = format!("sequence number too low (actual {actual}, expected {expected})"); + self.logout_and_terminate(writer, &reason).await; + Some(SessionState::new_disconnected(false, &reason)) + } + + async fn handle_sending_time_accuracy_problem( + &mut self, + writer: &WriterRef, + msg_seq_num: u64, + text: &str, + ) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) + .text(text); + if let Err(err) = self.send_message(writer, reject).await { + error!("failed to send reject for time accuracy problem: {err}"); + } + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + + async fn handle_original_sending_time_missing(&mut self, writer: &WriterRef, msg_seq_num: u64) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("original sending time is required"); + if let Err(err) = self.send_message(writer, reject).await { + error!("failed to send reject for time missing tag: {err}"); + } + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + + /// Send a logout message and immediately disconnect. + pub(crate) async fn logout_and_terminate(&mut self, writer: &WriterRef, reason: &str) { + let logout = Logout::with_reason(reason.to_string()); + match self.prepare_message(logout).await { + Ok(prepared) => writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout during session termination: {err}"), + } + writer.disconnect().await; + } + + pub async fn resend_messages( + &mut self, + writer: &WriterRef, + begin: u64, + end: u64, + ) -> Result<(), SessionOperationError> { + info!(begin, end, "resending messages as requested"); + let messages = self.store.get_slice(begin as usize, end as usize).await?; + + let no = messages.len(); + debug!(number_of_messages = no, "number of messages"); + + let mut reset_start: Option = None; + let mut sequence_number = 0; + + for msg in messages { + let mut message = self + .message_builder + .build(msg.as_slice()) + .into_message() + .ok_or_else(|| { + SessionOperationError::StoredMessageParse(format!( + "failed to build message for raw message: {msg:?}" + )) + })?; + sequence_number = get_msg_seq_num(&message); + let message_type: String = message + .header() + .get::<&str>(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? + .to_string(); + + if is_admin(&message_type) { + if reset_start.is_none() { + reset_start = Some(sequence_number); + } + continue; + } + + if let Some(begin) = reset_start { + let end = sequence_number; + Self::log_skipped_admin_messages(begin, end); + self.send_sequence_reset(writer, begin, end).await?; + reset_start = None; + } + + if let Err(e) = prepare_message_for_resend(&mut message) { + error!( + error = e, + "failed to prepare message for resend, sending original" + ); + } + writer + .send_raw_message(RawFixMessage::new(message.encode(self.message_config)?)) + .await; + + if enabled!(tracing::Level::DEBUG) + && let Ok(m) = String::from_utf8(msg.clone()) + { + debug!(sequence_number, message = m, "resent message"); + } + } + + if let Some(begin) = reset_start { + // the final reset if needed + let end = sequence_number; + Self::log_skipped_admin_messages(begin, end); + self.send_sequence_reset(writer, begin, end).await?; + } + + Ok(()) + } + + pub async fn send_sequence_reset( + &mut self, + writer: &WriterRef, + begin: u64, + end: u64, + ) -> Result<(), SessionOperationError> { + let sequence_reset = SequenceReset { + gap_fill: true, + new_seq_no: end, + }; + let raw_message = generate_message( + &self.config.begin_string, + &self.config.sender_comp_id, + &self.config.target_comp_id, + begin, + sequence_reset, + )?; + + writer + .send_raw_message(RawFixMessage::new(raw_message)) + .await; + debug!(begin, end, "sent reset sequence"); + + Ok(()) + } + + fn log_skipped_admin_messages(begin: u64, end: u64) { + info!( + begin, + end, "skipped admin message(s) during resend, requesting reset for these" + ); + } + + pub async fn handle_invalid_msg_type( + &mut self, + writer: &WriterRef, + message: &Message, + msg_type: &str, + ) { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidMsgtype) + .text(&format!("invalid message type {msg_type}")); + if let Err(err) = self.send_message(writer, reject).await { + error!("failed to send reject message for invalid msgtype: {err}"); + } + + #[allow(clippy::collapsible_if)] + if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) + && self.store.next_target_seq_number() == seq_num + { + if let Err(err) = self.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } + } +} From 3fc4d4bdaca293a690066569c46c9a84d0d093c7 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 13:00:12 +0100 Subject: [PATCH 10/14] Final touches to make session thin --- crates/hotfix/src/session.rs | 104 +++++++----------- crates/hotfix/src/session/state.rs | 12 ++ .../src/session/state/awaiting_resend.rs | 6 - crates/hotfix/src/session/state/ctx.rs | 41 ++++++- 4 files changed, 92 insertions(+), 71 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 1aa7e24e..c049b6ca 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -21,7 +21,6 @@ use crate::config::SessionConfig; use crate::message::logon::{Logon, ResetSeqNumConfig}; use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; -use crate::message::reject::Reject; use crate::message::resend_request::ResendRequest; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; @@ -41,7 +40,7 @@ use crate::store::MessageStore; use crate::transport::writer::WriterRef; use event::SessionEvent; use hotfix_message::parsed_message::{InvalidReason, ParsedMessage}; -use hotfix_message::session_fields::{MSG_SEQ_NUM, SessionRejectReason}; +use hotfix_message::session_fields::MSG_SEQ_NUM; const SCHEDULE_CHECK_INTERVAL: u64 = 1; @@ -140,56 +139,12 @@ where message: Message, reason: InvalidReason, ) -> Result<(), SessionOperationError> { - let Session { - ref state, - ref mut store, - ref config, - ref message_builder, - ref message_config, - .. - } = *self; - let mut ctx = SessionCtx { - config, - store, - message_builder, - message_config, - }; - let Some(writer) = state.get_writer() else { + let (mut ctx, writer) = self.make_ctx(); + let Some(writer) = writer else { return Ok(()); }; - - match reason { - InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { - if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidTagNumber) - .text(&format!("invalid field {tag}")); - ctx.send_message(writer, reject) - .await - .with_send_context("reject for invalid field")?; - } - } - InvalidReason::InvalidComponent(_component_name) => { - warn!("received invalid component"); - } - InvalidReason::InvalidMsgType(msg_type) => { - ctx.handle_invalid_msg_type(writer, &message, &msg_type) - .await; - } - InvalidReason::InvalidOrderInGroup { tag, .. } => { - if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason( - SessionRejectReason::RepeatingGroupFieldsOutOfOrder, - ) - .text(&format!("field appears in incorrect order:{tag}")); - ctx.send_message(writer, reject) - .await - .with_send_context("reject for invalid group order")?; - } - } - } - Ok(()) + ctx.handle_invalid_parsed_message(writer, &message, reason) + .await } async fn dispatch_valid_message( @@ -352,16 +307,6 @@ where Ok(()) } - /// Sends a logout message and immediately disconnects the counterparty. - async fn logout_and_terminate(&mut self, reason: &str) { - if let Err(err) = self.send_logout(reason).await { - warn!("failed to send logout during session termination: {}", err); - } - if let Some(writer) = self.state.get_writer() { - writer.disconnect().await; - } - } - /// Sends a logout message and puts the session state into an AwaitingLogout state. async fn initiate_graceful_logout( &mut self, @@ -396,7 +341,15 @@ where if let Err(err) = self.on_incoming(fix_message).await { let reason = err.to_string(); error!(reason, "fatal error in message processing"); - self.logout_and_terminate("internal error").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "internal error") + .await; self.state = SessionState::new_disconnected(true, &reason); } } @@ -548,7 +501,15 @@ where // we are in the same period, nothing needs to be done } Ok(SessionPeriodComparison::DifferentPeriod) => { - self.logout_and_terminate("session period changed").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "session period changed") + .await; if let Err(err) = self.store.reset().await { error!("error resetting session store: {err:}"); self.state = @@ -557,7 +518,14 @@ where } Ok(SessionPeriodComparison::OutsideSessionTime { .. }) => { warn!("store creation time is outside session schedule, resetting store"); - self.logout_and_terminate("creation time outside schedule") + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "creation time outside schedule") .await; if let Err(err) = self.store.reset().await { error!("error resetting session store: {err:}"); @@ -567,7 +535,15 @@ where } Err(err) => { error!("error checking session period: {err:?}"); - self.logout_and_terminate("internal error").await; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; + self.state + .logout_and_terminate(&mut ctx, "internal error") + .await; } } } else if self.state.get_writer().is_some() diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 03414df6..e5bc14c3 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -15,6 +15,7 @@ pub(crate) use disconnected::DisconnectedState; use crate::session::event::AwaitingActiveSessionResponse; use crate::session::info::Status as SessionInfoStatus; use crate::transport::writer::WriterRef; +use hotfix_store::MessageStore; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::Instant; @@ -104,6 +105,17 @@ impl SessionState { } } + /// Send a logout message and immediately disconnect, if connected. + pub(crate) async fn logout_and_terminate( + &self, + ctx: &mut SessionCtx<'_, Store>, + reason: &str, + ) { + if let Some(writer) = self.get_writer() { + ctx.logout_and_terminate(writer, reason).await; + } + } + pub fn heartbeat_deadline(&self) -> Option<&Instant> { match self { Self::Active(state) => Some(state.heartbeat_deadline()), diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index 996b1770..d77083fd 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -435,10 +435,6 @@ impl AwaitingResendState { .with_send_context("resend request")?; Ok(TransitionResult::Stay) } - AwaitingResendTransitionOutcome::InvalidState(reason) => { - error!("failed to request resend: {reason}"); - Ok(TransitionResult::Stay) - } AwaitingResendTransitionOutcome::BeginSeqNumberTooLow => { self.writer.disconnect().await; Ok(TransitionResult::TransitionTo( @@ -461,10 +457,8 @@ impl AwaitingResendState { } } -#[allow(dead_code)] // InvalidState is used only by AwaitingResendState::handle_seq_too_high and tests pub(crate) enum AwaitingResendTransitionOutcome { Success, - InvalidState(String), BeginSeqNumberTooLow, AttemptsExceeded, } diff --git a/crates/hotfix/src/session/state/ctx.rs b/crates/hotfix/src/session/state/ctx.rs index 2f81dfa0..55b16277 100644 --- a/crates/hotfix/src/session/state/ctx.rs +++ b/crates/hotfix/src/session/state/ctx.rs @@ -6,12 +6,13 @@ use crate::message::sequence_reset::SequenceReset; use crate::message::verification::verify_message as verify_message_impl; use crate::message::verification_error::{CompIdType, MessageVerificationError}; use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; -use crate::session::error::{InternalSendError, SessionOperationError}; +use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; use crate::session::get_msg_seq_num; use crate::session::state::SessionState; use crate::store::StoreError; use crate::transport::writer::WriterRef; use hotfix_message::message::{Config as MessageConfig, Message}; +use hotfix_message::parsed_message::InvalidReason; use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE, SessionRejectReason}; use hotfix_message::{MessageBuilder, Part}; use hotfix_store::MessageStore; @@ -409,6 +410,44 @@ impl SessionCtx<'_, Store> { ); } + pub async fn handle_invalid_parsed_message( + &mut self, + writer: &WriterRef, + message: &Message, + reason: InvalidReason, + ) -> Result<(), SessionOperationError> { + match reason { + InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidTagNumber) + .text(&format!("invalid field {tag}")); + self.send_message(writer, reject) + .await + .with_send_context("reject for invalid field")?; + } + } + InvalidReason::InvalidComponent(_component_name) => { + warn!("received invalid component"); + } + InvalidReason::InvalidMsgType(msg_type) => { + self.handle_invalid_msg_type(writer, message, &msg_type) + .await; + } + InvalidReason::InvalidOrderInGroup { tag, .. } => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RepeatingGroupFieldsOutOfOrder) + .text(&format!("field appears in incorrect order:{tag}")); + self.send_message(writer, reject) + .await + .with_send_context("reject for invalid group order")?; + } + } + } + Ok(()) + } + pub async fn handle_invalid_msg_type( &mut self, writer: &WriterRef, From 891692068d99ea8d0dab9d1d5a784ee5180fbc1b Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 13:14:21 +0100 Subject: [PATCH 11/14] Move context out of state submodule --- crates/hotfix/src/session.rs | 1 + crates/hotfix/src/session/{state => }/ctx.rs | 0 crates/hotfix/src/session/state.rs | 3 +-- 3 files changed, 2 insertions(+), 2 deletions(-) rename crates/hotfix/src/session/{state => }/ctx.rs (100%) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index c049b6ca..97d48a9f 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -1,4 +1,5 @@ pub(crate) mod admin_request; +mod ctx; pub mod error; pub(crate) mod event; mod info; diff --git a/crates/hotfix/src/session/state/ctx.rs b/crates/hotfix/src/session/ctx.rs similarity index 100% rename from crates/hotfix/src/session/state/ctx.rs rename to crates/hotfix/src/session/ctx.rs diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index e5bc14c3..2e95efac 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -2,14 +2,13 @@ mod active; mod awaiting_logon; mod awaiting_logout; mod awaiting_resend; -mod ctx; mod disconnected; +pub(crate) use crate::session::ctx::{SessionCtx, TransitionResult, VerifyResult}; pub(crate) use active::{ActiveState, calculate_peer_interval}; pub(crate) use awaiting_logon::AwaitingLogonState; pub(crate) use awaiting_logout::AwaitingLogoutState; pub(crate) use awaiting_resend::AwaitingResendState; -pub(crate) use ctx::{SessionCtx, TransitionResult, VerifyResult}; pub(crate) use disconnected::DisconnectedState; use crate::session::event::AwaitingActiveSessionResponse; From 7ee0e87d2a67f4d823b89ca148dc81e2a1c2743e Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 13:32:41 +0100 Subject: [PATCH 12/14] Remove make_ctx as its use is confusing --- crates/hotfix/src/session.rs | 45 +++++++++++-------- crates/hotfix/src/session/state.rs | 4 ++ .../src/session/state/awaiting_logout.rs | 10 ++++- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 97d48a9f..ac30670b 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -140,10 +140,16 @@ where message: Message, reason: InvalidReason, ) -> Result<(), SessionOperationError> { - let (mut ctx, writer) = self.make_ctx(); + let writer = self.state.get_writer(); let Some(writer) = writer else { return Ok(()); }; + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; ctx.handle_invalid_parsed_message(writer, &message, reason) .await } @@ -256,7 +262,7 @@ where SessionState::Active(s) => Some(s.on_disconnect(&reason).await), SessionState::AwaitingLogon(s) => Some(s.on_disconnect(&reason).await), SessionState::AwaitingResend(s) => Some(s.on_disconnect(&reason).await), - SessionState::AwaitingLogout(s) => Some(s.on_disconnect(&reason)), + SessionState::AwaitingLogout(s) => Some(s.on_disconnect(&reason).await), SessionState::Disconnected(_) => { warn!("disconnect message was received, but the session is already disconnected"); None @@ -267,17 +273,6 @@ where } } - fn make_ctx(&mut self) -> (SessionCtx<'_, Store>, Option<&WriterRef>) { - let writer = self.state.get_writer(); - let ctx = SessionCtx { - config: &self.config, - store: &mut self.store, - message_builder: &self.message_builder, - message_config: &self.message_config, - }; - (ctx, writer) - } - async fn send_logon(&mut self) -> Result<(), SessionOperationError> { let reset_config = if self.config.reset_on_logon || self.reset_on_next_logon { self.store.reset().await?; @@ -288,7 +283,13 @@ where self.reset_on_next_logon = false; let logon = Logon::new(self.config.heartbeat_interval, reset_config); - let (mut ctx, writer) = self.make_ctx(); + let writer = self.state.get_writer(); + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; if let Some(writer) = writer { ctx.send_message(writer, logon) .await @@ -299,7 +300,13 @@ where async fn send_logout(&mut self, reason: &str) -> Result<(), SessionOperationError> { let logout = Logout::with_reason(reason.to_string()); - let (mut ctx, writer) = self.make_ctx(); + let writer = self.state.get_writer(); + let mut ctx = SessionCtx { + config: &self.config, + store: &mut self.store, + message_builder: &self.message_builder, + message_config: &self.message_config, + }; if let Some(writer) = writer { ctx.send_message(writer, logout) .await @@ -324,11 +331,11 @@ where return Ok(()); }; - self.state = SessionState::AwaitingLogout(state::AwaitingLogoutState { + self.state = SessionState::AwaitingLogout(state::AwaitingLogoutState::new( writer, - logout_timeout: Instant::now() + Duration::from_secs(self.config.logout_timeout), + Instant::now() + Duration::from_secs(self.config.logout_timeout), reconnect, - }); + )); self.send_logout(reason).await?; Ok(()) @@ -547,7 +554,7 @@ where .await; } } - } else if self.state.get_writer().is_some() + } else if self.state.is_connected() && let Err(err) = self .initiate_graceful_logout("End of session time", true) .await diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 2e95efac..f4a21eac 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -63,6 +63,10 @@ impl SessionState { } } + pub(crate) fn is_connected(&self) -> bool { + self.get_writer().is_some() + } + pub(crate) fn get_writer(&self) -> Option<&WriterRef> { match self { Self::Active(ActiveState { writer, .. }) diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index 2466657e..110ecd8d 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -16,7 +16,15 @@ pub(crate) struct AwaitingLogoutState { } impl AwaitingLogoutState { - pub(crate) fn on_disconnect(&self, reason: &str) -> SessionState { + pub(crate) fn new(writer: WriterRef, logout_timeout: Instant, reconnect: bool) -> Self { + Self { + writer, + logout_timeout, + reconnect, + } + } + + pub(crate) async fn on_disconnect(&self, reason: &str) -> SessionState { SessionState::new_disconnected(self.reconnect, reason) } From e1671fee33080cec617d1dac75d933941cebb96e Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 13:51:14 +0100 Subject: [PATCH 13/14] Try to clean up some boilerplate in incoming message handling --- crates/hotfix/src/initiator.rs | 2 - crates/hotfix/src/session/ctx.rs | 45 ++++++++++--------- crates/hotfix/src/session/state/active.rs | 35 +++------------ .../src/session/state/awaiting_logon.rs | 5 +-- .../src/session/state/awaiting_logout.rs | 5 +-- .../src/session/state/awaiting_resend.rs | 35 +++------------ 6 files changed, 41 insertions(+), 86 deletions(-) diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index f06d1adf..3654f780 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -387,8 +387,6 @@ mod tests { #[tokio::test] async fn test_send_delegates_to_session_handle() { - use crate::session::error::SendOutcome; - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let config = create_test_config("127.0.0.1", port); diff --git a/crates/hotfix/src/session/ctx.rs b/crates/hotfix/src/session/ctx.rs index 55b16277..6722676d 100644 --- a/crates/hotfix/src/session/ctx.rs +++ b/crates/hotfix/src/session/ctx.rs @@ -45,7 +45,7 @@ pub(crate) enum TransitionResult { pub(crate) enum VerifyResult { Passed, SeqTooHigh { expected: u64, actual: u64 }, - ErrorHandled(Option), + Handled(TransitionResult), } impl SessionCtx<'_, Store> { @@ -133,18 +133,19 @@ impl SessionCtx<'_, Store> { } Err(err) => { let transition = self.handle_verification_error(writer, err).await?; - Ok(VerifyResult::ErrorHandled(transition)) + Ok(VerifyResult::Handled(transition)) } } } /// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). - /// Returns `Some(new_state)` if a state transition is needed. + /// Returns the `TransitionResult` to use — either `Stay` (error was handled in-place) + /// or `TransitionTo` (a state change is needed). pub async fn handle_verification_error( &mut self, writer: &WriterRef, error: MessageVerificationError, - ) -> Result, SessionOperationError> { + ) -> Result { match error { MessageVerificationError::SeqNumberTooLow { expected, @@ -159,20 +160,24 @@ impl SessionCtx<'_, Store> { warn!( "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" ); - Ok(None) + Ok(TransitionResult::Stay) + } + MessageVerificationError::IncorrectBeginString(begin_string) => { + let new_state = self + .handle_incorrect_begin_string(writer, begin_string) + .await; + Ok(TransitionResult::TransitionTo(new_state)) } - MessageVerificationError::IncorrectBeginString(begin_string) => Ok(Some( - self.handle_incorrect_begin_string(writer, begin_string) - .await, - )), MessageVerificationError::IncorrectCompId { comp_id, comp_id_type, msg_seq_num, - } => Ok(Some( - self.handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) - .await, - )), + } => { + let new_state = self + .handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) + .await; + Ok(TransitionResult::TransitionTo(new_state)) + } MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { self.handle_sending_time_accuracy_problem( writer, @@ -180,7 +185,7 @@ impl SessionCtx<'_, Store> { "unexpected sending time", ) .await; - Ok(None) + Ok(TransitionResult::Stay) } MessageVerificationError::SendingTimeMissing { msg_seq_num } => { self.handle_sending_time_accuracy_problem( @@ -189,12 +194,12 @@ impl SessionCtx<'_, Store> { "sending time missing", ) .await; - Ok(None) + Ok(TransitionResult::Stay) } MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { self.handle_original_sending_time_missing(writer, msg_seq_num) .await; - Ok(None) + Ok(TransitionResult::Stay) } MessageVerificationError::OriginalSendingTimeAfterSendingTime { msg_seq_num, .. @@ -205,7 +210,7 @@ impl SessionCtx<'_, Store> { "original sending time is after sending time", ) .await; - Ok(None) + Ok(TransitionResult::Stay) } } } @@ -250,19 +255,19 @@ impl SessionCtx<'_, Store> { expected: u64, actual: u64, possible_duplicate: bool, - ) -> Option { + ) -> TransitionResult { if possible_duplicate { warn!( "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" ); - return None; + return TransitionResult::Stay; } error!( "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." ); let reason = format!("sequence number too low (actual {actual}, expected {expected})"); self.logout_and_terminate(writer, &reason).await; - Some(SessionState::new_disconnected(false, &reason)) + TransitionResult::TransitionTo(SessionState::new_disconnected(false, &reason)) } async fn handle_sending_time_accuracy_problem( diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index 6d87f2ae..0fb7a0c3 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -152,10 +152,7 @@ impl ActiveState { .transition_to_awaiting_resend(ctx, expected, actual) .await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } if let (Some(expected_req_id), Ok(message_req_id)) = ( @@ -186,10 +183,7 @@ impl ActiveState { .transition_to_awaiting_resend(ctx, expected, actual) .await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { @@ -221,10 +215,7 @@ impl ActiveState { // ResendRequest with check_too_high=false should never get SeqTooHigh, // but handle gracefully } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let msg_seq_num = get_msg_seq_num(message); @@ -292,10 +283,7 @@ impl ActiveState { .transition_to_awaiting_resend(ctx, expected, actual) .await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } ctx.store.increment_target_seq_number().await?; @@ -319,10 +307,7 @@ impl ActiveState { .transition_to_awaiting_resend(ctx, expected, actual) .await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let end: u64 = match message.get(NEW_SEQ_NO) { @@ -377,10 +362,7 @@ impl ActiveState { VerifyResult::SeqTooHigh { .. } => { // verify with check_too_high=false, so this shouldn't happen } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } // We are logged on, send logout response @@ -419,10 +401,7 @@ impl ActiveState { .transition_to_awaiting_resend(ctx, expected, actual) .await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } match app.on_inbound_message(message).await { diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index 897fcb27..4c29d1f5 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -75,10 +75,7 @@ impl AwaitingLogonState { )); Ok(TransitionResult::TransitionTo(new_state)) } - VerifyResult::ErrorHandled(Some(new_state)) => { - Ok(TransitionResult::TransitionTo(new_state)) - } - VerifyResult::ErrorHandled(None) => Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => Ok(transition), } } } diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index 110ecd8d..5df5ee6f 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -55,10 +55,7 @@ impl AwaitingLogoutState { VerifyResult::SeqTooHigh { .. } => { // verify with check_too_high=false, shouldn't happen } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } app.on_logout("peer has logged us out").await; diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index d77083fd..528cd9a9 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -146,10 +146,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } ctx.store.increment_target_seq_number().await?; @@ -169,10 +166,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| todo!()); @@ -199,10 +193,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { .. } => { // check_too_high=false, shouldn't happen } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let msg_seq_num = get_msg_seq_num(message); @@ -270,10 +261,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } ctx.store.increment_target_seq_number().await?; @@ -295,10 +283,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } let end: u64 = match message.get(NEW_SEQ_NO) { @@ -349,10 +334,7 @@ impl AwaitingResendState { { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { .. } => {} - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } // We are in AwaitingResend (logged on), send logout response @@ -386,10 +368,7 @@ impl AwaitingResendState { VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; } - VerifyResult::ErrorHandled(Some(new_state)) => { - return Ok(TransitionResult::TransitionTo(new_state)); - } - VerifyResult::ErrorHandled(None) => return Ok(TransitionResult::Stay), + VerifyResult::Handled(transition) => return Ok(transition), } match app.on_inbound_message(message).await { From a843ec9f75c576535d6513cdb36c30685d875cb9 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Tue, 17 Mar 2026 14:18:15 +0100 Subject: [PATCH 14/14] Break out free functions for message sending --- crates/hotfix/src/session.rs | 4 +- crates/hotfix/src/session/ctx.rs | 402 +----------------- crates/hotfix/src/session/message_handling.rs | 402 ++++++++++++++++++ crates/hotfix/src/session/state.rs | 2 +- crates/hotfix/src/session/state/active.rs | 44 +- .../src/session/state/awaiting_logon.rs | 6 +- .../src/session/state/awaiting_logout.rs | 4 +- .../src/session/state/awaiting_resend.rs | 44 +- 8 files changed, 444 insertions(+), 464 deletions(-) create mode 100644 crates/hotfix/src/session/message_handling.rs diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index ac30670b..3d1f5c57 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -3,6 +3,7 @@ mod ctx; pub mod error; pub(crate) mod event; mod info; +mod message_handling; mod session_handle; pub mod session_ref; mod state; @@ -150,8 +151,7 @@ where message_builder: &self.message_builder, message_config: &self.message_config, }; - ctx.handle_invalid_parsed_message(writer, &message, reason) - .await + message_handling::handle_invalid_parsed_message(&mut ctx, writer, &message, reason).await } async fn dispatch_valid_message( diff --git a/crates/hotfix/src/session/ctx.rs b/crates/hotfix/src/session/ctx.rs index 6722676d..e38c467a 100644 --- a/crates/hotfix/src/session/ctx.rs +++ b/crates/hotfix/src/session/ctx.rs @@ -1,23 +1,14 @@ use crate::config::SessionConfig; -use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; -use crate::message::reject::Reject; -use crate::message::sequence_reset::SequenceReset; -use crate::message::verification::verify_message as verify_message_impl; -use crate::message::verification_error::{CompIdType, MessageVerificationError}; -use crate::message::{OutboundMessage, generate_message, is_admin, prepare_message_for_resend}; -use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; -use crate::session::get_msg_seq_num; +use crate::message::{OutboundMessage, generate_message}; +use crate::session::error::InternalSendError; use crate::session::state::SessionState; use crate::store::StoreError; use crate::transport::writer::WriterRef; +use hotfix_message::MessageBuilder; use hotfix_message::message::{Config as MessageConfig, Message}; -use hotfix_message::parsed_message::InvalidReason; -use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE, SessionRejectReason}; -use hotfix_message::{MessageBuilder, Part}; use hotfix_store::MessageStore; use std::collections::VecDeque; -use tracing::{debug, enabled, error, info, warn}; pub(crate) struct SessionCtx<'a, Store> { pub config: &'a SessionConfig, @@ -95,391 +86,4 @@ impl SessionCtx<'_, Store> { writer.send_raw_message(prepared.raw).await; Ok(prepared.seq_num) } - - pub fn verify_message( - &self, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result<(), MessageVerificationError> { - let expected_seq_number = if check_too_high || check_too_low { - Some(self.store.next_target_seq_number()) - } else { - None - }; - verify_message_impl( - message, - self.config, - expected_seq_number, - check_too_high, - check_too_low, - ) - } - - /// Verify a message and handle the error if verification fails. - /// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, - /// allowing the caller to handle the transition. - pub async fn verify_and_handle( - &mut self, - writer: &WriterRef, - message: &Message, - check_too_high: bool, - check_too_low: bool, - ) -> Result { - match self.verify_message(message, check_too_high, check_too_low) { - Ok(()) => Ok(VerifyResult::Passed), - Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { - Ok(VerifyResult::SeqTooHigh { expected, actual }) - } - Err(err) => { - let transition = self.handle_verification_error(writer, err).await?; - Ok(VerifyResult::Handled(transition)) - } - } - } - - /// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). - /// Returns the `TransitionResult` to use — either `Stay` (error was handled in-place) - /// or `TransitionTo` (a state change is needed). - pub async fn handle_verification_error( - &mut self, - writer: &WriterRef, - error: MessageVerificationError, - ) -> Result { - match error { - MessageVerificationError::SeqNumberTooLow { - expected, - actual, - possible_duplicate, - } => Ok(self - .handle_sequence_number_too_low(writer, expected, actual, possible_duplicate) - .await), - MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), - // but handle gracefully if it is. - warn!( - "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" - ); - Ok(TransitionResult::Stay) - } - MessageVerificationError::IncorrectBeginString(begin_string) => { - let new_state = self - .handle_incorrect_begin_string(writer, begin_string) - .await; - Ok(TransitionResult::TransitionTo(new_state)) - } - MessageVerificationError::IncorrectCompId { - comp_id, - comp_id_type, - msg_seq_num, - } => { - let new_state = self - .handle_incorrect_comp_id(writer, comp_id, comp_id_type, msg_seq_num) - .await; - Ok(TransitionResult::TransitionTo(new_state)) - } - MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "unexpected sending time", - ) - .await; - Ok(TransitionResult::Stay) - } - MessageVerificationError::SendingTimeMissing { msg_seq_num } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "sending time missing", - ) - .await; - Ok(TransitionResult::Stay) - } - MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { - self.handle_original_sending_time_missing(writer, msg_seq_num) - .await; - Ok(TransitionResult::Stay) - } - MessageVerificationError::OriginalSendingTimeAfterSendingTime { - msg_seq_num, .. - } => { - self.handle_sending_time_accuracy_problem( - writer, - msg_seq_num, - "original sending time is after sending time", - ) - .await; - Ok(TransitionResult::Stay) - } - } - } - - async fn handle_incorrect_begin_string( - &mut self, - writer: &WriterRef, - received_begin_string: String, - ) -> SessionState { - self.logout_and_terminate( - writer, - &format!("beginString={received_begin_string} is not supported"), - ) - .await; - SessionState::new_disconnected(true, "incorrect begin string") - } - - async fn handle_incorrect_comp_id( - &mut self, - writer: &WriterRef, - received_comp_id: String, - comp_id_type: CompIdType, - msg_seq_num: u64, - ) -> SessionState { - error!( - "rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})" - ); - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::ValueIsIncorrect) - .text(&format!("invalid comp ID {received_comp_id}")); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject message with invalid comp ID: {err}"); - } - self.logout_and_terminate(writer, "incorrect comp ID received") - .await; - SessionState::new_disconnected(true, "incorrect comp ID") - } - - async fn handle_sequence_number_too_low( - &mut self, - writer: &WriterRef, - expected: u64, - actual: u64, - possible_duplicate: bool, - ) -> TransitionResult { - if possible_duplicate { - warn!( - "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" - ); - return TransitionResult::Stay; - } - error!( - "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." - ); - let reason = format!("sequence number too low (actual {actual}, expected {expected})"); - self.logout_and_terminate(writer, &reason).await; - TransitionResult::TransitionTo(SessionState::new_disconnected(false, &reason)) - } - - async fn handle_sending_time_accuracy_problem( - &mut self, - writer: &WriterRef, - msg_seq_num: u64, - text: &str, - ) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) - .text(text); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject for time accuracy problem: {err}"); - } - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - - async fn handle_original_sending_time_missing(&mut self, writer: &WriterRef, msg_seq_num: u64) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RequiredTagMissing) - .text("original sending time is required"); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject for time missing tag: {err}"); - } - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - - /// Send a logout message and immediately disconnect. - pub(crate) async fn logout_and_terminate(&mut self, writer: &WriterRef, reason: &str) { - let logout = Logout::with_reason(reason.to_string()); - match self.prepare_message(logout).await { - Ok(prepared) => writer.send_raw_message(prepared.raw).await, - Err(err) => warn!("failed to send logout during session termination: {err}"), - } - writer.disconnect().await; - } - - pub async fn resend_messages( - &mut self, - writer: &WriterRef, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - info!(begin, end, "resending messages as requested"); - let messages = self.store.get_slice(begin as usize, end as usize).await?; - - let no = messages.len(); - debug!(number_of_messages = no, "number of messages"); - - let mut reset_start: Option = None; - let mut sequence_number = 0; - - for msg in messages { - let mut message = self - .message_builder - .build(msg.as_slice()) - .into_message() - .ok_or_else(|| { - SessionOperationError::StoredMessageParse(format!( - "failed to build message for raw message: {msg:?}" - )) - })?; - sequence_number = get_msg_seq_num(&message); - let message_type: String = message - .header() - .get::<&str>(MSG_TYPE) - .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? - .to_string(); - - if is_admin(&message_type) { - if reset_start.is_none() { - reset_start = Some(sequence_number); - } - continue; - } - - if let Some(begin) = reset_start { - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(writer, begin, end).await?; - reset_start = None; - } - - if let Err(e) = prepare_message_for_resend(&mut message) { - error!( - error = e, - "failed to prepare message for resend, sending original" - ); - } - writer - .send_raw_message(RawFixMessage::new(message.encode(self.message_config)?)) - .await; - - if enabled!(tracing::Level::DEBUG) - && let Ok(m) = String::from_utf8(msg.clone()) - { - debug!(sequence_number, message = m, "resent message"); - } - } - - if let Some(begin) = reset_start { - // the final reset if needed - let end = sequence_number; - Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(writer, begin, end).await?; - } - - Ok(()) - } - - pub async fn send_sequence_reset( - &mut self, - writer: &WriterRef, - begin: u64, - end: u64, - ) -> Result<(), SessionOperationError> { - let sequence_reset = SequenceReset { - gap_fill: true, - new_seq_no: end, - }; - let raw_message = generate_message( - &self.config.begin_string, - &self.config.sender_comp_id, - &self.config.target_comp_id, - begin, - sequence_reset, - )?; - - writer - .send_raw_message(RawFixMessage::new(raw_message)) - .await; - debug!(begin, end, "sent reset sequence"); - - Ok(()) - } - - fn log_skipped_admin_messages(begin: u64, end: u64) { - info!( - begin, - end, "skipped admin message(s) during resend, requesting reset for these" - ); - } - - pub async fn handle_invalid_parsed_message( - &mut self, - writer: &WriterRef, - message: &Message, - reason: InvalidReason, - ) -> Result<(), SessionOperationError> { - match reason { - InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { - if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidTagNumber) - .text(&format!("invalid field {tag}")); - self.send_message(writer, reject) - .await - .with_send_context("reject for invalid field")?; - } - } - InvalidReason::InvalidComponent(_component_name) => { - warn!("received invalid component"); - } - InvalidReason::InvalidMsgType(msg_type) => { - self.handle_invalid_msg_type(writer, message, &msg_type) - .await; - } - InvalidReason::InvalidOrderInGroup { tag, .. } => { - if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::RepeatingGroupFieldsOutOfOrder) - .text(&format!("field appears in incorrect order:{tag}")); - self.send_message(writer, reject) - .await - .with_send_context("reject for invalid group order")?; - } - } - } - Ok(()) - } - - pub async fn handle_invalid_msg_type( - &mut self, - writer: &WriterRef, - message: &Message, - msg_type: &str, - ) { - match message.header().get(MSG_SEQ_NUM) { - Ok(msg_seq_num) => { - let reject = Reject::new(msg_seq_num) - .session_reject_reason(SessionRejectReason::InvalidMsgtype) - .text(&format!("invalid message type {msg_type}")); - if let Err(err) = self.send_message(writer, reject).await { - error!("failed to send reject message for invalid msgtype: {err}"); - } - - #[allow(clippy::collapsible_if)] - if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) - && self.store.next_target_seq_number() == seq_num - { - if let Err(err) = self.store.increment_target_seq_number().await { - error!("failed to increment target seq number: {:?}", err); - } - } - } - Err(err) => { - error!("failed to get message seq num: {:?}", err); - } - } - } } diff --git a/crates/hotfix/src/session/message_handling.rs b/crates/hotfix/src/session/message_handling.rs new file mode 100644 index 00000000..90785239 --- /dev/null +++ b/crates/hotfix/src/session/message_handling.rs @@ -0,0 +1,402 @@ +use crate::message::logout::Logout; +use crate::message::parser::RawFixMessage; +use crate::message::reject::Reject; +use crate::message::sequence_reset::SequenceReset; +use crate::message::verification::verify_message as verify_message_impl; +use crate::message::verification_error::{CompIdType, MessageVerificationError}; +use crate::message::{generate_message, is_admin, prepare_message_for_resend}; +use crate::session::ctx::{SessionCtx, TransitionResult, VerifyResult}; +use crate::session::error::{InternalSendResultExt, SessionOperationError}; +use crate::session::get_msg_seq_num; +use crate::session::state::SessionState; +use crate::transport::writer::WriterRef; +use hotfix_message::Part; +use hotfix_message::message::Message; +use hotfix_message::parsed_message::InvalidReason; +use hotfix_message::session_fields::{MSG_SEQ_NUM, MSG_TYPE, SessionRejectReason}; +use hotfix_store::MessageStore; +use tracing::{debug, enabled, error, info, warn}; + +fn verify_message( + ctx: &SessionCtx<'_, Store>, + message: &Message, + check_too_high: bool, + check_too_low: bool, +) -> Result<(), MessageVerificationError> { + let expected_seq_number = if check_too_high || check_too_low { + Some(ctx.store.next_target_seq_number()) + } else { + None + }; + verify_message_impl( + message, + ctx.config, + expected_seq_number, + check_too_high, + check_too_low, + ) +} + +/// Verify a message and handle the error if verification fails. +/// For SeqNumberTooHigh, returns `VerifyResult::SeqTooHigh` instead of handling it, +/// allowing the caller to handle the transition. +pub async fn verify_and_handle( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + check_too_high: bool, + check_too_low: bool, +) -> Result { + match verify_message(ctx, message, check_too_high, check_too_low) { + Ok(()) => Ok(VerifyResult::Passed), + Err(MessageVerificationError::SeqNumberTooHigh { expected, actual }) => { + Ok(VerifyResult::SeqTooHigh { expected, actual }) + } + Err(err) => { + let transition = handle_verification_error(ctx, writer, err).await?; + Ok(VerifyResult::Handled(transition)) + } + } +} + +/// Handle a verification error (excluding SeqNumberTooHigh which is returned separately). +/// Returns the `TransitionResult` to use — either `Stay` (error was handled in-place) +/// or `TransitionTo` (a state change is needed). +async fn handle_verification_error( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + error: MessageVerificationError, +) -> Result { + match error { + MessageVerificationError::SeqNumberTooLow { + expected, + actual, + possible_duplicate, + } => Ok( + handle_sequence_number_too_low(ctx, writer, expected, actual, possible_duplicate).await, + ), + MessageVerificationError::SeqNumberTooHigh { expected, actual } => { + // This shouldn't be called for SeqTooHigh anymore (it's returned via VerifyResult), + // but handle gracefully if it is. + warn!( + "handle_verification_error called with SeqNumberTooHigh({expected}, {actual}) - caller should use verify_and_handle" + ); + Ok(TransitionResult::Stay) + } + MessageVerificationError::IncorrectBeginString(begin_string) => { + let new_state = handle_incorrect_begin_string(ctx, writer, begin_string).await; + Ok(TransitionResult::TransitionTo(new_state)) + } + MessageVerificationError::IncorrectCompId { + comp_id, + comp_id_type, + msg_seq_num, + } => { + let new_state = + handle_incorrect_comp_id(ctx, writer, comp_id, comp_id_type, msg_seq_num).await; + Ok(TransitionResult::TransitionTo(new_state)) + } + MessageVerificationError::SendingTimeAccuracyIssue { msg_seq_num } => { + handle_sending_time_accuracy_problem( + ctx, + writer, + msg_seq_num, + "unexpected sending time", + ) + .await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::SendingTimeMissing { msg_seq_num } => { + handle_sending_time_accuracy_problem(ctx, writer, msg_seq_num, "sending time missing") + .await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => { + handle_original_sending_time_missing(ctx, writer, msg_seq_num).await; + Ok(TransitionResult::Stay) + } + MessageVerificationError::OriginalSendingTimeAfterSendingTime { msg_seq_num, .. } => { + handle_sending_time_accuracy_problem( + ctx, + writer, + msg_seq_num, + "original sending time is after sending time", + ) + .await; + Ok(TransitionResult::Stay) + } + } +} + +async fn handle_incorrect_begin_string( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + received_begin_string: String, +) -> SessionState { + logout_and_terminate( + ctx, + writer, + &format!("beginString={received_begin_string} is not supported"), + ) + .await; + SessionState::new_disconnected(true, "incorrect begin string") +} + +async fn handle_incorrect_comp_id( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + received_comp_id: String, + comp_id_type: CompIdType, + msg_seq_num: u64, +) -> SessionState { + error!("rejecting message with incorrect comp ID: {received_comp_id} (type: {comp_id_type:?})"); + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::ValueIsIncorrect) + .text(&format!("invalid comp ID {received_comp_id}")); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject message with invalid comp ID: {err}"); + } + logout_and_terminate(ctx, writer, "incorrect comp ID received").await; + SessionState::new_disconnected(true, "incorrect comp ID") +} + +async fn handle_sequence_number_too_low( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + expected: u64, + actual: u64, + possible_duplicate: bool, +) -> TransitionResult { + if possible_duplicate { + warn!( + "sequence number too low (expected {expected}, actual {actual}, but counterparty indicated it's poss duplicate, ignoring" + ); + return TransitionResult::Stay; + } + error!( + "we expected {expected} sequence number, but target sent lower ({actual}), terminating..." + ); + let reason = format!("sequence number too low (actual {actual}, expected {expected})"); + logout_and_terminate(ctx, writer, &reason).await; + TransitionResult::TransitionTo(SessionState::new_disconnected(false, &reason)) +} + +async fn handle_sending_time_accuracy_problem( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + msg_seq_num: u64, + text: &str, +) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) + .text(text); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject for time accuracy problem: {err}"); + } + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } +} + +async fn handle_original_sending_time_missing( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + msg_seq_num: u64, +) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text("original sending time is required"); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject for time missing tag: {err}"); + } + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } +} + +/// Send a logout message and immediately disconnect. +pub(crate) async fn logout_and_terminate( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + reason: &str, +) { + let logout = Logout::with_reason(reason.to_string()); + match ctx.prepare_message(logout).await { + Ok(prepared) => writer.send_raw_message(prepared.raw).await, + Err(err) => warn!("failed to send logout during session termination: {err}"), + } + writer.disconnect().await; +} + +pub async fn resend_messages( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + begin: u64, + end: u64, +) -> Result<(), SessionOperationError> { + info!(begin, end, "resending messages as requested"); + let messages = ctx.store.get_slice(begin as usize, end as usize).await?; + + let no = messages.len(); + debug!(number_of_messages = no, "number of messages"); + + let mut reset_start: Option = None; + let mut sequence_number = 0; + + for msg in messages { + let mut message = ctx + .message_builder + .build(msg.as_slice()) + .into_message() + .ok_or_else(|| { + SessionOperationError::StoredMessageParse(format!( + "failed to build message for raw message: {msg:?}" + )) + })?; + sequence_number = get_msg_seq_num(&message); + let message_type: String = message + .header() + .get::<&str>(MSG_TYPE) + .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? + .to_string(); + + if is_admin(&message_type) { + if reset_start.is_none() { + reset_start = Some(sequence_number); + } + continue; + } + + if let Some(begin) = reset_start { + let end = sequence_number; + log_skipped_admin_messages(begin, end); + send_sequence_reset(ctx, writer, begin, end).await?; + reset_start = None; + } + + if let Err(e) = prepare_message_for_resend(&mut message) { + error!( + error = e, + "failed to prepare message for resend, sending original" + ); + } + writer + .send_raw_message(RawFixMessage::new(message.encode(ctx.message_config)?)) + .await; + + if enabled!(tracing::Level::DEBUG) + && let Ok(m) = String::from_utf8(msg.clone()) + { + debug!(sequence_number, message = m, "resent message"); + } + } + + if let Some(begin) = reset_start { + // the final reset if needed + let end = sequence_number; + log_skipped_admin_messages(begin, end); + send_sequence_reset(ctx, writer, begin, end).await?; + } + + Ok(()) +} + +async fn send_sequence_reset( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + begin: u64, + end: u64, +) -> Result<(), SessionOperationError> { + let sequence_reset = SequenceReset { + gap_fill: true, + new_seq_no: end, + }; + let raw_message = generate_message( + &ctx.config.begin_string, + &ctx.config.sender_comp_id, + &ctx.config.target_comp_id, + begin, + sequence_reset, + )?; + + writer + .send_raw_message(RawFixMessage::new(raw_message)) + .await; + debug!(begin, end, "sent reset sequence"); + + Ok(()) +} + +fn log_skipped_admin_messages(begin: u64, end: u64) { + info!( + begin, + end, "skipped admin message(s) during resend, requesting reset for these" + ); +} + +pub async fn handle_invalid_parsed_message( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + reason: InvalidReason, +) -> Result<(), SessionOperationError> { + match reason { + InvalidReason::InvalidField(tag) | InvalidReason::InvalidGroup(tag) => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidTagNumber) + .text(&format!("invalid field {tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid field")?; + } + } + InvalidReason::InvalidComponent(_component_name) => { + warn!("received invalid component"); + } + InvalidReason::InvalidMsgType(msg_type) => { + handle_invalid_msg_type(ctx, writer, message, &msg_type).await; + } + InvalidReason::InvalidOrderInGroup { tag, .. } => { + if let Ok(msg_seq_num) = message.header().get(MSG_SEQ_NUM) { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RepeatingGroupFieldsOutOfOrder) + .text(&format!("field appears in incorrect order:{tag}")); + ctx.send_message(writer, reject) + .await + .with_send_context("reject for invalid group order")?; + } + } + } + Ok(()) +} + +async fn handle_invalid_msg_type( + ctx: &mut SessionCtx<'_, Store>, + writer: &WriterRef, + message: &Message, + msg_type: &str, +) { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::InvalidMsgtype) + .text(&format!("invalid message type {msg_type}")); + if let Err(err) = ctx.send_message(writer, reject).await { + error!("failed to send reject message for invalid msgtype: {err}"); + } + + #[allow(clippy::collapsible_if)] + if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) + && ctx.store.next_target_seq_number() == seq_num + { + if let Err(err) = ctx.store.increment_target_seq_number().await { + error!("failed to increment target seq number: {:?}", err); + } + } + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } +} diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index f4a21eac..0b2f76d6 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -115,7 +115,7 @@ impl SessionState { reason: &str, ) { if let Some(writer) = self.get_writer() { - ctx.logout_and_terminate(writer, reason).await; + super::message_handling::logout_and_terminate(ctx, writer, reason).await; } } diff --git a/crates/hotfix/src/session/state/active.rs b/crates/hotfix/src/session/state/active.rs index 0fb7a0c3..da3b86d3 100644 --- a/crates/hotfix/src/session/state/active.rs +++ b/crates/hotfix/src/session/state/active.rs @@ -10,6 +10,7 @@ use crate::message::sequence_reset::SequenceReset; use crate::message::test_request::TestRequest; use crate::session::error::{InternalSendResultExt, SendError, SendOutcome, SessionOperationError}; use crate::session::get_msg_seq_num; +use crate::session::message_handling; use crate::session::state::{ AwaitingResendState, SessionCtx, SessionState, TestRequestId, TransitionResult, VerifyResult, }; @@ -142,10 +143,7 @@ impl ActiveState { ctx: &mut SessionCtx<'_, Store>, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self @@ -173,10 +171,7 @@ impl ActiveState { ctx: &mut SessionCtx<'_, Store>, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self @@ -206,10 +201,7 @@ impl ActiveState { ctx: &mut SessionCtx<'_, Store>, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { .. } => { // ResendRequest with check_too_high=false should never get SeqTooHigh, @@ -261,7 +253,7 @@ impl ActiveState { ctx.store.increment_target_seq_number().await?; } - ctx.resend_messages(&self.writer, begin_seq_number, end_seq_number) + message_handling::resend_messages(ctx, &self.writer, begin_seq_number, end_seq_number) .await?; self.reset_heartbeat_timer(ctx.config.heartbeat_interval); @@ -273,10 +265,7 @@ impl ActiveState { ctx: &mut SessionCtx<'_, Store>, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self @@ -297,9 +286,14 @@ impl ActiveState { ) -> Result { let msg_seq_num = get_msg_seq_num(message); let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); - match ctx - .verify_and_handle(&self.writer, message, is_gap_fill, is_gap_fill) - .await? + match message_handling::verify_and_handle( + ctx, + &self.writer, + message, + is_gap_fill, + is_gap_fill, + ) + .await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { @@ -354,10 +348,7 @@ impl ActiveState { app: &mut App, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, false) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, false).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { .. } => { // verify with check_too_high=false, so this shouldn't happen @@ -391,10 +382,7 @@ impl ActiveState { app: &mut App, message: &hotfix_message::message::Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self diff --git a/crates/hotfix/src/session/state/awaiting_logon.rs b/crates/hotfix/src/session/state/awaiting_logon.rs index 4c29d1f5..7c5f4850 100644 --- a/crates/hotfix/src/session/state/awaiting_logon.rs +++ b/crates/hotfix/src/session/state/awaiting_logon.rs @@ -1,6 +1,7 @@ use crate::Application; use crate::message::logon::Logon; use crate::session::error::SessionOperationError; +use crate::session::message_handling; use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; use hotfix_message::Part; @@ -42,10 +43,7 @@ impl AwaitingLogonState { } // process logon - match ctx - .verify_and_handle(&self.writer, &message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, &message, true, true).await? { VerifyResult::Passed => { // happy logon flow, the session is now active let new_state = diff --git a/crates/hotfix/src/session/state/awaiting_logout.rs b/crates/hotfix/src/session/state/awaiting_logout.rs index 5df5ee6f..544d2e01 100644 --- a/crates/hotfix/src/session/state/awaiting_logout.rs +++ b/crates/hotfix/src/session/state/awaiting_logout.rs @@ -1,6 +1,7 @@ use crate::Application; use crate::message::logout::Logout; use crate::session::error::SessionOperationError; +use crate::session::message_handling; use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; use hotfix_message::Part; @@ -47,8 +48,7 @@ impl AwaitingLogoutState { if message_type == Logout::MSG_TYPE { // Process the logout - match ctx - .verify_and_handle(&self.writer, &message, false, false) + match message_handling::verify_and_handle(ctx, &self.writer, &message, false, false) .await? { VerifyResult::Passed => {} diff --git a/crates/hotfix/src/session/state/awaiting_resend.rs b/crates/hotfix/src/session/state/awaiting_resend.rs index 528cd9a9..9a17a100 100644 --- a/crates/hotfix/src/session/state/awaiting_resend.rs +++ b/crates/hotfix/src/session/state/awaiting_resend.rs @@ -10,6 +10,7 @@ use crate::message::sequence_reset::SequenceReset; use crate::message::test_request::TestRequest; use crate::session::error::{InternalSendResultExt, SessionOperationError}; use crate::session::get_msg_seq_num; +use crate::session::message_handling; use crate::session::state::{SessionCtx, SessionState, TransitionResult, VerifyResult}; use crate::transport::writer::WriterRef; use hotfix_message::Part; @@ -138,10 +139,7 @@ impl AwaitingResendState { ctx: &mut SessionCtx<'_, Store>, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; @@ -158,10 +156,7 @@ impl AwaitingResendState { ctx: &mut SessionCtx<'_, Store>, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; @@ -185,10 +180,7 @@ impl AwaitingResendState { ctx: &mut SessionCtx<'_, Store>, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { .. } => { // check_too_high=false, shouldn't happen @@ -242,7 +234,7 @@ impl AwaitingResendState { ctx.store.increment_target_seq_number().await?; } - ctx.resend_messages(&self.writer, begin_seq_number, end_seq_number) + message_handling::resend_messages(ctx, &self.writer, begin_seq_number, end_seq_number) .await?; Ok(TransitionResult::Stay) @@ -253,10 +245,7 @@ impl AwaitingResendState { ctx: &mut SessionCtx<'_, Store>, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await; @@ -275,9 +264,14 @@ impl AwaitingResendState { ) -> Result { let msg_seq_num = get_msg_seq_num(message); let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); - match ctx - .verify_and_handle(&self.writer, message, is_gap_fill, is_gap_fill) - .await? + match message_handling::verify_and_handle( + ctx, + &self.writer, + message, + is_gap_fill, + is_gap_fill, + ) + .await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { @@ -328,10 +322,7 @@ impl AwaitingResendState { app: &mut App, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, false, false) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, false, false).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { .. } => {} VerifyResult::Handled(transition) => return Ok(transition), @@ -360,10 +351,7 @@ impl AwaitingResendState { app: &mut App, message: &Message, ) -> Result { - match ctx - .verify_and_handle(&self.writer, message, true, true) - .await? - { + match message_handling::verify_and_handle(ctx, &self.writer, message, true, true).await? { VerifyResult::Passed => {} VerifyResult::SeqTooHigh { expected, actual } => { return self.handle_seq_too_high(ctx, expected, actual).await;