Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion plugboard-schemas/plugboard_schemas/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from ._validator_registry import validator


_SYSTEM_STOP_EVENT = "system_stop"


def _build_component_graph(
connectors: dict[str, dict[str, _t.Any]],
) -> dict[str, set[str]]:
Expand Down Expand Up @@ -98,9 +101,11 @@ def validate_all_inputs_connected(
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
all_inputs = set(io.get("inputs", []))
input_events = set(io.get("input_events", []))
has_non_system_input_events = bool(input_events - {_SYSTEM_STOP_EVENT})
connected = connected_inputs.get(comp_name, set())
unconnected = all_inputs - connected
if unconnected:
if unconnected and not has_non_system_input_events:
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
return errors

Expand Down
16 changes: 12 additions & 4 deletions plugboard/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ async def _wrapper() -> None:
raise e
self._bind_outputs()
await self.io.write()
self._field_inputs_ready = False
self._reset_input_trackers()
await self._set_status(Status.WAITING, publish=not self._is_running)

return _wrapper
Expand All @@ -365,6 +365,11 @@ async def _wrapper() -> None:
def _has_field_inputs(self) -> bool:
return len(self.io.inputs) > 0

@property
def _has_connected_field_inputs(self) -> bool:
"""Whether any declared field inputs are connected via input channels."""
return self.io.has_connected_field_inputs

@cached_property
def _has_event_inputs(self) -> bool:
input_events = set([evt.safe_type() for evt in self.io.input_events])
Expand Down Expand Up @@ -409,7 +414,7 @@ async def _io_read_with_status_check(self) -> None:
task.cancel()
for task in done:
exc = task.exception()
if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0:
if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs:
await self.io.close() # Call close for final wait and flush event buffer
elif exc is not None:
raise exc
Expand All @@ -422,7 +427,7 @@ async def _periodic_status_check(self) -> None:
# TODO : Eventually producer graph update will be event driven. For now,
# : the update is performed periodically, so it's called here along
# : with the status check.
if len(self.io.inputs) == 0:
if not self._has_connected_field_inputs:
await self._update_producer_graph()

async def _status_check(self) -> None:
Expand Down Expand Up @@ -455,8 +460,11 @@ def _bind_inputs(self) -> None:
for field in self.io.inputs:
field_default = getattr(self, field, None)
value = self._field_inputs.get(field, field_default)
setattr(self, field, value)
super().__setattr__(field, value)

def _reset_input_trackers(self) -> None:
self._field_inputs = {}
self._field_inputs_ready = False

def _bind_outputs(self) -> None:
"""Binds component fields to output fields."""
Expand Down
11 changes: 6 additions & 5 deletions plugboard/component/io_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def is_closed(self) -> bool:
"""Returns `True` if the `IOController` is closed, `False` otherwise."""
return self._is_closed

@cached_property
def _has_field_inputs(self) -> bool:
@property
def has_connected_field_inputs(self) -> bool:
"""Returns whether any field inputs are connected via channels."""
return len(self._input_channels) > 0

@cached_property
Expand All @@ -96,7 +97,7 @@ def _has_event_inputs(self) -> bool:

@cached_property
def _has_inputs(self) -> bool:
return self._has_field_inputs or self._has_event_inputs
return self.has_connected_field_inputs or self._has_event_inputs

async def read(self, timeout: float | None = None) -> None:
"""Reads data and/or events from input channels.
Expand Down Expand Up @@ -139,7 +140,7 @@ async def read(self, timeout: float | None = None) -> None:

def _set_read_tasks(self) -> list[asyncio.Task]:
read_tasks: list[asyncio.Task] = []
if self._has_field_inputs:
if self.has_connected_field_inputs:
if _fields_read_task not in self._read_tasks:
read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task)
self._read_tasks[_fields_read_task] = read_fields_task
Expand Down Expand Up @@ -374,7 +375,7 @@ def _add_channel_for_event(

def _create_input_field_group_tasks(self) -> None:
"""Groups input field channels by field name and launches read tasks for group inputs."""
if not self._has_field_inputs:
if not self.has_connected_field_inputs:
return
field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list)
for key, chan in self._input_channels.items():
Expand Down
32 changes: 27 additions & 5 deletions plugboard/library/data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
**kwargs: Additional keyword arguments for [`Component`][plugboard.component.Component].
"""
super().__init__(**kwargs)
# Use a single buffer to track everything
self._buffer: dict[str, deque] = defaultdict(deque)
self._chunk_size = chunk_size
self.io = IOController(
Expand Down Expand Up @@ -76,18 +77,39 @@ async def _convert(self, data: dict[str, deque]) -> _t.Any:
def _bind_inputs(self) -> None:
"""Binds input fields to component fields and append to internal buffer."""
super()._bind_inputs()
for field in self.io.inputs:
for field in self._field_inputs:
value = getattr(self, field, None)
self._buffer[field].append(value)

@property
def _completed_rows(self) -> int:
"""Calculates how many fully formed rows exist in the buffer."""
if not self.io.inputs:
return 0
return min((len(self._buffer[f]) for f in self.io.inputs), default=0)

@property
def _can_step(self) -> bool:
"""We can step if we have at least one fully formed row."""
return self._completed_rows > 0

async def _save_chunk(self) -> None:
"""Write data from the buffer."""
"""Write completed data rows from the buffer."""
completed_rows = self._completed_rows
if completed_rows == 0:
return

if self._task is not None:
await self._task
# Create task to save next chunk of data
chunk = await self._convert(self._buffer)

# Extract only the completed rows into a new chunk
chunk_data = {
field: deque([self._buffer[field].popleft() for _ in range(completed_rows)])
for field in self.io.inputs
}

chunk = await self._convert(chunk_data)
self._task = asyncio.create_task(self._save(chunk))
self._buffer = defaultdict(deque)

async def step(self) -> None:
"""Trigger save when buffer is at target size."""
Expand Down
179 changes: 179 additions & 0 deletions tests/integration/test_process_with_components_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from plugboard.events import Event
from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError
from plugboard.library import FileWriter
from plugboard.process import LocalProcess, Process, RayProcess
from plugboard.schemas import ConnectorSpec, Status
from tests.conftest import ComponentTestHelper, zmq_connector_cls
Expand Down Expand Up @@ -459,6 +460,99 @@ async def test_event_driven_process_shutdown(
await process.destroy()


class MessageEventData(BaseModel):
"""Data for a message event."""

message: str


class MessageEvent(Event):
"""Event carrying a file-writer message."""

type: _t.ClassVar[str] = "message_event"
data: MessageEventData


class MessageEventGenerator(ComponentTestHelper):
"""Produces a fixed number of message events."""

io = IO(output_events=[MessageEvent])

def __init__(
self,
iters: int,
*args: _t.Any,
delay: float = 0.0,
start: int = 0,
stride: int = 1,
**kwargs: _t.Any,
) -> None:
super().__init__(*args, **kwargs)
self._iters = iters
self._delay = delay
self._start = start
self._stride = stride

async def init(self) -> None:
await super().init()
self._seq = iter(range(self._start, self._start + self._iters * self._stride, self._stride))

async def step(self) -> None:
# Optional delay to simulate staggered event arrival
if self._delay > 0.0:
await asyncio.sleep(self._delay)
try:
idx = next(self._seq)
except StopIteration:
await self.io.close()
else:
evt = MessageEvent(
source=self.name,
data=MessageEventData(message=f"Message {idx}"),
)
self.io.queue_event(evt)
await super().step()


class EventReaderFileWriter(FileWriter):
"""`FileWriter` variant that adds event handling instead of a connector for `message`."""

io = IO(input_events=[MessageEvent])

@MessageEvent.handler
async def handle_message(self, event: MessageEvent) -> None:
self.message = event.data.message


@pytest.mark.asyncio
async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None:
"""Test that field-input components can be reused in event-driven processes."""
output_path = tmp_path / "output_messages.csv"
components = [
MessageEventGenerator(iters=3, name="message_event_generator"),
EventReaderFileWriter(
path=output_path,
name="event_reader_file_writer",
field_names=["message"],
),
]
event_connectors = AsyncioConnector.builder().build_event_connectors(components)
process = LocalProcess(components=components, connectors=event_connectors)

await process.init()
await process.run()

assert process.status == Status.COMPLETED
assert output_path.read_text().splitlines() == [
"message",
"Message 0",
"Message 1",
"Message 2",
]

await process.destroy()


_SHORT_TIMEOUT = 0.1


Expand Down Expand Up @@ -536,3 +630,88 @@ async def test_constraint_error_stops_background_status_check() -> None:
)

await process.destroy()


class StaggeredEventFileWriter(FileWriter):
"""`FileWriter` variant that adds event handling instead of a connector for `message`."""

io = IO(input_events=[MessageEvent])

def __init__(self, *args: _t.Any, field_names: list[str], **kwargs: _t.Any) -> None:
super().__init__(*args, field_names=field_names, **kwargs)
self.step_count: int = 0
self.step_for_message: dict[str, int] = {}

@MessageEvent.handler
async def handle_message(self, event: MessageEvent) -> None:
msg = event.data.message
match event.source:
case "mg1":
self.mg1 = msg
case "mg2":
self.mg2 = msg
case "mg3":
self.mg3 = msg
case _:
raise ValueError(f"Unexpected event source: {event.source}")
self.step_for_message[msg] = self.step_count
self.step_count += 1


@pytest.mark.asyncio
@pytest_cases.parametrize(
"process_cls, connector_cls",
[
(LocalProcess, AsyncioConnector),
],
)
async def test_data_writer_handles_staggered_input_events(
process_cls: type[Process], connector_cls: type[Connector], tmp_path: Path, ray_ctx: None
) -> None:
"""Test that a FileWriter can handle input events arriving in different steps.

Input messages with data for different fields may arrive in different steps. The FileWriter
should write out a new row only when all required fields have received data, and should not
overwrite field values if only a subset of fields receive new data in a step.
"""
output_path = tmp_path / "staggered_output_messages.csv"

writer = StaggeredEventFileWriter(
path=output_path, field_names=["mg1", "mg2", "mg3"], name="writer"
)
components = [
# 3 inputs with different delays
MessageEventGenerator(iters=10, delay=0.005, start=0, stride=3, name="mg1"),
MessageEventGenerator(iters=10, delay=0.010, start=1, stride=3, name="mg2"),
MessageEventGenerator(iters=10, delay=0.020, start=2, stride=3, name="mg3"),
writer,
]

async with process_cls(
components=components,
connectors=AsyncioConnector.builder().build_event_connectors(components),
) as process:
await process.run()

with output_path.open() as f:
content = f.read().splitlines()

assert len(content) == 11 # header + 10 rows of data
assert content[0] == "mg1,mg2,mg3"
assert content[1] == "Message 0,Message 1,Message 2"
assert content[2] == "Message 3,Message 4,Message 5"
assert content[3] == "Message 6,Message 7,Message 8"
assert content[4] == "Message 9,Message 10,Message 11"
assert content[5] == "Message 12,Message 13,Message 14"
assert content[6] == "Message 15,Message 16,Message 17"
assert content[7] == "Message 18,Message 19,Message 20"
assert content[8] == "Message 21,Message 22,Message 23"
assert content[9] == "Message 24,Message 25,Message 26"
assert content[10] == "Message 27,Message 28,Message 29"

# Verify that messages from different generators were received in different steps
assert writer.step_count == 30
assert len(writer.step_for_message) == 30
assert len(set(writer.step_for_message.values())) == 30, (
"Expected each message to be received in a different step"
)
15 changes: 15 additions & 0 deletions tests/unit/test_process_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ def test_no_inputs_no_errors(self) -> None:
errors = validate_all_inputs_connected(pd)
assert errors == []

def test_missing_inputs_allowed_for_event_driven_component_reuse(self) -> None:
"""Unconnected inputs are allowed when non-system input events can populate them."""
pd = _make_process_dict(
components={
"producer": _make_component("producer", output_events=["message_event"]),
"writer": _make_component(
"writer",
inputs=["message"],
input_events=["system_stop", "message_event"],
),
},
)
errors = validate_all_inputs_connected(pd)
assert errors == []


# ---------------------------------------------------------------------------
# Tests for validate_input_events
Expand Down
Loading