Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 140 additions & 3 deletions crates/hotfix/src/transport/connection.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,157 @@
use std::time::Duration;

use tokio::sync::oneshot;
use tracing::warn;

use crate::transport::reader::ReaderRef;
use crate::transport::writer::WriterRef;

const FORCE_CLOSE_TIMEOUT: Duration = Duration::from_secs(10);

pub struct FixConnection {
writer: WriterRef,
reader: ReaderRef,
writer_exit: oneshot::Receiver<()>,
}

impl FixConnection {
pub fn new(writer: WriterRef, reader: ReaderRef) -> Self {
Self { writer, reader }
pub fn new(writer: WriterRef, reader: ReaderRef, writer_exit: oneshot::Receiver<()>) -> Self {
Self {
writer,
reader,
writer_exit,
}
}

pub fn get_writer(&self) -> WriterRef {
self.writer.clone()
}

pub async fn run_until_disconnect(self) {
self.reader.wait_for_disconnect().await
let Self {
reader,
mut writer_exit,
..
} = self;
let ReaderRef {
mut disconnect_signal,
kill,
} = reader;

tokio::select! {
_ = &mut disconnect_signal => return,
_ = &mut writer_exit => {}
}

match tokio::time::timeout(FORCE_CLOSE_TIMEOUT, &mut disconnect_signal).await {
Ok(_) => {}
Err(_) => {
warn!(
"reader did not observe EOF within {:?}, forcing close",
FORCE_CLOSE_TIMEOUT
);
let _ = kill.send(());
let _ = disconnect_signal.await;
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::transport::writer::WriterMessage;
use tokio::sync::mpsc;

/// Build a `FixConnection` and return the ends the test controls:
/// dc_sender to fire from the "reader", writer_exit_tx to fire from the "writer",
/// and kill_rx so the test can observe or simulate the reader being killed.
fn test_connection() -> (
FixConnection,
oneshot::Sender<()>,
oneshot::Sender<()>,
oneshot::Receiver<()>,
) {
let (dc_tx, dc_rx) = oneshot::channel::<()>();
let (kill_tx, kill_rx) = oneshot::channel::<()>();
let reader_ref = ReaderRef::new(dc_rx, kill_tx);

let (writer_mpsc_tx, _writer_mpsc_rx) = mpsc::channel::<WriterMessage>(1);
let writer_ref = WriterRef::new(writer_mpsc_tx);

let (writer_exit_tx, writer_exit_rx) = oneshot::channel::<()>();

let conn = FixConnection::new(writer_ref, reader_ref, writer_exit_rx);
(conn, dc_tx, writer_exit_tx, kill_rx)
}

/// Reader signals disconnect first — return immediately, kill is never sent.
#[tokio::test(start_paused = true)]
async fn returns_on_reader_disconnect_before_writer_exit() {
let (conn, dc_tx, _writer_exit_tx, mut kill_rx) = test_connection();

dc_tx.send(()).expect("dc receiver dropped");

conn.run_until_disconnect().await;

// Kill should not have been sent. The sender has been dropped by now
// (scope ended inside run_until_disconnect), so try_recv returns Closed
// rather than Empty. Either way, an Ok(()) would mean kill was sent.
assert!(
!matches!(kill_rx.try_recv(), Ok(())),
"kill signal should not have been sent"
);
}

/// Writer exits first, reader disconnects within the watchdog window — no kill.
#[tokio::test(start_paused = true)]
async fn returns_when_reader_disconnects_after_writer_exit_within_timeout() {
let (conn, dc_tx, writer_exit_tx, mut kill_rx) = test_connection();

writer_exit_tx
.send(())
.expect("writer_exit receiver dropped");

// Fire the reader disconnect from a task that runs on the same paused clock.
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(1)).await;
let _ = dc_tx.send(());
});

conn.run_until_disconnect().await;

assert!(
!matches!(kill_rx.try_recv(), Ok(())),
"kill signal should not have been sent when reader disconnected within timeout"
);
}

/// Writer exits first, reader stays blocked past the watchdog — kill fires,
/// and a simulated reader fires dc once it sees the kill.
#[tokio::test(start_paused = true)]
async fn watchdog_fires_kill_when_reader_stuck() {
let (conn, dc_tx, writer_exit_tx, kill_rx) = test_connection();

writer_exit_tx
.send(())
.expect("writer_exit receiver dropped");

// Stand in for the reader: when the watchdog kills us, we publish dc.
tokio::spawn(async move {
if kill_rx.await.is_ok() {
let _ = dc_tx.send(());
}
});

let start = tokio::time::Instant::now();
conn.run_until_disconnect().await;
let elapsed = start.elapsed();

assert!(
elapsed >= FORCE_CLOSE_TIMEOUT,
"expected watchdog to take at least {:?}, took {:?}",
FORCE_CLOSE_TIMEOUT,
elapsed
);
}
}
10 changes: 7 additions & 3 deletions crates/hotfix/src/transport/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ use tracing::warn;
pub struct ReaderMessage;

pub struct ReaderRef {
disconnect_signal: oneshot::Receiver<()>,
pub(crate) disconnect_signal: oneshot::Receiver<()>,
pub(crate) kill: oneshot::Sender<()>,
}

impl ReaderRef {
pub fn new(disconnect_signal: oneshot::Receiver<()>) -> Self {
Self { disconnect_signal }
pub fn new(disconnect_signal: oneshot::Receiver<()>, kill: oneshot::Sender<()>) -> Self {
Self {
disconnect_signal,
kill,
}
}

pub async fn wait_for_disconnect(self) {
Expand Down
4 changes: 2 additions & 2 deletions crates/hotfix/src/transport/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ where
{
let (reader, writer) = tokio::io::split(stream);

let writer_ref = spawn_socket_writer(writer);
let (writer_ref, writer_exit) = spawn_socket_writer(writer);
let reader_ref = spawn_socket_reader(reader, session_ref);

FixConnection::new(writer_ref, reader_ref)
FixConnection::new(writer_ref, reader_ref, writer_exit)
}
103 changes: 77 additions & 26 deletions crates/hotfix/src/transport/socket/socket_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ pub fn spawn_socket_reader(
session_ref: InternalSessionRef<impl OutboundMessage>,
) -> ReaderRef {
let (dc_sender, dc_receiver) = oneshot::channel();
let (kill_sender, kill_receiver) = oneshot::channel();
let actor = ReaderActor::new(reader, session_ref, dc_sender);
tokio::spawn(run_reader(actor));
tokio::spawn(run_reader(actor, kill_receiver));

ReaderRef::new(dc_receiver)
ReaderRef::new(dc_receiver, kill_sender)
}

struct ReaderActor<M, R> {
Expand All @@ -38,40 +39,52 @@ impl<M, R: AsyncRead> ReaderActor<M, R> {
}
}

async fn run_reader<Outbound, R>(mut actor: ReaderActor<Outbound, R>)
where
async fn run_reader<Outbound, R>(
mut actor: ReaderActor<Outbound, R>,
mut kill_rx: oneshot::Receiver<()>,
) where
Outbound: OutboundMessage,
R: AsyncRead,
{
let mut parser = Parser::default();
loop {
let mut buf = vec![];

match actor.reader.read_buf(&mut buf).await {
Ok(0) => {
let _ = actor
.session_ref
.disconnect("received EOF".to_string())
.await;
break;
}
Err(err) => {
let _ = actor.session_ref.disconnect(err.to_string()).await;
break;
}
Ok(_) => {
let messages = parser.parse(&buf);

for msg in messages {
if actor
tokio::select! {
result = actor.reader.read_buf(&mut buf) => match result {
Ok(0) => {
let _ = actor
.session_ref
.new_fix_message_received(msg)
.await
.is_err()
{
debug!("reader received message but session has been terminated");
.disconnect("received EOF".to_string())
.await;
break;
}
Err(err) => {
let _ = actor.session_ref.disconnect(err.to_string()).await;
break;
}
Ok(_) => {
let messages = parser.parse(&buf);

for msg in messages {
if actor
.session_ref
.new_fix_message_received(msg)
.await
.is_err()
{
debug!("reader received message but session has been terminated");
}
}
}
},
res = &mut kill_rx => {
let reason = match res {
Ok(()) => "forced close by watchdog",
Err(_) => "reader handle dropped",
};
let _ = actor.session_ref.disconnect(reason.to_string()).await;
break;
}
}
}
Expand Down Expand Up @@ -228,4 +241,42 @@ mod tests {
// wait for disconnect signal
let _ = reader_ref.wait_for_disconnect().await;
}

/// Kill signal terminates the reader even when the peer is silent, and
/// the session observes the watchdog-sourced disconnect reason.
#[tokio::test]
async fn kill_signal_terminates_reader() {
let (_writer, reader) = duplex(1024);
let (reader_half, _writer_half) = tokio::io::split(reader);

let (session_ref, mut event_receiver) = create_test_session_ref();
let reader_ref = spawn_socket_reader(reader_half, session_ref);

// Destructure so we can both fire the kill and later await the disconnect signal.
let ReaderRef {
disconnect_signal,
kill,
} = reader_ref;

kill.send(()).expect("kill receiver dropped");

// Reader should publish the watchdog reason to the session.
match tokio::time::timeout(
tokio::time::Duration::from_millis(100),
event_receiver.recv(),
)
.await
{
Ok(Some(SessionEvent::Disconnected(reason))) => {
assert_eq!(reason, "forced close by watchdog");
}
other => panic!("expected Disconnected(\"forced close by watchdog\"), got {other:?}"),
}

// And the disconnect signal should fire shortly after.
tokio::time::timeout(tokio::time::Duration::from_millis(100), disconnect_signal)
.await
.expect("disconnect signal not fired within timeout")
.expect("disconnect sender dropped without signalling");
}
}
Loading
Loading