diff --git a/crates/hotfix/src/transport/connection.rs b/crates/hotfix/src/transport/connection.rs index bb2cf19..5c4590e 100644 --- a/crates/hotfix/src/transport/connection.rs +++ b/crates/hotfix/src/transport/connection.rs @@ -1,20 +1,157 @@ +use std::time::Duration; + +use tokio::sync::oneshot; +use tracing::warn; + use crate::transport::reader::ReaderRef; use crate::transport::writer::WriterRef; +const FORCE_CLOSE_TIMEOUT: Duration = Duration::from_secs(10); + pub struct FixConnection { writer: WriterRef, reader: ReaderRef, + writer_exit: oneshot::Receiver<()>, } impl FixConnection { - pub fn new(writer: WriterRef, reader: ReaderRef) -> Self { - Self { writer, reader } + pub fn new(writer: WriterRef, reader: ReaderRef, writer_exit: oneshot::Receiver<()>) -> Self { + Self { + writer, + reader, + writer_exit, + } } + pub fn get_writer(&self) -> WriterRef { self.writer.clone() } pub async fn run_until_disconnect(self) { - self.reader.wait_for_disconnect().await + let Self { + reader, + mut writer_exit, + .. + } = self; + let ReaderRef { + mut disconnect_signal, + kill, + } = reader; + + tokio::select! { + _ = &mut disconnect_signal => return, + _ = &mut writer_exit => {} + } + + match tokio::time::timeout(FORCE_CLOSE_TIMEOUT, &mut disconnect_signal).await { + Ok(_) => {} + Err(_) => { + warn!( + "reader did not observe EOF within {:?}, forcing close", + FORCE_CLOSE_TIMEOUT + ); + let _ = kill.send(()); + let _ = disconnect_signal.await; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transport::writer::WriterMessage; + use tokio::sync::mpsc; + + /// Build a `FixConnection` and return the ends the test controls: + /// dc_sender to fire from the "reader", writer_exit_tx to fire from the "writer", + /// and kill_rx so the test can observe or simulate the reader being killed. + fn test_connection() -> ( + FixConnection, + oneshot::Sender<()>, + oneshot::Sender<()>, + oneshot::Receiver<()>, + ) { + let (dc_tx, dc_rx) = oneshot::channel::<()>(); + let (kill_tx, kill_rx) = oneshot::channel::<()>(); + let reader_ref = ReaderRef::new(dc_rx, kill_tx); + + let (writer_mpsc_tx, _writer_mpsc_rx) = mpsc::channel::(1); + let writer_ref = WriterRef::new(writer_mpsc_tx); + + let (writer_exit_tx, writer_exit_rx) = oneshot::channel::<()>(); + + let conn = FixConnection::new(writer_ref, reader_ref, writer_exit_rx); + (conn, dc_tx, writer_exit_tx, kill_rx) + } + + /// Reader signals disconnect first — return immediately, kill is never sent. + #[tokio::test(start_paused = true)] + async fn returns_on_reader_disconnect_before_writer_exit() { + let (conn, dc_tx, _writer_exit_tx, mut kill_rx) = test_connection(); + + dc_tx.send(()).expect("dc receiver dropped"); + + conn.run_until_disconnect().await; + + // Kill should not have been sent. The sender has been dropped by now + // (scope ended inside run_until_disconnect), so try_recv returns Closed + // rather than Empty. Either way, an Ok(()) would mean kill was sent. + assert!( + !matches!(kill_rx.try_recv(), Ok(())), + "kill signal should not have been sent" + ); + } + + /// Writer exits first, reader disconnects within the watchdog window — no kill. + #[tokio::test(start_paused = true)] + async fn returns_when_reader_disconnects_after_writer_exit_within_timeout() { + let (conn, dc_tx, writer_exit_tx, mut kill_rx) = test_connection(); + + writer_exit_tx + .send(()) + .expect("writer_exit receiver dropped"); + + // Fire the reader disconnect from a task that runs on the same paused clock. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + let _ = dc_tx.send(()); + }); + + conn.run_until_disconnect().await; + + assert!( + !matches!(kill_rx.try_recv(), Ok(())), + "kill signal should not have been sent when reader disconnected within timeout" + ); + } + + /// Writer exits first, reader stays blocked past the watchdog — kill fires, + /// and a simulated reader fires dc once it sees the kill. + #[tokio::test(start_paused = true)] + async fn watchdog_fires_kill_when_reader_stuck() { + let (conn, dc_tx, writer_exit_tx, kill_rx) = test_connection(); + + writer_exit_tx + .send(()) + .expect("writer_exit receiver dropped"); + + // Stand in for the reader: when the watchdog kills us, we publish dc. + tokio::spawn(async move { + if kill_rx.await.is_ok() { + let _ = dc_tx.send(()); + } + }); + + let start = tokio::time::Instant::now(); + conn.run_until_disconnect().await; + let elapsed = start.elapsed(); + + assert!( + elapsed >= FORCE_CLOSE_TIMEOUT, + "expected watchdog to take at least {:?}, took {:?}", + FORCE_CLOSE_TIMEOUT, + elapsed + ); } } diff --git a/crates/hotfix/src/transport/reader.rs b/crates/hotfix/src/transport/reader.rs index 46ee36d..2e8caac 100644 --- a/crates/hotfix/src/transport/reader.rs +++ b/crates/hotfix/src/transport/reader.rs @@ -6,12 +6,16 @@ use tracing::warn; pub struct ReaderMessage; pub struct ReaderRef { - disconnect_signal: oneshot::Receiver<()>, + pub(crate) disconnect_signal: oneshot::Receiver<()>, + pub(crate) kill: oneshot::Sender<()>, } impl ReaderRef { - pub fn new(disconnect_signal: oneshot::Receiver<()>) -> Self { - Self { disconnect_signal } + pub fn new(disconnect_signal: oneshot::Receiver<()>, kill: oneshot::Sender<()>) -> Self { + Self { + disconnect_signal, + kill, + } } pub async fn wait_for_disconnect(self) { diff --git a/crates/hotfix/src/transport/socket.rs b/crates/hotfix/src/transport/socket.rs index 99481ab..e8ba208 100644 --- a/crates/hotfix/src/transport/socket.rs +++ b/crates/hotfix/src/transport/socket.rs @@ -47,8 +47,8 @@ where { let (reader, writer) = tokio::io::split(stream); - let writer_ref = spawn_socket_writer(writer); + let (writer_ref, writer_exit) = spawn_socket_writer(writer); let reader_ref = spawn_socket_reader(reader, session_ref); - FixConnection::new(writer_ref, reader_ref) + FixConnection::new(writer_ref, reader_ref, writer_exit) } diff --git a/crates/hotfix/src/transport/socket/socket_reader.rs b/crates/hotfix/src/transport/socket/socket_reader.rs index 3cdbcaa..418b341 100644 --- a/crates/hotfix/src/transport/socket/socket_reader.rs +++ b/crates/hotfix/src/transport/socket/socket_reader.rs @@ -12,10 +12,11 @@ pub fn spawn_socket_reader( session_ref: InternalSessionRef, ) -> ReaderRef { let (dc_sender, dc_receiver) = oneshot::channel(); + let (kill_sender, kill_receiver) = oneshot::channel(); let actor = ReaderActor::new(reader, session_ref, dc_sender); - tokio::spawn(run_reader(actor)); + tokio::spawn(run_reader(actor, kill_receiver)); - ReaderRef::new(dc_receiver) + ReaderRef::new(dc_receiver, kill_sender) } struct ReaderActor { @@ -38,8 +39,10 @@ impl ReaderActor { } } -async fn run_reader(mut actor: ReaderActor) -where +async fn run_reader( + mut actor: ReaderActor, + mut kill_rx: oneshot::Receiver<()>, +) where Outbound: OutboundMessage, R: AsyncRead, { @@ -47,31 +50,41 @@ where loop { let mut buf = vec![]; - match actor.reader.read_buf(&mut buf).await { - Ok(0) => { - let _ = actor - .session_ref - .disconnect("received EOF".to_string()) - .await; - break; - } - Err(err) => { - let _ = actor.session_ref.disconnect(err.to_string()).await; - break; - } - Ok(_) => { - let messages = parser.parse(&buf); - - for msg in messages { - if actor + tokio::select! { + result = actor.reader.read_buf(&mut buf) => match result { + Ok(0) => { + let _ = actor .session_ref - .new_fix_message_received(msg) - .await - .is_err() - { - debug!("reader received message but session has been terminated"); + .disconnect("received EOF".to_string()) + .await; + break; + } + Err(err) => { + let _ = actor.session_ref.disconnect(err.to_string()).await; + break; + } + Ok(_) => { + let messages = parser.parse(&buf); + + for msg in messages { + if actor + .session_ref + .new_fix_message_received(msg) + .await + .is_err() + { + debug!("reader received message but session has been terminated"); + } } } + }, + res = &mut kill_rx => { + let reason = match res { + Ok(()) => "forced close by watchdog", + Err(_) => "reader handle dropped", + }; + let _ = actor.session_ref.disconnect(reason.to_string()).await; + break; } } } @@ -228,4 +241,42 @@ mod tests { // wait for disconnect signal let _ = reader_ref.wait_for_disconnect().await; } + + /// Kill signal terminates the reader even when the peer is silent, and + /// the session observes the watchdog-sourced disconnect reason. + #[tokio::test] + async fn kill_signal_terminates_reader() { + let (_writer, reader) = duplex(1024); + let (reader_half, _writer_half) = tokio::io::split(reader); + + let (session_ref, mut event_receiver) = create_test_session_ref(); + let reader_ref = spawn_socket_reader(reader_half, session_ref); + + // Destructure so we can both fire the kill and later await the disconnect signal. + let ReaderRef { + disconnect_signal, + kill, + } = reader_ref; + + kill.send(()).expect("kill receiver dropped"); + + // Reader should publish the watchdog reason to the session. + match tokio::time::timeout( + tokio::time::Duration::from_millis(100), + event_receiver.recv(), + ) + .await + { + Ok(Some(SessionEvent::Disconnected(reason))) => { + assert_eq!(reason, "forced close by watchdog"); + } + other => panic!("expected Disconnected(\"forced close by watchdog\"), got {other:?}"), + } + + // And the disconnect signal should fire shortly after. + tokio::time::timeout(tokio::time::Duration::from_millis(100), disconnect_signal) + .await + .expect("disconnect signal not fired within timeout") + .expect("disconnect sender dropped without signalling"); + } } diff --git a/crates/hotfix/src/transport/socket/socket_writer.rs b/crates/hotfix/src/transport/socket/socket_writer.rs index 3a1e877..930b2c6 100644 --- a/crates/hotfix/src/transport/socket/socket_writer.rs +++ b/crates/hotfix/src/transport/socket/socket_writer.rs @@ -1,14 +1,22 @@ +use std::time::Duration; + use crate::transport::writer::{WriterMessage, WriterRef}; use tokio::io::{AsyncWrite, AsyncWriteExt, WriteHalf}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tracing::{debug, warn}; -pub fn spawn_socket_writer(writer: WriteHalf) -> WriterRef { +const WRITER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +pub fn spawn_socket_writer(writer: WriteHalf) -> (WriterRef, oneshot::Receiver<()>) +where + W: AsyncWrite + Send + 'static, +{ let (sender, mailbox) = mpsc::channel(10); + let (exit_tx, exit_rx) = oneshot::channel(); let actor = WriterActor::new(writer, mailbox); - tokio::spawn(run_writer(actor)); + tokio::spawn(run_writer(actor, exit_tx)); - WriterRef::new(sender) + (WriterRef::new(sender), exit_rx) } struct WriterActor { @@ -37,13 +45,23 @@ impl WriterActor { } } -async fn run_writer(mut actor: WriterActor) { +async fn run_writer(mut actor: WriterActor, exit_tx: oneshot::Sender<()>) { while let Some(msg) = actor.mailbox.recv().await { if !actor.handle(msg).await { break; } } + match tokio::time::timeout(WRITER_SHUTDOWN_TIMEOUT, actor.writer.shutdown()).await { + Ok(Ok(())) => debug!("writer half closed cleanly"), + Ok(Err(err)) => warn!("writer shutdown returned error: {err}"), + Err(_) => warn!( + "writer shutdown timed out after {:?}", + WRITER_SHUTDOWN_TIMEOUT + ), + } + + let _ = exit_tx.send(()); debug!("writer loop is shutting down"); } @@ -58,7 +76,7 @@ mod tests { async fn test_send_single_message() { let (reader, writer) = duplex(1024); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); let fix_message = b"8=FIX.4.4\x019=77\x0135=A\x0134=1\x0149=sender\x0152=20230908-08:24:56.574\x0156=target\x0198=0\x01108=30\x01141=Y\x0110=037\x01"; let raw_message = RawFixMessage::new(fix_message.to_vec()); @@ -84,7 +102,7 @@ mod tests { async fn test_send_multiple_messages() { let (reader, writer) = duplex(2048); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); let msg1 = b"8=FIX.4.4\x019=77\x0135=A\x0134=1\x0149=sender\x0152=20230908-08:24:56.574\x0156=target\x0198=0\x01108=30\x01141=Y\x0110=037\x01"; let msg2 = b"8=FIX.4.4\x019=77\x0135=A\x0134=2\x0149=sender\x0152=20230908-08:24:58.574\x0156=target\x0198=0\x01108=30\x01141=Y\x0110=040\x01"; @@ -125,7 +143,7 @@ mod tests { async fn test_disconnect() { let (reader, writer) = duplex(1024); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); // send a message first let fix_message = b"8=FIX.4.4\x019=77\x0135=A\x0134=1\x0149=sender\x0152=20230908-08:24:56.574\x0156=target\x0198=0\x01108=30\x01141=Y\x0110=037\x01"; @@ -155,7 +173,7 @@ mod tests { async fn test_send_empty_message() { let (reader, writer) = duplex(1024); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); let empty_message = RawFixMessage::new(vec![]); writer_ref.send_raw_message(empty_message).await; @@ -187,7 +205,7 @@ mod tests { async fn test_writer_shutdown_on_mailbox_close() { let (_reader, writer) = duplex(1024); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); // send a message to ensure the writer is running let fix_message = b"8=FIX.4.4\x019=77\x0135=A\x0134=1\x0149=sender\x0152=20230908-08:24:56.574\x0156=target\x0198=0\x01108=30\x01141=Y\x0110=037\x01"; @@ -210,7 +228,7 @@ mod tests { async fn test_write_error_handling() { let (reader, writer) = duplex(1024); let (_reader_half, writer_half) = tokio::io::split(writer); - let writer_ref = spawn_socket_writer(writer_half); + let (writer_ref, _exit_rx) = spawn_socket_writer(writer_half); // close the reader end, which should cause write errors drop(reader); @@ -231,4 +249,94 @@ mod tests { // and continued running (as per the code comment that it only shuts down // when explicitly requested) } + + /// After processing Disconnect, the actor calls shutdown() on its WriteHalf, + /// which for a duplex stream surfaces as EOF on the peer read side. + #[tokio::test] + async fn shutdown_called_on_disconnect() { + let (reader, writer) = duplex(1024); + let (_reader_half, writer_half) = tokio::io::split(writer); + let (writer_ref, exit_rx) = spawn_socket_writer(writer_half); + + writer_ref.disconnect().await; + + tokio::time::timeout(tokio::time::Duration::from_millis(200), exit_rx) + .await + .expect("exit signal not fired within timeout") + .expect("exit sender dropped without signalling"); + + // Peer side of the duplex should observe EOF after shutdown. + let mut reader = reader; + let mut buf = vec![0u8; 16]; + let n = tokio::time::timeout( + tokio::time::Duration::from_millis(200), + reader.read(&mut buf), + ) + .await + .expect("read timed out — shutdown did not surface as EOF") + .expect("read failed"); + assert_eq!(n, 0, "expected EOF after writer shutdown, read {n} bytes"); + } + + /// Fallback exit path: all WriterRef clones dropped without sending Disconnect. + /// The actor's mailbox closes, the loop exits, shutdown() runs, and exit fires. + #[tokio::test] + async fn exit_signal_fires_when_all_senders_dropped() { + let (_reader, writer) = duplex(1024); + let (_reader_half, writer_half) = tokio::io::split(writer); + let (writer_ref, exit_rx) = spawn_socket_writer(writer_half); + + drop(writer_ref); + + tokio::time::timeout(tokio::time::Duration::from_millis(200), exit_rx) + .await + .expect("exit signal not fired within timeout") + .expect("exit sender dropped without signalling"); + } + + use std::pin::Pin; + use std::task::{Context, Poll}; + use tokio::io::AsyncWrite; + + /// `AsyncWrite` where `poll_write` succeeds but `poll_shutdown` hangs forever. + struct StuckShutdownWriter; + + impl AsyncWrite for StuckShutdownWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + } + + /// If shutdown() never resolves, the writer still exits after WRITER_SHUTDOWN_TIMEOUT. + /// Virtual time via `start_paused = true` keeps the test fast. + #[tokio::test(start_paused = true)] + async fn shutdown_timeout_does_not_block_exit() { + // Build a split pair around StuckShutdownWriter. It only implements AsyncWrite; + // we wrap with `tokio::io::join` to supply a dummy AsyncRead. + let stuck = tokio::io::join(tokio::io::empty(), StuckShutdownWriter); + let (_read_half, write_half) = tokio::io::split(stuck); + let (writer_ref, exit_rx) = spawn_socket_writer(write_half); + + writer_ref.disconnect().await; + + // Advance virtual time past the shutdown timeout. + tokio::time::advance(WRITER_SHUTDOWN_TIMEOUT + std::time::Duration::from_millis(100)).await; + + // Exit should have fired by now. + exit_rx + .await + .expect("exit sender dropped without signalling"); + } } diff --git a/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs index 219d7a1..28314ae 100644 --- a/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs +++ b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs @@ -37,7 +37,8 @@ where ) -> Result { let (writer_ref, receiver) = Self::create_writer(); let (reader_ref, dc_sender) = Self::create_reader(); - let connection = FixConnection::new(writer_ref, reader_ref); + let (_writer_exit_tx, writer_exit_rx) = oneshot::channel(); + let connection = FixConnection::new(writer_ref, reader_ref, writer_exit_rx); let message_config = MessageConfig::default(); let message_builder = MessageBuilder::new(Dictionary::fix44(), message_config)?; @@ -61,7 +62,8 @@ where pub async fn reconnect(&mut self, reset_store: bool) -> Result<()> { let (writer_ref, receiver) = Self::create_writer(); let (reader_ref, dc_sender) = Self::create_reader(); - let connection = FixConnection::new(writer_ref, reader_ref); + let (_writer_exit_tx, writer_exit_rx) = oneshot::channel(); + let connection = FixConnection::new(writer_ref, reader_ref, writer_exit_rx); self.receiver = receiver; self._dc_sender = dc_sender; @@ -250,6 +252,7 @@ where fn create_reader() -> (ReaderRef, oneshot::Sender<()>) { let (dc_sender, dc_receiver) = oneshot::channel(); - (ReaderRef::new(dc_receiver), dc_sender) + let (kill_sender, _kill_receiver) = oneshot::channel(); + (ReaderRef::new(dc_receiver, kill_sender), dc_sender) } }