@@ -223,7 +223,7 @@ where
223223 . handle_verification_issue ( & mut self . ctx , & message, flags)
224224 . await ?
225225 {
226- self . apply_transition ( result) ;
226+ self . apply_transition ( result) . await ;
227227 return Ok ( ( ) ) ;
228228 }
229229
@@ -286,25 +286,28 @@ where
286286 }
287287
288288 async fn check_end_of_resend ( & mut self ) -> Result < ( ) , SessionOperationError > {
289- let ended_state = if let SessionState :: AwaitingResend ( state) = & mut self . state {
289+ let backlog = if let SessionState :: AwaitingResend ( state) = & mut self . state {
290290 if self . ctx . store . next_target_seq_number ( ) > state. end_seq_number {
291+ let inbound_queue = std:: mem:: take ( & mut state. inbound_queue ) ;
291292 let new_state = SessionState :: new_active (
292293 state. writer . clone ( ) ,
293294 self . ctx . config . heartbeat_interval ,
294295 ) ;
295- Some ( std:: mem:: replace ( & mut self . state , new_state) )
296+ self . apply_transition ( TransitionResult :: TransitionTo ( new_state) )
297+ . await ;
298+ Some ( inbound_queue)
296299 } else {
297300 None
298301 }
299302 } else {
300303 None
301304 } ;
302305
303- if let Some ( SessionState :: AwaitingResend ( mut state ) ) = ended_state {
306+ if let Some ( mut inbound_queue ) = backlog {
304307 // we have reached the end of the resend,
305308 // process queued messages and resume normal operation
306309 debug ! ( "resend is done, processing backlog" ) ;
307- while let Some ( msg) = state . inbound_queue . pop_front ( ) {
310+ while let Some ( msg) = inbound_queue. pop_front ( ) {
308311 let seq_number: u64 = msg. get ( MSG_SEQ_NUM ) . unwrap_or_else ( |e| {
309312 error ! ( "failed to get seq number: {:?}" , e) ;
310313 0
@@ -328,39 +331,48 @@ where
328331 }
329332
330333 async fn on_connect ( & mut self , writer : WriterRef ) -> Result < ( ) , SessionOperationError > {
331- self . state = SessionState :: AwaitingLogon ( AwaitingLogonState {
332- writer,
333- logon_sent : false ,
334- logon_timeout : Instant :: now ( ) + Duration :: from_secs ( self . ctx . config . logon_timeout ) ,
335- } ) ;
334+ self . apply_transition ( TransitionResult :: TransitionTo ( SessionState :: AwaitingLogon (
335+ AwaitingLogonState {
336+ writer,
337+ logon_sent : false ,
338+ logon_timeout : Instant :: now ( ) + Duration :: from_secs ( self . ctx . config . logon_timeout ) ,
339+ } ,
340+ ) ) )
341+ . await ;
336342 self . reset_peer_timer ( None ) ;
337343 self . send_logon ( ) . await ?;
338344
339345 Ok ( ( ) )
340346 }
341347
342348 async fn on_disconnect ( & mut self , reason : String ) {
343- match self . state {
349+ let transition = match self . state {
344350 SessionState :: Active ( _)
345351 | SessionState :: AwaitingLogon ( _)
346352 | SessionState :: AwaitingResend ( _) => {
347353 self . state . disconnect_writer ( ) . await ;
348- self . state = SessionState :: new_disconnected ( true , & reason) ;
354+ TransitionResult :: TransitionTo ( SessionState :: new_disconnected ( true , & reason) )
349355 }
350356 SessionState :: Disconnected ( _) => {
351- warn ! ( "disconnect message was received, but the session is already disconnected" )
357+ warn ! ( "disconnect message was received, but the session is already disconnected" ) ;
358+ TransitionResult :: Stay
352359 }
353360 SessionState :: AwaitingLogout ( AwaitingLogoutState { reconnect, .. } ) => {
354- self . state = SessionState :: new_disconnected ( reconnect, & reason) ;
361+ TransitionResult :: TransitionTo ( SessionState :: new_disconnected ( reconnect, & reason) )
355362 }
356- }
363+ } ;
364+ self . apply_transition ( transition) . await ;
357365 }
358366
359367 async fn on_logon ( & mut self ) -> Result < ( ) , SessionOperationError > {
360368 if let SessionState :: AwaitingLogon ( AwaitingLogonState { writer, .. } ) = & self . state {
361369 let writer = writer. clone ( ) ;
362370 // happy logon flow, the session is now active
363- self . state = SessionState :: new_active ( writer, self . ctx . config . heartbeat_interval ) ;
371+ self . apply_transition ( TransitionResult :: TransitionTo ( SessionState :: new_active (
372+ writer,
373+ self . ctx . config . heartbeat_interval ,
374+ ) ) )
375+ . await ;
364376 self . ctx . application . on_logon ( ) . await ;
365377 self . ctx . store . increment_target_seq_number ( ) . await ?;
366378 } else {
@@ -388,12 +400,18 @@ where
388400 // if we initiated the logout, preserve the reconnect flag
389401 SessionState :: AwaitingLogout ( AwaitingLogoutState { reconnect, .. } ) => {
390402 self . state . disconnect_writer ( ) . await ;
391- self . state = SessionState :: new_disconnected ( reconnect, "logout completed" ) ;
403+ self . apply_transition ( TransitionResult :: TransitionTo (
404+ SessionState :: new_disconnected ( reconnect, "logout completed" ) ,
405+ ) )
406+ . await ;
392407 }
393408 // otherwise assume it makes sense to try to reconnect
394409 _ => {
395410 self . state . disconnect_writer ( ) . await ;
396- self . state = SessionState :: new_disconnected ( true , "peer has logged us out" )
411+ self . apply_transition ( TransitionResult :: TransitionTo (
412+ SessionState :: new_disconnected ( true , "peer has logged us out" ) ,
413+ ) )
414+ . await ;
397415 }
398416 }
399417
@@ -462,9 +480,17 @@ where
462480 Ok ( ( ) )
463481 }
464482
465- fn apply_transition ( & mut self , result : TransitionResult ) {
483+ async fn apply_transition ( & mut self , result : TransitionResult ) {
466484 if let TransitionResult :: TransitionTo ( new_state) = result {
485+ let old_status = self . state . as_status ( ) ;
467486 self . state = new_state;
487+ let new_status = self . state . as_status ( ) ;
488+ if old_status != new_status {
489+ self . ctx
490+ . application
491+ . on_state_change ( & old_status, & new_status)
492+ . await ;
493+ }
468494 }
469495 }
470496
@@ -532,7 +558,10 @@ where
532558 self . state
533559 . logout_and_terminate ( & mut self . ctx , "internal error" )
534560 . await ;
535- self . state = SessionState :: new_disconnected ( true , & reason) ;
561+ self . apply_transition ( TransitionResult :: TransitionTo (
562+ SessionState :: new_disconnected ( true , & reason) ,
563+ ) )
564+ . await ;
536565 }
537566 }
538567 SessionEvent :: Disconnected ( reason) => {
@@ -575,12 +604,13 @@ where
575604 match request {
576605 AdminRequest :: InitiateGracefulShutdown { reconnect } => {
577606 warn ! ( "initiating shutdown on request from admin.." ) ;
578- if let Err ( err ) = self
607+ match self
579608 . state
580609 . initiate_graceful_logout ( & mut self . ctx , "explicitly requested" , reconnect)
581610 . await
582611 {
583- error ! ( err = ?err, "initiating graceful shutdown" ) ;
612+ Ok ( result) => self . apply_transition ( result) . await ,
613+ Err ( err) => error ! ( err = ?err, "initiating graceful shutdown" ) ,
584614 }
585615 }
586616 AdminRequest :: RequestSessionInfo ( responder) => {
@@ -646,8 +676,10 @@ where
646676 . await ;
647677 if let Err ( err) = self . ctx . store . reset ( ) . await {
648678 error ! ( "error resetting session store: {err:}" ) ;
649- self . state =
650- SessionState :: new_disconnected ( false , "unexpected error in reset" ) ;
679+ self . apply_transition ( TransitionResult :: TransitionTo (
680+ SessionState :: new_disconnected ( false , "unexpected error in reset" ) ,
681+ ) )
682+ . await ;
651683 }
652684 }
653685 Ok ( SessionPeriodComparison :: OutsideSessionTime { .. } ) => {
@@ -659,8 +691,10 @@ where
659691 . await ;
660692 if let Err ( err) = self . ctx . store . reset ( ) . await {
661693 error ! ( "error resetting session store: {err:}" ) ;
662- self . state =
663- SessionState :: new_disconnected ( false , "unexpected error in reset" ) ;
694+ self . apply_transition ( TransitionResult :: TransitionTo (
695+ SessionState :: new_disconnected ( false , "unexpected error in reset" ) ,
696+ ) )
697+ . await ;
664698 }
665699 }
666700 Err ( err) => {
@@ -673,12 +707,13 @@ where
673707 }
674708 } else if self . state . is_connected ( ) {
675709 // we are currently outside scheduled session time
676- if let Err ( err ) = self
710+ match self
677711 . state
678712 . initiate_graceful_logout ( & mut self . ctx , "End of session time" , true )
679713 . await
680714 {
681- error ! ( err = ?err, "failed to initiate graceful logout" ) ;
715+ Ok ( result) => self . apply_transition ( result) . await ,
716+ Err ( err) => error ! ( err = ?err, "failed to initiate graceful logout" ) ,
682717 }
683718 }
684719
@@ -887,6 +922,8 @@ mod tests {
887922 }
888923 async fn on_logout ( & mut self , _: & str ) { }
889924 async fn on_logon ( & mut self ) { }
925+
926+ async fn on_state_change ( & self , _from : & Status , _to : & Status ) { }
890927 }
891928
892929 fn create_writer_ref ( ) -> WriterRef {
0 commit comments