diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index 7e64adf3..b86bdca0 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -15,6 +15,7 @@ use crate::application::Application; use crate::config::SessionConfig; use crate::message::OutboundMessage; use crate::session::error::{SendError, SendOutcome, SessionCreationError}; +use crate::session::event::ScheduleResponse; use crate::session::{InternalSessionRef, SessionHandle}; use crate::store::MessageStore; use crate::transport::connect; @@ -107,9 +108,18 @@ async fn establish_connection( completion_tx: watch::Sender, ) { loop { - if session_ref.await_in_schedule().await.is_err() { - warn!("session task terminated when checking schedule"); - break; + match session_ref.await_in_schedule().await { + Ok(ScheduleResponse::InSchedule) => { + debug!("resuming connection as schedule is active"); + } + Ok(ScheduleResponse::Shutdown) => { + warn!("session indicated shutdown during schedule check"); + break; + } + Err(_) => { + warn!("session task terminated when checking schedule"); + break; + } } match connect(&config, session_ref.clone()).await { diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 0165324b..fc28d81a 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -9,7 +9,7 @@ mod session_handle; pub mod session_ref; mod state; #[cfg(test)] -mod test_utils; +pub(crate) mod test_utils; use chrono::Utc; use hotfix_message::dict::Dictionary; diff --git a/crates/hotfix/src/session/session_ref.rs b/crates/hotfix/src/session/session_ref.rs index 15aa3957..a6971465 100644 --- a/crates/hotfix/src/session/session_ref.rs +++ b/crates/hotfix/src/session/session_ref.rs @@ -82,16 +82,13 @@ impl InternalSessionRef { Ok(receiver.await?) } - pub async fn await_in_schedule(&self) -> Result<(), SessionGone> { + pub async fn await_in_schedule(&self) -> Result { debug!("awaiting in-schedule time"); let (sender, receiver) = oneshot::channel::(); self.event_sender .send(SessionEvent::AwaitSchedule(sender)) .await?; - receiver.await?; - - debug!("resuming connection as schedule is active"); - Ok(()) + Ok(receiver.await?) } } @@ -110,3 +107,52 @@ impl From for SessionGone { Self(err.to_string()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::session::test_utils::create_test_session_ref; + + #[tokio::test] + async fn await_in_schedule_returns_in_schedule_when_session_responds_in_schedule() { + let (session_ref, mut event_receiver) = create_test_session_ref(); + + tokio::spawn(async move { + match event_receiver.recv().await { + Some(SessionEvent::AwaitSchedule(responder)) => { + let _ = responder.send(ScheduleResponse::InSchedule); + } + other => panic!("unexpected event: {other:?}"), + } + }); + + let result = session_ref.await_in_schedule().await; + assert!(matches!(result, Ok(ScheduleResponse::InSchedule))); + } + + #[tokio::test] + async fn await_in_schedule_returns_shutdown_when_session_responds_shutdown() { + let (session_ref, mut event_receiver) = create_test_session_ref(); + + tokio::spawn(async move { + match event_receiver.recv().await { + Some(SessionEvent::AwaitSchedule(responder)) => { + let _ = responder.send(ScheduleResponse::Shutdown); + } + other => panic!("unexpected event: {other:?}"), + } + }); + + let result = session_ref.await_in_schedule().await; + assert!(matches!(result, Ok(ScheduleResponse::Shutdown))); + } + + #[tokio::test] + async fn await_in_schedule_returns_err_when_event_channel_closed() { + let (session_ref, event_receiver) = create_test_session_ref(); + drop(event_receiver); + + let result = session_ref.await_in_schedule().await; + assert!(matches!(result, Err(SessionGone(_)))); + } +} diff --git a/crates/hotfix/src/session/test_utils.rs b/crates/hotfix/src/session/test_utils.rs index d71becc0..a7e6cfae 100644 --- a/crates/hotfix/src/session/test_utils.rs +++ b/crates/hotfix/src/session/test_utils.rs @@ -1,5 +1,9 @@ use crate::config::SessionConfig; +use crate::message::{Message, OutboundMessage}; +use crate::session::admin_request::AdminRequest; use crate::session::ctx::SessionCtx; +use crate::session::event::SessionEvent; +use crate::session::session_ref::{InternalSessionRef, OutboundRequest}; use crate::store::{MessageStore, Result as StoreResult}; use crate::transport::writer::{WriterMessage, WriterRef}; use chrono::{DateTime, Utc}; @@ -117,3 +121,31 @@ pub(crate) fn extract_field(raw: &[u8], tag: u32) -> Option { } None } + +#[derive(Clone)] +pub(crate) struct TestMessage; + +impl OutboundMessage for TestMessage { + fn write(&self, _msg: &mut Message) {} + fn message_type(&self) -> &str { + "TEST" + } +} + +pub(crate) fn create_test_session_ref() -> ( + InternalSessionRef, + mpsc::Receiver, +) { + let (event_sender, event_receiver) = mpsc::channel::(100); + let (outbound_message_sender, _outbound_receiver) = + mpsc::channel::>(10); + let (admin_request_sender, _admin_receiver) = mpsc::channel::(10); + + let session_ref = InternalSessionRef { + event_sender, + outbound_message_sender, + admin_request_sender, + }; + + (session_ref, event_receiver) +} diff --git a/crates/hotfix/src/transport/socket/socket_reader.rs b/crates/hotfix/src/transport/socket/socket_reader.rs index 4e9ad0c9..3cdbcaab 100644 --- a/crates/hotfix/src/transport/socket/socket_reader.rs +++ b/crates/hotfix/src/transport/socket/socket_reader.rs @@ -84,42 +84,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::message::Message; - use crate::session::admin_request::AdminRequest; use crate::session::event::SessionEvent; - use crate::session::session_ref::OutboundRequest; + use crate::session::test_utils::create_test_session_ref; use tokio::io::{AsyncWriteExt, duplex}; - use tokio::sync::mpsc; - - #[derive(Clone, Debug, PartialEq)] - struct TestMessage; - - impl OutboundMessage for TestMessage { - fn write(&self, _msg: &mut Message) {} - - fn message_type(&self) -> &str { - "TEST" - } - } - - /// Creates a test InternalSessionRef that captures events for verification - fn create_test_session_ref() -> ( - InternalSessionRef, - mpsc::Receiver, - ) { - let (event_sender, event_receiver) = mpsc::channel::(100); - let (outbound_message_sender, _outbound_receiver) = - mpsc::channel::>(10); - let (admin_request_sender, _admin_receiver) = mpsc::channel::(10); - - let session_ref = InternalSessionRef { - event_sender, - outbound_message_sender, - admin_request_sender, - }; - - (session_ref, event_receiver) - } /// Test that the reader correctly parses a valid FIX message and sends it to the session /// for processing.