From a16f15132d2253336e4d0077b186c83265949990 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Tue, 26 May 2026 00:42:48 +1000 Subject: [PATCH 1/4] feat: add minimal self-contained UnboundedSource PoC for Python SDK Add UnboundedSource, UnboundedReader, and CheckpointMark ABCs (Java semantics, Python names) plus a Splittable-DoFn wrapper so that `p | ReadFromUnboundedSource(MySource())` runs an unbounded source on the portable Fn API DirectRunner. Mirrors the in-tree periodicsequence SDF (`@DoFn.unbounded_per_element` + `restriction_tracker.defer_remainder` + ManualWatermarkEstimator) and touches no core files (no iobase.py or sdf_utils.py edits). Covers read, event-time timestamps, monotonic watermark, checkpoint-based pause/resume, bundle finalization, and the MAX-watermark done transition. Out of scope (later weeks): record-id dedup, backlog reporting, dynamic split fractions, and iobase.Read.expand() wiring. 14 deterministic tests pass, including an end-to-end DirectRunner read. GSoC 2026 Week-1 deliverable for issue #19137. --- .../python/apache_beam/io/unbounded_source.py | 489 ++++++++++++++++++ .../apache_beam/io/unbounded_source_test.py | 368 +++++++++++++ 2 files changed, 857 insertions(+) create mode 100644 sdks/python/apache_beam/io/unbounded_source.py create mode 100644 sdks/python/apache_beam/io/unbounded_source_test.py diff --git a/sdks/python/apache_beam/io/unbounded_source.py b/sdks/python/apache_beam/io/unbounded_source.py new file mode 100644 index 000000000000..f2de28f9d573 --- /dev/null +++ b/sdks/python/apache_beam/io/unbounded_source.py @@ -0,0 +1,489 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A minimal, self-contained ``UnboundedSource`` for the Python SDK. + +This module is a Week-1 proof-of-concept for GSoC 2026 (issue #19137). It brings +the Java ``UnboundedSource`` abstractions to Python and makes them *runnable* on +the portable Fn API path (e.g. the default DirectRunner) by dispatching reads +through a Splittable ``DoFn`` -- mirroring Java's +``Read.UnboundedSourceAsSDFWrapperFn``. + +Public API:: + + from apache_beam.io.unbounded_source import ( + UnboundedSource, UnboundedReader, CheckpointMark, ReadFromUnboundedSource) + + class MySource(UnboundedSource): + ... + + with beam.Pipeline() as p: + p | ReadFromUnboundedSource(MySource()) | beam.Map(print) + +Scope (deliberately minimal): read loop, event-time timestamps, monotonic +watermark reporting, checkpoint-based pause/resume (``defer_remainder``) and +bundle finalization. Out of scope for this PoC: record-id deduplication, +backlog-byte reporting, dynamic split fractions, and wiring into +``iobase.Read.expand()`` (callers use ``ReadFromUnboundedSource`` directly). The +design mirrors the in-tree streaming SDF template +``apache_beam.transforms.periodicsequence``. +""" + +# pytype: skip-file + +import dataclasses +import logging +from typing import Any +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +from apache_beam import coders +from apache_beam.coders.coders import BooleanCoder +from apache_beam.coders.coders import Coder +from apache_beam.coders.coders import NullableCoder +from apache_beam.coders.coders import TimestampCoder +from apache_beam.coders.coders import TupleCoder +from apache_beam.coders.coders import _MemoizingPickleCoder +from apache_beam.io import iobase +from apache_beam.io.watermark_estimators import ManualWatermarkEstimator +from apache_beam.runners import sdf_utils +from apache_beam.transforms import Impulse +from apache_beam.transforms import PTransform +from apache_beam.transforms import core +from apache_beam.transforms.window import TimestampedValue +from apache_beam.utils.timestamp import Duration +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + +__all__ = [ + 'CheckpointMark', + 'UnboundedReader', + 'UnboundedSource', + 'ReadFromUnboundedSource', +] + +_LOGGER = logging.getLogger(__name__) + +# Placed in the ``try_claim`` holder when the reader has no data *right now*. +# This is NOT end-of-stream -- an unbounded reader may produce more later. +_NO_DATA = object() + +_DEFAULT_POLL_INTERVAL_SECONDS = 1.0 + +# ------------------------------------------------------------------------------ +# Public abstract base classes (Java semantics, Python names). Following the +# existing iobase.py style, methods raise NotImplementedError rather than using +# a formal abc.ABC. +# ------------------------------------------------------------------------------ + + +class CheckpointMark(object): + """A durable, serializable position in an :class:`UnboundedSource`. + + Mirrors ``org.apache.beam.sdk.io.UnboundedSource.CheckpointMark``. Produced by + :meth:`UnboundedReader.get_checkpoint_mark`, encoded with + :meth:`UnboundedSource.get_checkpoint_mark_coder`, and used to resume a reader + (see :meth:`UnboundedSource.create_reader`). + """ + def finalize_checkpoint(self) -> None: + """Called once the runner has durably committed work up to this mark. + + Override to acknowledge/commit upstream (e.g. ack Pub/Sub messages). The + default is a no-op. Unlike Java, the Python bundle finalizer takes no + deadline argument, so no mapping is required here. + """ + pass + + +class UnboundedReader(object): + """Reads records from an :class:`UnboundedSource`. + + Mirrors ``UnboundedSource.UnboundedReader``. Lifecycle: exactly one + :meth:`start`, then any number of :meth:`advance` calls; whenever one returns + ``True`` the current record is available via :meth:`get_current` / + :meth:`get_current_timestamp`. A ``False`` return means "no data available + right now", which is distinct from end-of-stream: a reader signals a permanent + end by reporting a watermark of ``MAX_TIMESTAMP``. + """ + def start(self) -> bool: + """Positions at the first record; returns whether one is available.""" + raise NotImplementedError + + def advance(self) -> bool: + """Advances to the next record. ``False`` == no data *now*, not EOF.""" + raise NotImplementedError + + def get_current(self) -> Any: + """Returns the record claimed by the last successful start/advance.""" + raise NotImplementedError + + def get_current_timestamp(self) -> Timestamp: + """Returns the event-time timestamp of the current record.""" + raise NotImplementedError + + def get_watermark(self) -> Timestamp: + """A best-effort lower bound on timestamps of future records. + + Treated as monotonic by the wrapper. Return ``MAX_TIMESTAMP`` to signal that + this reader has permanently finished. + """ + raise NotImplementedError + + def get_checkpoint_mark(self) -> CheckpointMark: + """Returns a durable mark to resume from. Call only at a bundle boundary.""" + raise NotImplementedError + + def close(self) -> None: + """Releases reader resources. Default no-op.""" + pass + + +class UnboundedSource(iobase.SourceBase): + """A source producing an unbounded stream of records with checkpointing. + + Mirrors ``org.apache.beam.sdk.io.UnboundedSource``. Read it in a pipeline with + :class:`ReadFromUnboundedSource`:: + + p | ReadFromUnboundedSource(MyUnboundedSource()) + """ + def split(self, + desired_num_splits: int, + options: Optional[Any] = None) -> Iterable['UnboundedSource']: + """Splits into at most ``desired_num_splits`` independent sub-sources.""" + raise NotImplementedError + + def create_reader( + self, options: Optional[Any], + checkpoint_mark: Optional[CheckpointMark]) -> UnboundedReader: + """Creates a reader, resuming from ``checkpoint_mark`` when it is not None.""" + raise NotImplementedError + + def get_checkpoint_mark_coder(self) -> Coder: + """Returns the coder for this source's :class:`CheckpointMark` instances.""" + raise NotImplementedError + + def is_bounded(self) -> bool: + # SourceBase.is_bounded raises; an unbounded source is, by definition, not. + return False + + def default_output_coder(self) -> Coder: + # Permissive default, matching BoundedSource (iobase.py). Override for a + # tighter coder. Not wired into ReadFromUnboundedSource in this PoC. + return coders.registry.get_coder(object) + + +# ------------------------------------------------------------------------------ +# SDF wrapper internals (restriction, coder, tracker, provider). Names are +# underscore-prefixed: they are an implementation detail of +# ReadFromUnboundedSource, not public API. +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True) +class _UnboundedSourceRestriction(object): + """Durable SDF restriction describing where a reader should (re)start. + + Holds only serializable state -- never a live reader. Mirrors Java's + ``UnboundedSourceRestriction(source, checkpoint, watermark)`` plus an explicit + ``is_done`` flag for the terminal (MAX-watermark) transition. + """ + source: UnboundedSource + checkpoint_mark: Optional[CheckpointMark] = None + watermark: Timestamp = MIN_TIMESTAMP + is_done: bool = False + + +class _UnboundedSourceRestrictionCoder(Coder): + """Encodes :class:`_UnboundedSourceRestriction`. + + Shape mirrors Java's ``UnboundedSourceRestrictionCoder``: pickled source + + nullable checkpoint (encoded with the source's own checkpoint coder) + + watermark + done flag. + """ + def __init__(self, checkpoint_mark_coder: Optional[Coder] = None): + nullable_checkpoint = NullableCoder( + checkpoint_mark_coder or _MemoizingPickleCoder()) + self._tuple_coder = TupleCoder(( + _MemoizingPickleCoder(), # source + nullable_checkpoint, # checkpoint_mark (may be None) + TimestampCoder(), # watermark + BooleanCoder())) # is_done + + def encode(self, restriction: '_UnboundedSourceRestriction') -> bytes: + return self._tuple_coder.encode(( + restriction.source, + restriction.checkpoint_mark, + restriction.watermark, + restriction.is_done)) + + def decode(self, encoded: bytes) -> '_UnboundedSourceRestriction': + source, checkpoint_mark, watermark, is_done = self._tuple_coder.decode( + encoded) + return _UnboundedSourceRestriction( + source, checkpoint_mark, watermark, is_done) + + def is_deterministic(self) -> bool: + # The source and checkpoint are pickled, which is not guaranteed + # deterministic; matches the bounded SDF restriction coder in iobase.py. + return False + + +class _UnboundedSourceRestrictionTracker(iobase.RestrictionTracker): + """Drives an :class:`UnboundedReader` for one SDF restriction. + + Owns the live reader (lazily created, never serialized): both runner-initiated + ``try_split`` and the self-checkpoint ``try_split(0)`` raised by + ``defer_remainder`` must checkpoint the *same* reader. + + ``process()`` only ever sees a ``RestrictionTrackerView`` -- which hides custom + methods and whose ``try_claim`` returns just a bool -- so the freshly-read + record is handed back through a one-element holder list passed as the + ``try_claim`` *position* argument (mirrors Java's + ``tryClaim(UnboundedSourceValue[] out)``). + """ + def __init__( + self, + restriction: _UnboundedSourceRestriction, + options: Optional[Any] = None): + self._restriction = restriction + self._options = options + self._reader = None # type: Optional[UnboundedReader] + self._started = False + # True once a checkpoint has been cut this bundle (EOF or self-checkpoint). + self._checkpoint_taken = False + + def _ensure_reader(self) -> None: + if self._reader is None: + self._reader = self._restriction.source.create_reader( + self._options, self._restriction.checkpoint_mark) + + def current_restriction(self) -> _UnboundedSourceRestriction: + return self._restriction + + def try_claim(self, out: List[Any]) -> bool: + # 'out' is a one-element holder -- the only channel back to process(). It + # receives either (value, timestamp) or the _NO_DATA sentinel. + if self._restriction.is_done: + out[0] = _NO_DATA + return False + self._ensure_reader() + if not self._started: + has_data = self._reader.start() + else: + has_data = self._reader.advance() + self._started = True + if has_data: + # A record is available: always emit it. Inspect has_data BEFORE the + # watermark, because a reader may return its final record and report a + # MAX_TIMESTAMP watermark on the *same* call (meaning "nothing after + # this"). That EOF is realized on the next, data-less claim, so the record + # we just read is never dropped. (We do not touch self._restriction on the + # data path, so its identity is preserved for the finalization gate.) + out[0] = ( + self._reader.get_current(), self._reader.get_current_timestamp()) + return True + watermark = self._reader.get_watermark() + if watermark >= MAX_TIMESTAMP: + # No data and watermark at MAX: the reader permanently finished. Cut a + # final checkpoint, close, and mark the restriction done. + checkpoint = self._reader.get_checkpoint_mark() + self._reader.close() + self._reader = None + self._restriction = dataclasses.replace( + self._restriction, + checkpoint_mark=checkpoint, + watermark=MAX_TIMESTAMP, + is_done=True) + self._checkpoint_taken = True + out[0] = _NO_DATA + return False + # No data right now (not EOF): refresh the watermark so process() can + # advance it before deferring, then let process() self-checkpoint. + self._restriction = dataclasses.replace( + self._restriction, watermark=watermark) + out[0] = _NO_DATA + return True + + def try_split( + self, fraction_of_remainder + ) -> Optional[Tuple[_UnboundedSourceRestriction, + _UnboundedSourceRestriction]]: + # fraction 0 is the self-checkpoint raised by defer_remainder(); any other + # fraction cuts the same checkpoint. Returns (primary, residual) or None. + if self._reader is None or not self._started or self._restriction.is_done: + return None + checkpoint = self._reader.get_checkpoint_mark() + watermark = self._reader.get_watermark() + # Primary is finished work; it carries the checkpoint only so the DoFn can + # register finalization. The residual carries the resume state. + primary = dataclasses.replace( + self._restriction, checkpoint_mark=checkpoint, is_done=True) + residual = _UnboundedSourceRestriction( + source=self._restriction.source, + checkpoint_mark=checkpoint, + watermark=watermark, + is_done=False) + self._restriction = primary + self._checkpoint_taken = True + # The residual's reader is rebuilt from its checkpoint on resume; drop ours. + self._reader = None + return primary, residual + + def check_done(self) -> bool: + # Called after every process(); must raise if work is left unaccounted for. + if self._restriction.is_done or self._checkpoint_taken: + return True + raise ValueError( + 'UnboundedSource restriction was neither finished nor checkpointed; ' + 'process() must self-checkpoint via defer_remainder() or run to EOF: ' + '%r' % (self._restriction, )) + + def current_progress(self) -> 'iobase.RestrictionProgress': + # Backlog-based progress is out of scope; report a coarse done/not-done + # fraction so the runner has a (recommended) signal. + return iobase.RestrictionProgress( + fraction=1.0 if self._restriction.is_done else 0.0) + + def is_bounded(self) -> bool: + return False + + +class _UnboundedSourceRestrictionProvider(core.RestrictionProvider): + """Wraps an :class:`UnboundedSource` element as an SDF restriction.""" + def __init__( + self, + checkpoint_mark_coder: Optional[Coder] = None, + options: Optional[Any] = None): + self._restriction_coder = _UnboundedSourceRestrictionCoder( + checkpoint_mark_coder) + self._options = options + + def initial_restriction( + self, element: UnboundedSource) -> _UnboundedSourceRestriction: + if not isinstance(element, UnboundedSource): + raise TypeError( + 'ReadFromUnboundedSource expected an UnboundedSource element, got %r' + % (element, )) + return _UnboundedSourceRestriction(source=element) + + def create_tracker( + self, restriction: _UnboundedSourceRestriction + ) -> _UnboundedSourceRestrictionTracker: + return _UnboundedSourceRestrictionTracker( + restriction, options=self._options) + + def split(self, element, + restriction) -> Iterable[_UnboundedSourceRestriction]: + # Minimal PoC: no initial fan-out. Real splitting is future work. + yield restriction + + def restriction_size(self, element, restriction) -> int: + # Backlog estimation is out of scope; report a constant non-negative size. + return 1 + + def restriction_coder(self) -> Coder: + return self._restriction_coder + + def truncate(self, element, restriction): + # On drain, stop emitting new records (mirrors PeriodicSequence.truncate). + return None + + +def _set_watermark_if_greater(watermark_estimator, new_watermark: Timestamp): + # ManualWatermarkEstimator.set_watermark raises if the watermark regresses, so + # only ever advance it (mirrors PeriodicSequence's monotonic guard). + current = watermark_estimator.current_watermark() + if current is None or new_watermark > current: + watermark_estimator.set_watermark(new_watermark) + + +class ReadFromUnboundedSource(PTransform): + """Reads an :class:`UnboundedSource` via a Splittable ``DoFn``. + + Dispatches through an SDF that mirrors Java's + ``Read.UnboundedSourceAsSDFWrapperFn``: checkpoint-based pause/resume + (``defer_remainder``), monotonic watermark reporting, and bundle finalization. + Runs on the portable Fn API path -- the default DirectRunner routes an + ``Impulse | Map | ParDo`` pipeline to Prism/FnApiRunner:: + + p | ReadFromUnboundedSource(MyUnboundedSource()) + """ + def __init__( + self, + source: UnboundedSource, + poll_interval_seconds: float = _DEFAULT_POLL_INTERVAL_SECONDS): + if not isinstance(source, UnboundedSource): + raise TypeError('source must be an UnboundedSource, got %r' % (source, )) + super().__init__() + self._source = source + self._poll_interval_seconds = poll_interval_seconds + + def expand(self, pbegin): + source = self._source + poll_interval_seconds = self._poll_interval_seconds + provider = _UnboundedSourceRestrictionProvider( + checkpoint_mark_coder=source.get_checkpoint_mark_coder()) + + class _ReadFromUnboundedSourceDoFn(core.DoFn): + @core.DoFn.unbounded_per_element() + def process( + self, + unused_element, + bundle_finalizer=core.DoFn.BundleFinalizerParam, + tracker=core.DoFn.RestrictionParam(provider), + watermark_estimator=core.DoFn.WatermarkEstimatorParam( + ManualWatermarkEstimator.default_provider())): + # Parameter order matters: positionally-injected params (the element and + # the bundle finalizer) must precede the kwarg-injected ones (the + # restriction tracker and watermark estimator), which the SDF invoker + # passes by name (runners/common.py _get_arg_placeholders). + assert isinstance(tracker, sdf_utils.RestrictionTrackerView) + initial = tracker.current_restriction() + try: + while True: + holder = [None] + if not tracker.try_claim(holder): + break # restriction done (EOF) -> stop + record = holder[0] + if record is _NO_DATA: + # No data right now: advance the watermark and self-checkpoint so + # the runner reschedules us later. Resume via defer_remainder() + + # break -- NOT yield ProcessContinuation (the portable SDF path). + _set_watermark_if_greater( + watermark_estimator, tracker.current_restriction().watermark) + tracker.defer_remainder(Duration(seconds=poll_interval_seconds)) + break + value, record_timestamp = record + _set_watermark_if_greater(watermark_estimator, record_timestamp) + yield TimestampedValue(value, record_timestamp) + finally: + current = tracker.current_restriction() + checkpoint = current.checkpoint_mark + # Register finalization only when a real checkpoint was cut this + # bundle. Restriction identity (`current is not initial`) mirrors + # Java's reference-equality gate in Read.java. + if current is not initial and checkpoint is not None: + bundle_finalizer.register(checkpoint.finalize_checkpoint) + + return ( + pbegin + | 'Impulse' >> Impulse() + | 'EmitSource' >> core.Map(lambda _: source) + | 'ReadUnbounded' >> core.ParDo(_ReadFromUnboundedSourceDoFn())) diff --git a/sdks/python/apache_beam/io/unbounded_source_test.py b/sdks/python/apache_beam/io/unbounded_source_test.py new file mode 100644 index 000000000000..1ee53a61f123 --- /dev/null +++ b/sdks/python/apache_beam/io/unbounded_source_test.py @@ -0,0 +1,368 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for apache_beam.io.unbounded_source. + +Strategy: checkpoint/resume/watermark/coder semantics are covered by +deterministic unit tests (no pipeline, no wall clock). A single end-to-end +DirectRunner test asserts only ordering + termination -- no defer-timing +assertions, which would be flaky (cf. periodicsequence_test which skips +processing-time tests for the same reason). +""" + +# pytype: skip-file + +import logging +import unittest + +from apache_beam import coders +from apache_beam.io.unbounded_source import CheckpointMark +from apache_beam.io.unbounded_source import ReadFromUnboundedSource +from apache_beam.io.unbounded_source import UnboundedReader +from apache_beam.io.unbounded_source import UnboundedSource +from apache_beam.io.unbounded_source import _NO_DATA +from apache_beam.io.unbounded_source import _UnboundedSourceRestriction +from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionCoder +from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionTracker +from apache_beam.io.unbounded_source import _set_watermark_if_greater +from apache_beam.io.watermark_estimators import ManualWatermarkEstimator +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + +# pylint: disable=expression-not-assigned + +# ------------------------------------------------------------------------------ +# A tiny in-memory demo source: emits the integers 0..count-1, one per record, +# with event time Timestamp(index). It self-terminates (watermark -> MAX after +# the last record) so a pipeline reading it ends. Resumes from a checkpoint at +# (last_index + 1). +# ------------------------------------------------------------------------------ + + +class _CountingCheckpointMark(CheckpointMark): + def __init__(self, last_index, finalize_log=None): + self.last_index = last_index + self._finalize_log = finalize_log + + def finalize_checkpoint(self): + if self._finalize_log is not None: + self._finalize_log.append(self.last_index) + + def __eq__(self, other): + return ( + isinstance(other, _CountingCheckpointMark) and + other.last_index == self.last_index) + + def __hash__(self): + return hash(self.last_index) + + def __repr__(self): + return '_CountingCheckpointMark(last_index=%r)' % (self.last_index, ) + + +class _CountingReader(UnboundedReader): + def __init__(self, count, start_index, finalize_log=None): + self._count = count + self._next = start_index + self._current = None + self._exhausted = False + self._finalize_log = finalize_log + + def _read_next(self): + if self._next >= self._count: + self._exhausted = True + return False + self._current = self._next + self._next += 1 + return True + + def start(self): + return self._read_next() + + def advance(self): + return self._read_next() + + def get_current(self): + return self._current + + def get_current_timestamp(self): + return Timestamp(self._current) + + def get_watermark(self): + if self._exhausted: + return MAX_TIMESTAMP + if self._current is None: + return MIN_TIMESTAMP + return Timestamp(self._current) + + def get_checkpoint_mark(self): + last = self._current if self._current is not None else self._next - 1 + return _CountingCheckpointMark(last, finalize_log=self._finalize_log) + + +class CountingSource(UnboundedSource): + def __init__(self, count, finalize_log=None): + self._count = count + self._finalize_log = finalize_log + + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + start_index = 0 if checkpoint_mark is None else checkpoint_mark.last_index + 1 + return _CountingReader( + self._count, start_index, finalize_log=self._finalize_log) + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + +class _NoDataReader(UnboundedReader): + """Always reports 'no data right now' (watermark < MAX, so never EOF).""" + def start(self): + return False + + def advance(self): + return False + + def get_current(self): + raise AssertionError('no data available') + + def get_current_timestamp(self): + raise AssertionError('no data available') + + def get_watermark(self): + return Timestamp(0) + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(-1) + + +class _NoDataSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _NoDataReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + +def _new_tracker(source, checkpoint=None): + restriction = _UnboundedSourceRestriction( + source=source, checkpoint_mark=checkpoint) + return _UnboundedSourceRestrictionTracker(restriction) + + +def _claim(tracker): + """Claims once; returns (claimed_bool, holder_value).""" + holder = [None] + claimed = tracker.try_claim(holder) + return claimed, holder[0] + + +# ------------------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------------------ + + +class AbcContractTest(unittest.TestCase): + def test_checkpointmark_default_finalize_is_noop(self): + self.assertIsNone(CheckpointMark().finalize_checkpoint()) + + def test_unboundedsource_is_bounded_false(self): + self.assertFalse(CountingSource(3).is_bounded()) + + def test_reader_lifecycle_start_advance_eof(self): + reader = CountingSource(3).create_reader(None, None) + self.assertTrue(reader.start()) + self.assertEqual(reader.get_current(), 0) + self.assertEqual(reader.get_current_timestamp(), Timestamp(0)) + self.assertTrue(reader.advance()) + self.assertEqual(reader.get_current(), 1) + self.assertTrue(reader.advance()) + self.assertEqual(reader.get_current(), 2) + self.assertFalse(reader.advance()) + self.assertEqual(reader.get_watermark(), MAX_TIMESTAMP) + + +class RestrictionCoderTest(unittest.TestCase): + def test_roundtrip_no_checkpoint(self): + source = CountingSource(3) + coder = _UnboundedSourceRestrictionCoder(source.get_checkpoint_mark_coder()) + decoded = coder.decode( + coder.encode(_UnboundedSourceRestriction(source=source))) + self.assertIsNone(decoded.checkpoint_mark) + self.assertEqual(decoded.watermark, MIN_TIMESTAMP) + self.assertFalse(decoded.is_done) + reader = decoded.source.create_reader(None, None) + self.assertTrue(reader.start()) + self.assertEqual(reader.get_current(), 0) + + def test_roundtrip_with_checkpoint_resumes(self): + source = CountingSource(5) + coder = _UnboundedSourceRestrictionCoder(source.get_checkpoint_mark_coder()) + restriction = _UnboundedSourceRestriction( + source=source, + checkpoint_mark=_CountingCheckpointMark(1), + watermark=Timestamp(1), + is_done=False) + decoded = coder.decode(coder.encode(restriction)) + self.assertEqual(decoded.checkpoint_mark.last_index, 1) + self.assertEqual(decoded.watermark, Timestamp(1)) + self.assertFalse(decoded.is_done) + # A reader built from the decoded checkpoint resumes at the next index. + reader = decoded.source.create_reader(None, decoded.checkpoint_mark) + self.assertTrue(reader.start()) + self.assertEqual(reader.get_current(), 2) + + +class RestrictionTrackerTest(unittest.TestCase): + def test_claim_emits_in_order(self): + tracker = _new_tracker(CountingSource(3)) + values = [] + while True: + claimed, record = _claim(tracker) + if not claimed: + break + self.assertIsNot(record, _NO_DATA) + values.append(record[0]) + self.assertEqual(values, [0, 1, 2]) + self.assertTrue(tracker.check_done()) + + def test_claim_emits_final_record_when_watermark_is_max(self): + # Regression: a reader may return its final record (has_data True) while + # simultaneously reporting a MAX_TIMESTAMP watermark ("nothing after this"). + # The record must still be emitted; EOF is realized on the next claim. + class _FinalRecordReader(UnboundedReader): + def start(self): + return True + + def advance(self): + return False + + def get_current(self): + return 'last' + + def get_current_timestamp(self): + return Timestamp(0) + + def get_watermark(self): + return MAX_TIMESTAMP + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(0) + + class _FinalSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _FinalRecordReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + tracker = _new_tracker(_FinalSource()) + claimed, record = _claim(tracker) + self.assertTrue(claimed) + self.assertIsNot(record, _NO_DATA) + self.assertEqual(record[0], 'last') + # The next claim observes EOF and finishes (no second, phantom record). + claimed_again, _ = _claim(tracker) + self.assertFalse(claimed_again) + self.assertTrue(tracker.check_done()) + + def test_try_split_zero_produces_resumable_residual(self): + source = CountingSource(5) + tracker = _new_tracker(source) + # Claim 0 and 1. + self.assertEqual(_claim(tracker)[1][0], 0) + self.assertEqual(_claim(tracker)[1][0], 1) + + split = tracker.try_split(0) + self.assertIsNotNone(split) + primary, residual = split + self.assertTrue(primary.is_done) + self.assertFalse(residual.is_done) + self.assertEqual(residual.checkpoint_mark.last_index, 1) + # check_done passes on the (now done) primary. + self.assertTrue(tracker.check_done()) + + # Resuming from the residual continues at index 2. + resumed = _new_tracker(source, checkpoint=residual.checkpoint_mark) + self.assertEqual(_claim(resumed)[1][0], 2) + + def test_no_data_returns_sentinel_without_finishing(self): + tracker = _new_tracker(_NoDataSource()) + claimed, record = _claim(tracker) + self.assertTrue(claimed) # not EOF + self.assertIs(record, _NO_DATA) + # A self-checkpoint is still possible (poll/resume path). + self.assertIsNotNone(tracker.try_split(0)) + + def test_check_done_raises_when_not_done(self): + tracker = _new_tracker(CountingSource(3)) + with self.assertRaises(ValueError): + tracker.check_done() + + def test_is_bounded_false(self): + self.assertFalse(_new_tracker(CountingSource(3)).is_bounded()) + + +class WatermarkTest(unittest.TestCase): + def test_set_watermark_is_monotonic(self): + estimator = ManualWatermarkEstimator(None) + _set_watermark_if_greater(estimator, Timestamp(5)) + self.assertEqual(estimator.current_watermark(), Timestamp(5)) + # A regression is ignored (would otherwise raise inside set_watermark). + _set_watermark_if_greater(estimator, Timestamp(3)) + self.assertEqual(estimator.current_watermark(), Timestamp(5)) + _set_watermark_if_greater(estimator, Timestamp(7)) + self.assertEqual(estimator.current_watermark(), Timestamp(7)) + + +class FinalizationTest(unittest.TestCase): + def test_finalize_checkpoint_invoked(self): + # Authoritative finalize test at the unit level: the e2e finalize may run in + # a worker process, so its side effect is not observable from the test. + finalize_log = [] + source = CountingSource(5, finalize_log=finalize_log) + tracker = _new_tracker(source) + _claim(tracker) # 0 + _claim(tracker) # 1 + _, residual = tracker.try_split(0) + residual.checkpoint_mark.finalize_checkpoint() + self.assertEqual(finalize_log, [1]) + + +class EndToEndTest(unittest.TestCase): + def test_direct_runner_emits_all_in_order(self): + with TestPipeline() as p: + out = p | ReadFromUnboundedSource(CountingSource(5)) + self.assertFalse(out.is_bounded) + assert_that(out, equal_to([0, 1, 2, 3, 4])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From 40c3626469c765f56538f030d51a6d262cb1c786 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Thu, 28 May 2026 00:53:55 +1000 Subject: [PATCH 2/4] feat: wire UnboundedSource into iobase.Read + harden SDF tracker * iobase.Read.expand dispatches UnboundedSource through ReadFromUnboundedSource (Impulse | Map | ParDo); function-local lazy import breaks the iobase <-> unbounded_source cycle. The Read.__init__ docstring is updated to describe the new dispatch. * iobase.Read.to_runner_api_parameter widens the source isinstance to (BoundedSource, UnboundedSource) so Read(unbounded_source) graphs round-trip as READ.urn + ReadPayload(is_bounded=UNBOUNDED). Runner- side dispatch on the UNBOUNDED flag in bundle_processor stays W2. * SDF tracker correctness: - Data-path watermark now propagates reader.get_watermark() (Java Read.java:594 parity); holder is (value, record_ts, source_wm). - _UnboundedSourceRestriction adds finalization_checkpoint_mark so a done primary can carry a commit hook independent of the residual's RESUME checkpoint_mark. Coder is now a fixed 5-tuple. - try_claim / try_split close the reader before re-raising on any reader method failure; the DoFn finally is reduced to defense- in-depth for yield / downstream-raise paths and logs a warning if the private RestrictionTrackerView chain ever breaks. - ReadFromUnboundedSource validates poll_interval_seconds > 0. * Regression coverage: - File-marker side-channel tests for EOF watermark advance to MAX and reader-close on (i) reader.advance raise, (ii) reader.get_* raise, (iii) downstream yield raise. - try_split / EOF separation of finalize and resume channels. - Circular import order in 3 subprocesses (clean module cache): iobase first, unbounded_source first, lazy via Read.expand. - Cloudpickle round-trip for the transform and source. - to_runner_api / from_runner_api round-trip asserting IsBounded.UNBOUNDED enum and source recovery. unbounded_source_test 37/37, iobase_test 16/16. Tracking #19137. --- sdks/python/apache_beam/io/iobase.py | 26 +- sdks/python/apache_beam/io/iobase_test.py | 78 ++ .../python/apache_beam/io/unbounded_source.py | 323 ++++++-- .../apache_beam/io/unbounded_source_test.py | 696 +++++++++++++++++- 4 files changed, 1068 insertions(+), 55 deletions(-) diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 67d6cd358a07..ee47fb27602b 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -919,7 +919,11 @@ def __init__(self, source: SourceBase) -> None: """Initializes a Read transform. Args: - source: Data source to read from. + source: Data source to read from. A ``BoundedSource`` is wrapped in the + bounded SDF reader; an ``UnboundedSource`` is dispatched through + :class:`apache_beam.io.unbounded_source.ReadFromUnboundedSource` with + the default poll interval (users wanting a custom poll cadence must + instantiate ``ReadFromUnboundedSource`` directly). """ super().__init__() self.source = source @@ -945,6 +949,16 @@ def expand(self, pbegin): | 'EmitSource' >> core.Map(lambda _: self.source).with_output_types(BoundedSource) | SDFBoundedSourceReader(display_data)) + # Lazy import to break the iobase <-> unbounded_source cycle: the + # unbounded_source module imports iobase (UnboundedSource extends + # SourceBase). Pattern matches the _PubSubSource lazy import below. + from apache_beam.io.unbounded_source import ( + ReadFromUnboundedSource, UnboundedSource) + if isinstance(self.source, UnboundedSource): + # Delegate to the dedicated SDF PTransform; identical to the user + # writing `p | ReadFromUnboundedSource(self.source)` directly. Custom + # poll_interval_seconds requires using ReadFromUnboundedSource directly. + return pbegin | ReadFromUnboundedSource(self.source) elif isinstance(self.source, ptransform.PTransform): # The Read transform can also admit a full PTransform as an input # rather than an anctual source. If the input is a PTransform, then @@ -986,7 +1000,15 @@ def to_runner_api_parameter( timestamp_attribute=self.source.timestamp_attribute, with_attributes=self.source.with_attributes, id_attribute=self.source.id_label)) - if isinstance(self.source, BoundedSource): + # Lazy import to avoid the iobase <-> unbounded_source cycle. + from apache_beam.io.unbounded_source import UnboundedSource + if isinstance(self.source, (BoundedSource, UnboundedSource)): + # READ.urn covers both source flavours; the IsBounded enum distinguishes + # them. NB: today the bundle_processor.py IMPULSE_READ_TRANSFORM handler + # only consumes BOUNDED — the UNBOUNDED branch round-trips correctly + # through the protobuf graph but execution still flows through this + # composite's expanded sub-transforms (Impulse | Map | SDF-ParDo), not + # through the URN-handler. Runner-side UNBOUNDED dispatch is W2 work. return ( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload( diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py index eb9617cfae34..f3cc6c50ff0f 100644 --- a/sdks/python/apache_beam/io/iobase_test.py +++ b/sdks/python/apache_beam/io/iobase_test.py @@ -220,5 +220,83 @@ def test_sdf_wrap_range_source(self): self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3]) +class UseSdfUnboundedSourcesTests(unittest.TestCase): + """Mirrors UseSdfBoundedSourcesTests for the new UnboundedSource branch in + iobase.Read.expand(). Uses CountingSource from unbounded_source_test as the + fake finite UnboundedSource (avoids dragging the network in). + """ + + def test_read_dispatches_to_read_from_unbounded_source(self): + from apache_beam.io.unbounded_source_test import CountingSource + with mock.patch( + 'apache_beam.io.unbounded_source.ReadFromUnboundedSource.expand' + ) as mock_expand: + mock_expand.side_effect = ( + lambda pbegin: pbegin | beam.Impulse() | beam.Map(lambda _: 'fake')) + with beam.Pipeline() as p: + out = p | beam.io.Read(CountingSource(3)) + assert_that(out, equal_to(['fake'])) + mock_expand.assert_called_once() + + def test_read_end_to_end_unbounded(self): + from apache_beam.io.unbounded_source_test import CountingSource + with beam.Pipeline() as p: + out = p | beam.io.Read(CountingSource(5)) + assert_that(out, equal_to([0, 1, 2, 3, 4])) + + def test_read_unbounded_pcollection_is_unbounded(self): + from apache_beam.io.unbounded_source_test import CountingSource + with beam.Pipeline() as p: + out = p | beam.io.Read(CountingSource(3)) + self.assertFalse(out.is_bounded) + + def test_to_runner_api_emits_unbounded_read_payload(self): + """``Read.to_runner_api_parameter`` must serialize an UnboundedSource as + ``READ.urn`` with ``IsBounded.UNBOUNDED``. The runner-side handler is W2 + and ignores this enum today, but the wire format must round-trip + consistently for pipeline persistence / cross-runner submission. + """ + from apache_beam.io.unbounded_source_test import CountingSource + from apache_beam.portability import common_urns + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.pipeline_context import PipelineContext + + read = beam.io.Read(CountingSource(5)) + urn, payload = read.to_runner_api_parameter(PipelineContext()) + + self.assertEqual(urn, common_urns.deprecated_primitives.READ.urn) + self.assertIsInstance(payload, beam_runner_api_pb2.ReadPayload) + self.assertEqual( + payload.is_bounded, beam_runner_api_pb2.IsBounded.UNBOUNDED) + # The source field must be populated -- a non-empty FunctionSpec proto. + self.assertTrue(payload.source.urn) + + def test_read_unbounded_round_trips_through_runner_api(self): + """Encode then decode via the runner-API protobuf. The restored + transform must be a ``Read`` wrapping an equivalent UnboundedSource. + """ + from apache_beam.io.unbounded_source import UnboundedSource + from apache_beam.io.unbounded_source_test import CountingSource + from apache_beam.portability import common_urns + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.pipeline_context import PipelineContext + + original = beam.io.Read(CountingSource(7)) + context = PipelineContext() + urn, payload = original.to_runner_api_parameter(context) + + transform_proto = beam_runner_api_pb2.PTransform() + transform_proto.spec.urn = urn + restored = iobase.Read.from_runner_api_parameter( + transform_proto, payload, context) + + self.assertIsInstance(restored, iobase.Read) + self.assertIsInstance(restored.source, UnboundedSource) + self.assertIsInstance(restored.source, CountingSource) + self.assertFalse(restored.source.is_bounded()) + # Verify the source's internal state survived pickle round-trip. + self.assertEqual(restored.source._count, 7) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/io/unbounded_source.py b/sdks/python/apache_beam/io/unbounded_source.py index f2de28f9d573..d05403b132d8 100644 --- a/sdks/python/apache_beam/io/unbounded_source.py +++ b/sdks/python/apache_beam/io/unbounded_source.py @@ -20,8 +20,10 @@ This module is a Week-1 proof-of-concept for GSoC 2026 (issue #19137). It brings the Java ``UnboundedSource`` abstractions to Python and makes them *runnable* on the portable Fn API path (e.g. the default DirectRunner) by dispatching reads -through a Splittable ``DoFn`` -- mirroring Java's -``Read.UnboundedSourceAsSDFWrapperFn``. +through a Splittable ``DoFn`` -- inspired by (not a literal port of) Java's +``Read.UnboundedSourceAsSDFWrapperFn``. The streaming-SDF template followed for +the process-loop / watermark / defer plumbing is +``apache_beam.transforms.periodicsequence``. Public API:: @@ -33,14 +35,29 @@ class MySource(UnboundedSource): with beam.Pipeline() as p: p | ReadFromUnboundedSource(MySource()) | beam.Map(print) + # Equivalent (since iobase.Read.expand dispatches on source type): + # p | beam.io.Read(MySource()) | beam.Map(print) Scope (deliberately minimal): read loop, event-time timestamps, monotonic -watermark reporting, checkpoint-based pause/resume (``defer_remainder``) and -bundle finalization. Out of scope for this PoC: record-id deduplication, -backlog-byte reporting, dynamic split fractions, and wiring into -``iobase.Read.expand()`` (callers use ``ReadFromUnboundedSource`` directly). The -design mirrors the in-tree streaming SDF template -``apache_beam.transforms.periodicsequence``. +watermark reporting (including the EOF advance to ``MAX_TIMESTAMP`` so downstream +windows can close), checkpoint-based pause/resume (``defer_remainder``), +deterministic reader close on EOF / split / exception, and bundle finalization. + +Out of scope for this PoC (tracked under #19137): + * Record-id-based deduplication (Java's ``ValueWithRecordId``). + * Backlog-byte reporting (``restriction_size`` is a constant 1; per-restriction + progress is binary 0.0 / 1.0). + * Dynamic split fractions / runner-initiated work stealing. + * Initial fan-out: ``RestrictionProvider.split`` ignores ``desired_num_splits`` + and yields a single restriction. + * Source-specific checkpoint coders threaded through the SDF restriction coder + (the restriction coder always pickles checkpoint marks via the source's + ``get_checkpoint_mark_coder`` captured once at ``expand()`` time, but no + per-tracker coder dispatch). + * Reader caching across bundles (Java caches readers across split boundaries + via a Guava cache; this PoC always rebuilds the reader from the checkpoint). + * ``EmptyUnboundedSource`` terminal-state marker (this PoC uses an ``is_done`` + flag on the restriction instead). """ # pytype: skip-file @@ -108,6 +125,20 @@ def finalize_checkpoint(self) -> None: Override to acknowledge/commit upstream (e.g. ack Pub/Sub messages). The default is a no-op. Unlike Java, the Python bundle finalizer takes no deadline argument, so no mapping is required here. + + Implementations MUST be idempotent in two senses: + * Repeated calls on the same :class:`CheckpointMark` instance must be + safe (runner may retry callbacks). + * Successive calls on *different* CheckpointMark instances must be + consistent with monotonically progressing committed state (each + bundle's defer/split produces a fresh CheckpointMark covering the + records read so far; the typical Kafka-style implementation is + ``ack(self.last_offset)`` which is naturally monotonic and idempotent + on the broker side). + + The SDK's bundle finalizer (``_BundleFinalizerParam.finalize_bundle`` at + ``transforms/core.py``) catches and logs any exception raised from this + method but does NOT retry, so a transient failure is silently dropped. """ pass @@ -166,18 +197,44 @@ class UnboundedSource(iobase.SourceBase): def split(self, desired_num_splits: int, options: Optional[Any] = None) -> Iterable['UnboundedSource']: - """Splits into at most ``desired_num_splits`` independent sub-sources.""" + """Splits into at most ``desired_num_splits`` independent sub-sources. + + Each returned sub-source MUST be independent and MUST NOT share mutable + state with siblings (the runner may execute them concurrently across + workers). Mirrors Java's ``UnboundedSource.split``. Note that the current + ``ReadFromUnboundedSource`` PoC ignores ``desired_num_splits`` -- this + method is the public API but is dead code from the SDF wrapper's + perspective until W2. + """ raise NotImplementedError def create_reader( self, options: Optional[Any], checkpoint_mark: Optional[CheckpointMark]) -> UnboundedReader: - """Creates a reader, resuming from ``checkpoint_mark`` when it is not None.""" + """Creates a reader, optionally resuming from ``checkpoint_mark``. + + Contract: + * When ``checkpoint_mark`` is ``None``, the returned reader's ``start()`` + produces the very first record of the source (or returns ``False`` if + none yet). + * When ``checkpoint_mark`` is not ``None``, the returned reader's + ``start()`` produces the FIRST record strictly AFTER the position + encoded by ``checkpoint_mark``. The reader must NOT re-deliver records + already covered by the prior bundle. Mirrors Java's + ``UnboundedSource.createReader(options, checkpointMark)``. + """ raise NotImplementedError def get_checkpoint_mark_coder(self) -> Coder: - """Returns the coder for this source's :class:`CheckpointMark` instances.""" - raise NotImplementedError + """Returns the coder for this source's :class:`CheckpointMark` instances. + + Called once at pipeline construction (graph build), NOT per-bundle. Do not + perform I/O here. Subclasses MUST override; the default raises with a + helpful message naming the subclass. + """ + raise NotImplementedError( + '%s must override get_checkpoint_mark_coder() to return a Coder for ' + 'its CheckpointMark subclass.' % type(self).__name__) def is_bounded(self) -> bool: # SourceBase.is_bounded raises; an unbounded source is, by definition, not. @@ -185,7 +242,9 @@ def is_bounded(self) -> bool: def default_output_coder(self) -> Coder: # Permissive default, matching BoundedSource (iobase.py). Override for a - # tighter coder. Not wired into ReadFromUnboundedSource in this PoC. + # tighter coder. Not wired into ReadFromUnboundedSource in this PoC -- + # this method is kept as a forward-compat hook so subclasses written + # against the API today will Just Work when wiring lands in W2. return coders.registry.get_coder(object) @@ -202,46 +261,70 @@ class _UnboundedSourceRestriction(object): Holds only serializable state -- never a live reader. Mirrors Java's ``UnboundedSourceRestriction(source, checkpoint, watermark)`` plus an explicit - ``is_done`` flag for the terminal (MAX-watermark) transition. + ``is_done`` flag for the terminal (MAX-watermark) transition and a separate + ``finalization_checkpoint_mark`` so a done primary can carry a commit-hook + without polluting the RESUME-state ``checkpoint_mark`` (matches W1 design + doc v5 finding: combining the two channels causes resume/commit semantic + confusion). + + Field roles: + * ``checkpoint_mark`` -- RESUME state. A reader rebuilt from this mark + MUST produce the FIRST record strictly AFTER it (i.e. no re-delivery). + * ``finalization_checkpoint_mark`` -- COMMIT hook. Only set on a done + primary that was just cut this bundle. Registered with the runner's + bundle finalizer to acknowledge upstream (e.g. ack Pub/Sub messages). + Independent of ``checkpoint_mark`` so a residual's resume state can be + ``None`` while still recording what should be acked. """ source: UnboundedSource checkpoint_mark: Optional[CheckpointMark] = None watermark: Timestamp = MIN_TIMESTAMP is_done: bool = False + finalization_checkpoint_mark: Optional[CheckpointMark] = None class _UnboundedSourceRestrictionCoder(Coder): - """Encodes :class:`_UnboundedSourceRestriction`. + """Encodes :class:`_UnboundedSourceRestriction` as a fixed 5-tuple. - Shape mirrors Java's ``UnboundedSourceRestrictionCoder``: pickled source + - nullable checkpoint (encoded with the source's own checkpoint coder) + - watermark + done flag. + Shape: pickled source + nullable resume checkpoint (encoded with the + source's own checkpoint coder if provided, else pickle) + watermark + + done flag + nullable finalization checkpoint (same coder as resume). """ def __init__(self, checkpoint_mark_coder: Optional[Coder] = None): nullable_checkpoint = NullableCoder( checkpoint_mark_coder or _MemoizingPickleCoder()) self._tuple_coder = TupleCoder(( _MemoizingPickleCoder(), # source - nullable_checkpoint, # checkpoint_mark (may be None) + nullable_checkpoint, # checkpoint_mark (RESUME state, may be None) TimestampCoder(), # watermark - BooleanCoder())) # is_done + BooleanCoder(), # is_done + nullable_checkpoint)) # finalization_checkpoint_mark (commit hook) def encode(self, restriction: '_UnboundedSourceRestriction') -> bytes: return self._tuple_coder.encode(( restriction.source, restriction.checkpoint_mark, restriction.watermark, - restriction.is_done)) + restriction.is_done, + restriction.finalization_checkpoint_mark)) def decode(self, encoded: bytes) -> '_UnboundedSourceRestriction': - source, checkpoint_mark, watermark, is_done = self._tuple_coder.decode( - encoded) + (source, checkpoint_mark, watermark, is_done, + finalization_checkpoint_mark) = self._tuple_coder.decode(encoded) return _UnboundedSourceRestriction( - source, checkpoint_mark, watermark, is_done) + source=source, + checkpoint_mark=checkpoint_mark, + watermark=watermark, + is_done=is_done, + finalization_checkpoint_mark=finalization_checkpoint_mark) def is_deterministic(self) -> bool: # The source and checkpoint are pickled, which is not guaranteed # deterministic; matches the bounded SDF restriction coder in iobase.py. + # NOTE on forward-compat: the wire format is a fixed 5-tuple. Adding a + # 6th field in a future version would break decoding of in-flight blobs + # from older workers. If/when another field is needed, switch this to a + # length-prefixed or version-tagged encoding -- out of scope for W1. return False @@ -267,6 +350,9 @@ def __init__( self._reader = None # type: Optional[UnboundedReader] self._started = False # True once a checkpoint has been cut this bundle (EOF or self-checkpoint). + # Today this co-varies with `_restriction.is_done` (both set together at + # EOF and at try_split); kept separate as a forward-compat hook so a + # future refactor can checkpoint without finishing the restriction. self._checkpoint_taken = False def _ensure_reader(self) -> None: @@ -274,12 +360,62 @@ def _ensure_reader(self) -> None: self._reader = self._restriction.source.create_reader( self._options, self._restriction.checkpoint_mark) + def _close_reader_if_open(self) -> None: + """Idempotent reader release. Called by the EOF and split paths, and by + the DoFn's ``finally`` so an exception inside ``process()`` does not leak + sockets / file descriptors held by the live :class:`UnboundedReader`. + """ + if self._reader is None: + return + try: + self._reader.close() + except Exception: # pylint: disable=broad-except + _LOGGER.warning('Error closing UnboundedReader', exc_info=True) + finally: + self._reader = None + def current_restriction(self) -> _UnboundedSourceRestriction: return self._restriction def try_claim(self, out: List[Any]) -> bool: - # 'out' is a one-element holder -- the only channel back to process(). It - # receives either (value, timestamp) or the _NO_DATA sentinel. + """Advances the underlying reader by one record. + + Holder protocol: ``out[0]`` receives either + ``(value, record_timestamp, source_watermark)`` on the has-data path, or + the :data:`_NO_DATA` sentinel on no-data-now / EOF / already-done paths. + + The watermark in the has-data tuple is the SOURCE'S reported watermark + (``reader.get_watermark()``), not the record's event time -- matching + Java's ``UnboundedSourceValue.getWatermark()`` (Read.java:594). The + DoFn uses ``source_watermark`` to advance the output PCollection's + watermark and uses ``record_timestamp`` only to label the emitted + ``TimestampedValue``. Conflating the two would freeze the PCollection + watermark at the last record's event time, breaking out-of-order + sources and starving downstream event-time windows. + + Contract drift note: the ``RestrictionTracker`` ABC defines + ``try_claim(position)`` where ``position`` identifies a split point. We + instead use the argument as a one-element output holder, like Java's + ``tryClaim(UnboundedSourceValue[] out)``. The + ``ThreadsafeRestrictionTracker`` / ``RestrictionTrackerView`` chain + forwards the value opaquely (sdf_utils.py:75, sdf_utils.py:171), so the + mutation is visible across the lock. + + Exception safety: any exception from ``reader.start()`` / ``advance()`` / + ``get_watermark()`` etc. closes the reader before re-raising, so the + DoFn's ``finally`` does not need to traverse the SDF wrapper chain on + reader-method failures. + """ + try: + return self._try_claim_inner(out) + except Exception: + # Reader is in an indeterminate state; release its resources before + # the exception bubbles to the DoFn (which can't trust ``self._reader`` + # anymore). + self._close_reader_if_open() + raise + + def _try_claim_inner(self, out: List[Any]) -> bool: if self._restriction.is_done: out[0] = _NO_DATA return False @@ -293,24 +429,26 @@ def try_claim(self, out: List[Any]) -> bool: # A record is available: always emit it. Inspect has_data BEFORE the # watermark, because a reader may return its final record and report a # MAX_TIMESTAMP watermark on the *same* call (meaning "nothing after - # this"). That EOF is realized on the next, data-less claim, so the record - # we just read is never dropped. (We do not touch self._restriction on the - # data path, so its identity is preserved for the finalization gate.) + # this"). That EOF is realized on the next, data-less claim, so the + # record we just read is never dropped. Capture the source watermark + # alongside the record (see method docstring for why). out[0] = ( - self._reader.get_current(), self._reader.get_current_timestamp()) + self._reader.get_current(), + self._reader.get_current_timestamp(), + self._reader.get_watermark()) return True watermark = self._reader.get_watermark() if watermark >= MAX_TIMESTAMP: # No data and watermark at MAX: the reader permanently finished. Cut a # final checkpoint, close, and mark the restriction done. checkpoint = self._reader.get_checkpoint_mark() - self._reader.close() - self._reader = None + self._close_reader_if_open() self._restriction = dataclasses.replace( self._restriction, - checkpoint_mark=checkpoint, + checkpoint_mark=None, # nothing left to resume from watermark=MAX_TIMESTAMP, - is_done=True) + is_done=True, + finalization_checkpoint_mark=checkpoint) self._checkpoint_taken = True out[0] = _NO_DATA return False @@ -325,25 +463,52 @@ def try_split( self, fraction_of_remainder ) -> Optional[Tuple[_UnboundedSourceRestriction, _UnboundedSourceRestriction]]: + """Cuts a checkpoint, returning (primary, residual) or None. + + The cut checkpoint goes into ``primary.finalization_checkpoint_mark`` so + the DoFn can register a bundle-finalize callback for it. The same + checkpoint object also goes into ``residual.checkpoint_mark`` so the + resumed reader rebuilds at the correct position. The two fields are + independent on purpose (see :class:`_UnboundedSourceRestriction` + docstring): a runner that re-processes the primary alone must not see + a stale resume state, and a residual must not register finalize again + until ITS checkpoint is cut in a future bundle. + """ + try: + return self._try_split_inner(fraction_of_remainder) + except Exception: + self._close_reader_if_open() + raise + + def _try_split_inner(self, fraction_of_remainder): # fraction 0 is the self-checkpoint raised by defer_remainder(); any other # fraction cuts the same checkpoint. Returns (primary, residual) or None. if self._reader is None or not self._started or self._restriction.is_done: return None checkpoint = self._reader.get_checkpoint_mark() watermark = self._reader.get_watermark() - # Primary is finished work; it carries the checkpoint only so the DoFn can - # register finalization. The residual carries the resume state. + # Primary represents work just finished THIS bundle; it carries ONLY the + # finalize hook. checkpoint_mark on primary is None to make it obvious + # that the primary has no resume state of its own. primary = dataclasses.replace( - self._restriction, checkpoint_mark=checkpoint, is_done=True) + self._restriction, + checkpoint_mark=None, + is_done=True, + finalization_checkpoint_mark=checkpoint) + # Residual represents future work; it carries ONLY the resume state. + # finalization_checkpoint_mark is None so a future bundle does not + # double-register finalize for the same checkpoint. residual = _UnboundedSourceRestriction( source=self._restriction.source, checkpoint_mark=checkpoint, watermark=watermark, - is_done=False) + is_done=False, + finalization_checkpoint_mark=None) self._restriction = primary self._checkpoint_taken = True - # The residual's reader is rebuilt from its checkpoint on resume; drop ours. - self._reader = None + # The residual's reader is rebuilt from its checkpoint on resume; close + # ours rather than dropping a live handle on the floor. + self._close_reader_if_open() return primary, residual def check_done(self) -> bool: @@ -391,11 +556,16 @@ def create_tracker( def split(self, element, restriction) -> Iterable[_UnboundedSourceRestriction]: - # Minimal PoC: no initial fan-out. Real splitting is future work. + # Minimal PoC: no initial fan-out. ``desired_num_splits`` is *not* honored + # and ``UnboundedSource.split(desired_num_splits, options)`` is currently + # dead code from this provider's perspective. Real splitting (one + # restriction per sub-source, e.g. one per Kafka partition) is W2 work. yield restriction def restriction_size(self, element, restriction) -> int: # Backlog estimation is out of scope; report a constant non-negative size. + # This blinds Dataflow's auto-scaler and Flink's work-stealing to per- + # restriction load -- documented gap for #19137. return 1 def restriction_coder(self) -> Coder: @@ -406,9 +576,12 @@ def truncate(self, element, restriction): return None -def _set_watermark_if_greater(watermark_estimator, new_watermark: Timestamp): +def _set_watermark_if_greater( + watermark_estimator, new_watermark: Timestamp) -> None: # ManualWatermarkEstimator.set_watermark raises if the watermark regresses, so - # only ever advance it (mirrors PeriodicSequence's monotonic guard). + # only ever advance it (mirrors PeriodicSequence's monotonic guard). A + # misbehaving reader that reports a regressing watermark is silently absorbed + # here -- intentional, to keep the pipeline running through reader bugs. current = watermark_estimator.current_watermark() if current is None or new_watermark > current: watermark_estimator.set_watermark(new_watermark) @@ -431,6 +604,14 @@ def __init__( poll_interval_seconds: float = _DEFAULT_POLL_INTERVAL_SECONDS): if not isinstance(source, UnboundedSource): raise TypeError('source must be an UnboundedSource, got %r' % (source, )) + if poll_interval_seconds <= 0: + # A zero / negative poll interval would either busy-spin on no-data + # polls (poll_interval=0 -> defer_remainder(Duration(0)) -> immediate + # re-schedule) or pass a negative ``Duration`` to the runner which is + # not well-defined. Mirror Java's IllegalArgumentException posture. + raise ValueError( + 'poll_interval_seconds must be > 0, got %r' % + (poll_interval_seconds, )) super().__init__() self._source = source self._poll_interval_seconds = poll_interval_seconds @@ -441,7 +622,15 @@ def expand(self, pbegin): provider = _UnboundedSourceRestrictionProvider( checkpoint_mark_coder=source.get_checkpoint_mark_coder()) + # The DoFn is defined inside ``expand`` so it can close over the + # source-specific ``provider`` (which holds the source's checkpoint coder) + # and the user-tuned ``poll_interval_seconds``. Lifting it to module level + # would require a stateless provider (losing per-source checkpoint coder + # selection), so this is a deliberate trade-off. Cloudpickle, Beam's + # default, handles closure-defined classes; stdlib ``pickle`` does not. class _ReadFromUnboundedSourceDoFn(core.DoFn): + """SDF wrapper driving an :class:`UnboundedReader` for one restriction.""" + @core.DoFn.unbounded_per_element() def process( self, @@ -460,7 +649,14 @@ def process( while True: holder = [None] if not tracker.try_claim(holder): - break # restriction done (EOF) -> stop + # EOF (restriction is_done==True, watermark already set to MAX in + # the tracker). Mirrors Java Read.java:625 -- advance the + # watermark estimator unconditionally on the terminal path so + # downstream event-time windows can close, otherwise the + # estimator would stay at the last reported watermark. + _set_watermark_if_greater( + watermark_estimator, tracker.current_restriction().watermark) + break record = holder[0] if record is _NO_DATA: # No data right now: advance the watermark and self-checkpoint so @@ -470,17 +666,44 @@ def process( watermark_estimator, tracker.current_restriction().watermark) tracker.defer_remainder(Duration(seconds=poll_interval_seconds)) break - value, record_timestamp = record - _set_watermark_if_greater(watermark_estimator, record_timestamp) + # Data path: advance the estimator with the SOURCE's reported + # watermark (third tuple slot), NOT the record's event time. + # Mirrors Java Read.java:594. The record's event time is used + # only as the TimestampedValue label so the downstream sees the + # real per-record timestamp. + value, record_timestamp, source_watermark = record + _set_watermark_if_greater(watermark_estimator, source_watermark) yield TimestampedValue(value, record_timestamp) finally: current = tracker.current_restriction() - checkpoint = current.checkpoint_mark # Register finalization only when a real checkpoint was cut this # bundle. Restriction identity (`current is not initial`) mirrors - # Java's reference-equality gate in Read.java. - if current is not initial and checkpoint is not None: - bundle_finalizer.register(checkpoint.finalize_checkpoint) + # Java's reference-equality gate in Read.java. We read the explicit + # finalization channel, NOT ``checkpoint_mark`` (which is the + # RESUME state and may belong to the residual after a split). + finalize_mark = current.finalization_checkpoint_mark + if current is not initial and finalize_mark is not None: + bundle_finalizer.register(finalize_mark.finalize_checkpoint) + # Release the underlying reader on every exit path, including the + # exception path where a downstream yield raised between two + # try_claim calls (reader-method failures are already closed inside + # the tracker). ``RestrictionTrackerView`` does not expose the inner + # tracker, so traverse the (stable-but-private) wrapper chain. If + # the chain changes in a future Beam version we log a warning and + # let GC eventually close; never call ``close`` on an unrelated + # tracker subclass. + threadsafe = getattr( + tracker, '_threadsafe_restriction_tracker', None) + inner_tracker = getattr( + threadsafe, '_restriction_tracker', None) + if isinstance(inner_tracker, _UnboundedSourceRestrictionTracker): + inner_tracker._close_reader_if_open() + elif inner_tracker is not None or threadsafe is not None: + _LOGGER.warning( + 'UnboundedSource DoFn could not reach the inner tracker via ' + '_threadsafe_restriction_tracker._restriction_tracker; reader ' + 'close on exception path skipped, relying on GC. Beam SDF ' + 'wrapper internals may have changed -- file an issue.') return ( pbegin diff --git a/sdks/python/apache_beam/io/unbounded_source_test.py b/sdks/python/apache_beam/io/unbounded_source_test.py index 1ee53a61f123..5ec6a90bc0e7 100644 --- a/sdks/python/apache_beam/io/unbounded_source_test.py +++ b/sdks/python/apache_beam/io/unbounded_source_test.py @@ -26,10 +26,16 @@ # pytype: skip-file +import gc import logging +import os +import tempfile +import threading import unittest +import apache_beam as beam from apache_beam import coders +from apache_beam.io import unbounded_source as _unbounded_source_module from apache_beam.io.unbounded_source import CheckpointMark from apache_beam.io.unbounded_source import ReadFromUnboundedSource from apache_beam.io.unbounded_source import UnboundedReader @@ -43,6 +49,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms.window import FixedWindows from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp @@ -85,6 +92,7 @@ def __init__(self, count, start_index, finalize_log=None): self._current = None self._exhausted = False self._finalize_log = finalize_log + self.closed = False def _read_next(self): if self._next >= self._count: @@ -117,19 +125,24 @@ def get_checkpoint_mark(self): last = self._current if self._current is not None else self._next - 1 return _CountingCheckpointMark(last, finalize_log=self._finalize_log) + def close(self): + self.closed = True + class CountingSource(UnboundedSource): def __init__(self, count, finalize_log=None): self._count = count self._finalize_log = finalize_log + self.last_reader = None def split(self, desired_num_splits, options=None): return [self] def create_reader(self, options, checkpoint_mark): start_index = 0 if checkpoint_mark is None else checkpoint_mark.last_index + 1 - return _CountingReader( + self.last_reader = _CountingReader( self._count, start_index, finalize_log=self._finalize_log) + return self.last_reader def get_checkpoint_mark_coder(self): return coders.PickleCoder() @@ -167,6 +180,112 @@ def get_checkpoint_mark_coder(self): return coders.PickleCoder() +# Module-level helpers so they pickle cleanly across Beam's worker boundary. +# The DoFnReaderCloseOnExceptionTest uses ``_set_close_marker`` to install a +# tempfile path under a lock (so concurrent test runners cannot race on it), +# then waits for the reader's close() to write to it. +_READER_CLOSE_MARKER = None # set under _READER_CLOSE_MARKER_LOCK +_READER_CLOSE_MARKER_LOCK = threading.Lock() + + +def _set_close_marker(path): + with _READER_CLOSE_MARKER_LOCK: + global _READER_CLOSE_MARKER + _READER_CLOSE_MARKER = path + + +def _read_close_marker(): + with _READER_CLOSE_MARKER_LOCK: + return _READER_CLOSE_MARKER + + +class _RaisingReader(UnboundedReader): + def start(self): + return True # first record available + + def advance(self): + raise RuntimeError('reader.advance() boom') + + def get_current(self): + return 'rec' + + def get_current_timestamp(self): + return Timestamp(0) + + def get_watermark(self): + return Timestamp(0) + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(0) + + def close(self): + path = _read_close_marker() + if path is not None: + with open(path, 'a') as fp: + fp.write('closed\n') + + +class _RaisingSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _RaisingReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + +# A non-raising marker-aware source for testing DoFn-side close on the +# *downstream* yield-raise path (where the source itself is well-behaved but a +# downstream Map raises mid-bundle). Module-level for cloudpickle. +class _MarkerCloseReader(UnboundedReader): + def __init__(self): + self._idx = -1 + + def start(self): + self._idx = 0 + return True + + def advance(self): + self._idx += 1 + return self._idx < 3 + + def get_current(self): + return self._idx + + def get_current_timestamp(self): + return Timestamp(self._idx) + + def get_watermark(self): + return Timestamp(self._idx) if self._idx < 2 else MAX_TIMESTAMP + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(self._idx) + + def close(self): + path = _read_close_marker() + if path is not None: + with open(path, 'a') as fp: + fp.write('closed\n') + + +class _MarkerCloseSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _MarkerCloseReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + +def _downstream_boom(_unused): + """Module-level so it pickles cleanly through Beam's bundle worker boundary.""" + raise RuntimeError('downstream boom') + + def _new_tracker(source, checkpoint=None): restriction = _UnboundedSourceRestriction( source=source, checkpoint_mark=checkpoint) @@ -304,7 +423,13 @@ def test_try_split_zero_produces_resumable_residual(self): primary, residual = split self.assertTrue(primary.is_done) self.assertFalse(residual.is_done) + # Resume / finalize channel separation: primary carries only the + # finalize hook, residual carries only the resume state. + self.assertIsNone(primary.checkpoint_mark) + self.assertIsNotNone(primary.finalization_checkpoint_mark) + self.assertEqual(primary.finalization_checkpoint_mark.last_index, 1) self.assertEqual(residual.checkpoint_mark.last_index, 1) + self.assertIsNone(residual.finalization_checkpoint_mark) # check_done passes on the (now done) primary. self.assertTrue(tracker.check_done()) @@ -345,13 +470,15 @@ class FinalizationTest(unittest.TestCase): def test_finalize_checkpoint_invoked(self): # Authoritative finalize test at the unit level: the e2e finalize may run in # a worker process, so its side effect is not observable from the test. + # The finalize hook lives on the PRIMARY (commit channel), independent of + # the residual's resume state. finalize_log = [] source = CountingSource(5, finalize_log=finalize_log) tracker = _new_tracker(source) _claim(tracker) # 0 _claim(tracker) # 1 - _, residual = tracker.try_split(0) - residual.checkpoint_mark.finalize_checkpoint() + primary, _ = tracker.try_split(0) + primary.finalization_checkpoint_mark.finalize_checkpoint() self.assertEqual(finalize_log, [1]) @@ -362,6 +489,569 @@ def test_direct_runner_emits_all_in_order(self): self.assertFalse(out.is_bounded) assert_that(out, equal_to([0, 1, 2, 3, 4])) + def test_eof_lets_event_time_window_fire(self): + # Regression for the EOF-watermark fix: the DoFn must advance the watermark + # estimator to MAX_TIMESTAMP on the terminal claim so downstream FixedWindow + # closes. Without that advance the GBK below never fires and assert_that + # observes an empty output. + with TestPipeline() as p: + out = ( + p + | ReadFromUnboundedSource(CountingSource(5)) + | beam.WindowInto(FixedWindows(100)) + | beam.Map(lambda v: ('all', v)) + | beam.GroupByKey() + | beam.MapTuple(lambda _key, values: sorted(values))) + assert_that(out, equal_to([[0, 1, 2, 3, 4]])) + + def test_read_dispatches_through_iobase_read(self): + # Parity check: `beam.io.Read(unbounded_source)` must produce the same + # records as `ReadFromUnboundedSource(unbounded_source)`. The dispatch is + # the elif branch added to iobase.Read.expand(). + with TestPipeline() as p: + out = p | beam.io.Read(CountingSource(5)) + self.assertFalse(out.is_bounded) + assert_that(out, equal_to([0, 1, 2, 3, 4])) + + +# ------------------------------------------------------------------------------ +# Regression tests for the BLOCKER fixes (EOF watermark, reader close on every +# exit path) plus contract regressions (NotImplementedError message, +# finalize_checkpoint idempotency). +# ------------------------------------------------------------------------------ + + +class ReaderCloseTest(unittest.TestCase): + """Reader lifecycle: close() must run on every tracker-driven exit path.""" + + def test_tracker_closes_reader_on_eof(self): + source = CountingSource(0) # immediately exhausted + tracker = _new_tracker(source) + holder = [None] + self.assertFalse(tracker.try_claim(holder)) + self.assertIsNone(tracker._reader) + self.assertTrue(source.last_reader.closed) + + def test_tracker_closes_reader_on_split(self): + source = CountingSource(5) + tracker = _new_tracker(source) + _claim(tracker) # creates reader, claims 0 + reader = source.last_reader + self.assertFalse(reader.closed) + split = tracker.try_split(0) + self.assertIsNotNone(split) + self.assertIsNone(tracker._reader) + self.assertTrue(reader.closed) + + def test_close_helper_is_idempotent_and_safe_on_empty_tracker(self): + tracker = _new_tracker(CountingSource(3)) + # No reader yet -- helper must be a no-op. + tracker._close_reader_if_open() + _claim(tracker) + reader = tracker._reader + tracker._close_reader_if_open() + self.assertTrue(reader.closed) + self.assertIsNone(tracker._reader) + # Second call is a no-op (no reader to close). + tracker._close_reader_if_open() + + def test_close_helper_swallows_reader_close_errors(self): + class _BoomReader(UnboundedReader): + def start(self): + return True + + def advance(self): + return False + + def get_current(self): + return 'x' + + def get_current_timestamp(self): + return Timestamp(0) + + def get_watermark(self): + return Timestamp(0) + + def get_checkpoint_mark(self): + return CheckpointMark() + + def close(self): + raise RuntimeError('close blew up') + + class _BoomSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _BoomReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + tracker = _new_tracker(_BoomSource()) + _claim(tracker) + # Helper must not propagate the reader's close() exception, otherwise the + # DoFn's finally / split paths would mask the original error. + tracker._close_reader_if_open() + self.assertIsNone(tracker._reader) + + +class BestPracticeRegressionTest(unittest.TestCase): + """Regression guards for the round-2 best-practice fixes: + B1: data-path watermark uses source.get_watermark(), not record event time + B2: finalization_checkpoint_mark separate from resume checkpoint_mark + H4: tracker-internal exception close on reader-method failure + """ + + def test_b1_data_path_holder_carries_source_watermark(self): + """The holder's 3rd slot is the SOURCE's reported watermark, not the + record's event time. A reader that reports event time 1000 with a source + watermark of 990 (out-of-order data) must surface 990 to the wrapper, not + 1000. + """ + class _LaggingReader(UnboundedReader): + def start(self): + return True + + def advance(self): + return False + + def get_current(self): + return 'rec' + + def get_current_timestamp(self): + return Timestamp(1000) # record event time + + def get_watermark(self): + return Timestamp(990) # source watermark lags 10us behind + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(0) + + class _LaggingSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _LaggingReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + tracker = _new_tracker(_LaggingSource()) + claimed, record = _claim(tracker) + self.assertTrue(claimed) + self.assertIsNot(record, _NO_DATA) + value, record_timestamp, source_watermark = record + self.assertEqual(value, 'rec') + self.assertEqual(record_timestamp, Timestamp(1000)) + # Critical: watermark slot is the SOURCE watermark, NOT record timestamp. + self.assertEqual(source_watermark, Timestamp(990)) + self.assertNotEqual(source_watermark, record_timestamp) + + def test_b2_split_separates_finalize_and_resume_channels(self): + source = CountingSource(5) + tracker = _new_tracker(source) + _claim(tracker) # claim 0 so reader has progress + primary, residual = tracker.try_split(0) + # Primary carries ONLY the finalize hook -- no resume state. + self.assertIsNone(primary.checkpoint_mark) + self.assertIsNotNone(primary.finalization_checkpoint_mark) + self.assertTrue(primary.is_done) + # Residual carries ONLY the resume state -- no finalize hook (a future + # bundle that splits THIS residual will produce ITS own finalize mark). + self.assertIsNotNone(residual.checkpoint_mark) + self.assertIsNone(residual.finalization_checkpoint_mark) + self.assertFalse(residual.is_done) + # The two marks reference the same underlying checkpoint object. + self.assertEqual( + primary.finalization_checkpoint_mark.last_index, + residual.checkpoint_mark.last_index) + + def test_b2_eof_populates_finalize_and_clears_resume(self): + # EOF transition: restriction.checkpoint_mark goes to None (no more + # records to resume from), finalization_checkpoint_mark carries the + # final commit hook. + source = CountingSource(0) # immediately exhausted + tracker = _new_tracker(source) + holder = [None] + self.assertFalse(tracker.try_claim(holder)) + r = tracker.current_restriction() + self.assertTrue(r.is_done) + self.assertEqual(r.watermark, MAX_TIMESTAMP) + self.assertIsNone(r.checkpoint_mark) + self.assertIsNotNone(r.finalization_checkpoint_mark) + + def test_h4_tracker_closes_reader_when_advance_raises(self): + # If reader.advance() raises, the tracker's try_claim wraps it and + # closes the reader BEFORE re-raising. The DoFn's finally does not need + # to traverse the private SDF chain for reader-method failures. + class _BoomReader(UnboundedReader): + def __init__(self): + self.closed = False + + def start(self): + return True + + def advance(self): + raise RuntimeError('advance boom') + + def get_current(self): + return 'first' + + def get_current_timestamp(self): + return Timestamp(0) + + def get_watermark(self): + return Timestamp(0) + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(0) + + def close(self): + self.closed = True + + class _BoomSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _BoomReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + src = _BoomSource() + tracker = _new_tracker(src) + # First claim succeeds (start returns True). + self.assertTrue(tracker.try_claim([None])) + reader_after_first = tracker._reader + self.assertIsNotNone(reader_after_first) + # Second claim invokes advance() which raises. Tracker must close the + # reader before propagating the exception. + with self.assertRaises(RuntimeError): + tracker.try_claim([None]) + self.assertTrue(reader_after_first.closed) + self.assertIsNone(tracker._reader) + + def test_h4_tracker_closes_reader_when_get_watermark_raises(self): + # Reader method failures other than advance() also trigger close. + class _WatermarkBoomReader(UnboundedReader): + def __init__(self): + self.closed = False + + def start(self): + return False # no data -> drops into get_watermark path + + def advance(self): + return False + + def get_current(self): + raise AssertionError + + def get_current_timestamp(self): + raise AssertionError + + def get_watermark(self): + raise RuntimeError('watermark boom') + + def get_checkpoint_mark(self): + return _CountingCheckpointMark(0) + + def close(self): + self.closed = True + + class _WatermarkBoomSource(UnboundedSource): + def split(self, desired_num_splits, options=None): + return [self] + + def create_reader(self, options, checkpoint_mark): + return _WatermarkBoomReader() + + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + + src = _WatermarkBoomSource() + tracker = _new_tracker(src) + with self.assertRaises(RuntimeError): + tracker.try_claim([None]) + self.assertIsNone(tracker._reader) + + +class UnboundedSourceContractTest(unittest.TestCase): + def test_get_checkpoint_mark_coder_default_names_subclass(self): + class MySource(UnboundedSource): + pass + + with self.assertRaises(NotImplementedError) as cm: + MySource().get_checkpoint_mark_coder() + self.assertIn('MySource', str(cm.exception)) + + def test_default_finalize_is_idempotent(self): + mark = CheckpointMark() + # Default no-op must tolerate repeated invocation; the SDK's bundle + # finalizer makes no exactly-once guarantee on this callback. + self.assertIsNone(mark.finalize_checkpoint()) + self.assertIsNone(mark.finalize_checkpoint()) + + +class ReadFromUnboundedSourceValidationTest(unittest.TestCase): + def test_non_source_argument_raises(self): + with self.assertRaises(TypeError): + ReadFromUnboundedSource('not-a-source') # type: ignore[arg-type] + + def test_poll_interval_must_be_positive(self): + src = CountingSource(3) + with self.assertRaises(ValueError): + ReadFromUnboundedSource(src, poll_interval_seconds=0) + with self.assertRaises(ValueError): + ReadFromUnboundedSource(src, poll_interval_seconds=-1) + with self.assertRaises(ValueError): + ReadFromUnboundedSource(src, poll_interval_seconds=-0.5) + # Positive values OK. + ReadFromUnboundedSource(src, poll_interval_seconds=0.001) + ReadFromUnboundedSource(src, poll_interval_seconds=60) + + +class CloudpicklePicklabilityTest(unittest.TestCase): + """The DoFn class is defined inside ``ReadFromUnboundedSource.expand`` so it + can close over the source-specific provider. Beam's default pickler is + cloudpickle; stdlib pickle would fail on a closure-defined class. This is a + regression guard for cross-runner portability (Dataflow / Flink portable + workers also use cloudpickle). + """ + + def test_transform_round_trips_through_cloudpickle(self): + from apache_beam.internal import pickler + transform = ReadFromUnboundedSource(CountingSource(5)) + blob = pickler.dumps(transform) + self.assertIsInstance(blob, bytes) + restored = pickler.loads(blob) + self.assertIsInstance(restored, ReadFromUnboundedSource) + + def test_source_object_round_trips_through_cloudpickle(self): + from apache_beam.internal import pickler + src = CountingSource(5) + restored = pickler.loads(pickler.dumps(src)) + self.assertIsInstance(restored, CountingSource) + self.assertEqual(restored._count, 5) + + +class CircularImportOrderTest(unittest.TestCase): + """`iobase.py` and `unbounded_source.py` form a cycle (UnboundedSource extends + iobase.SourceBase; iobase.Read.expand lazy-imports unbounded_source). All + three import-order scenarios must complete without ImportError. Subprocesses + ensure each test starts from a clean module cache. + """ + + def _run_in_subprocess(self, script): + import subprocess + import sys + import os + env = os.environ.copy() + beam_python = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( + __file__))))) + env['PYTHONPATH'] = beam_python + os.pathsep + env.get('PYTHONPATH', '') + fd, path = tempfile.mkstemp(suffix='.py') + try: + with os.fdopen(fd, 'w') as fp: + fp.write(script) + return subprocess.run( + [sys.executable, path], + capture_output=True, + text=True, + env=env, + timeout=60) + finally: + if os.path.exists(path): + os.unlink(path) + + def test_iobase_then_unbounded_source(self): + result = self._run_in_subprocess( + 'import apache_beam.io.iobase\n' + 'import apache_beam.io.unbounded_source\n' + 'print("ok")\n') + self.assertEqual( + result.returncode, 0, + 'stderr=%r stdout=%r' % (result.stderr, result.stdout)) + self.assertIn('ok', result.stdout) + + def test_unbounded_source_then_iobase(self): + result = self._run_in_subprocess( + 'import apache_beam.io.unbounded_source\n' + 'import apache_beam.io.iobase\n' + 'print("ok")\n') + self.assertEqual( + result.returncode, 0, + 'stderr=%r stdout=%r' % (result.stderr, result.stdout)) + self.assertIn('ok', result.stdout) + + def test_read_expand_lazy_imports_unbounded_source(self): + # Import only iobase, then trigger Read.expand() on an UnboundedSource. + # The expand() must lazy-import unbounded_source without ImportError and + # produce a valid expanded transform tree. + script = ''' +import sys +import apache_beam as beam +from apache_beam import coders +import apache_beam.io.iobase as iobase +# Now import unbounded_source AFTER iobase, then verify Read.expand +# successfully lazy-imports ReadFromUnboundedSource: +from apache_beam.io.unbounded_source import UnboundedSource + +class _S(UnboundedSource): + def split(self, n, options=None): + return [self] + def create_reader(self, o, cp): + return None + def get_checkpoint_mark_coder(self): + return coders.PickleCoder() + +r = iobase.Read(_S()) +p = beam.Pipeline() +result = r.expand(p) +assert not result.is_bounded, 'expanded PCollection should be unbounded' +print("ok") +''' + result = self._run_in_subprocess(script) + self.assertEqual( + result.returncode, 0, + 'stderr=%r stdout=%r' % (result.stderr, result.stdout)) + self.assertIn('ok', result.stdout) + + +class DoFnReaderCloseOnDownstreamRaiseTest(unittest.TestCase): + """H4 second half: tracker-internal exception close (already tested in + ``BestPracticeRegressionTest.test_h4_*``) handles reader-method failures. + This test covers the OTHER half -- the source is well-behaved but a + downstream transform raises during ``yield``, so the exception happens + AFTER ``try_claim`` returns with a live reader. The DoFn's ``finally`` + must close it via the private SDF chain. + """ + + def test_dofn_finally_closes_reader_when_downstream_yield_raises(self): + marker = _new_marker_path('.downstream.close.log') + _set_close_marker(marker) + try: + raised = False + try: + with beam.Pipeline() as p: + _ = ( + p + | beam.io.Read(_MarkerCloseSource()) + | 'BoomMap' >> beam.Map(_downstream_boom)) + except Exception: # pylint: disable=broad-except + raised = True + gc.collect() + self.assertTrue( + raised, + 'pipeline did not surface the downstream Map exception') + self.assertTrue( + os.path.exists(marker), + 'DoFn finally did not invoke reader.close() on the downstream ' + 'yield-raise path -- reader leaked. Private-chain close in ' + 'unbounded_source.py:expand finally may be broken.') + finally: + _set_close_marker(None) + if os.path.exists(marker): + os.unlink(marker) + + +# ------------------------------------------------------------------------------ +# Stronger regression guards (added after independent second-opinion review). +# The windowed e2e test above is suggestive but not bulletproof, because the +# FnApiRunner watermark manager advances PCollection watermarks to MAX once a +# bundle has no deferred work (fn_runner.py ~819 and ~969). These tests probe +# the DoFn-level behavior directly via file-based side-channels so the BLOCKER +# fixes cannot regress silently. (In-memory closures don't propagate across +# Beam's cloudpickle worker boundary even when the worker runs in the same +# process, so we go through the filesystem.) +# ------------------------------------------------------------------------------ + + +def _new_marker_path(suffix): + """Create a fresh temp file path used as a cross-bundle side-channel. + + Returns a path that does NOT exist (deleted after mkstemp). The DoFn-side + code writes to it; the test reads it back. + """ + fd, path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + os.unlink(path) + return path + + +class DoFnWatermarkAdvanceTest(unittest.TestCase): + """B-1 regression: the DoFn MUST advance the watermark estimator to + MAX_TIMESTAMP on the terminal claim, not rely on the runner's auto-advance. + """ + + def test_eof_invokes_set_watermark_with_max_timestamp(self): + marker = _new_marker_path('.watermarks.log') + + original = _unbounded_source_module._set_watermark_if_greater + + def _recording(estimator, watermark): + # File side-channel: closure variables are deep-copied across Beam's + # bundle boundary even in embedded FnApiRunner; the filesystem is the + # only reliable cross-bundle assertion target. + with open(marker, 'a') as fp: + fp.write(repr(watermark) + '\n') + return original(estimator, watermark) + + _unbounded_source_module._set_watermark_if_greater = _recording + try: + with TestPipeline() as p: + _ = p | ReadFromUnboundedSource(CountingSource(3)) + finally: + _unbounded_source_module._set_watermark_if_greater = original + + try: + with open(marker) as fp: + lines = fp.read().splitlines() + finally: + if os.path.exists(marker): + os.unlink(marker) + + self.assertIn( + repr(MAX_TIMESTAMP), + lines, + '_set_watermark_if_greater was never called with MAX_TIMESTAMP -- ' + 'the EOF branch in process() is not advancing the estimator. ' + 'Captured calls: %r' % (lines, )) + + +class DoFnReaderCloseOnExceptionTest(unittest.TestCase): + """B-2 regression: the DoFn's ``finally`` MUST close the reader even when + ``process()`` raises mid-bundle, otherwise we leak sockets/fds in production. + """ + + def test_reader_close_runs_when_process_raises(self): + marker = _new_marker_path('.close.log') + _set_close_marker(marker) + try: + raised = False + try: + with beam.Pipeline() as p: + _ = p | ReadFromUnboundedSource(_RaisingSource()) + except Exception: # pylint: disable=broad-except + raised = True + self.assertTrue( + raised, 'pipeline did not surface the reader.advance() exception') + # Generator finalisation (which runs the DoFn's ``finally``) may be + # deferred to GC inside Beam's bundle processor; force it here so the + # close-marker is observable. + gc.collect() + self.assertTrue( + os.path.exists(marker), + 'DoFn finally did not invoke reader.close() on the exception path ' + '-- reader leaked.') + finally: + _set_close_marker(None) + if os.path.exists(marker): + os.unlink(marker) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 5a278f7624f1dcc0833802ce1f9b4be7d0bff163 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Thu, 28 May 2026 01:45:50 +1000 Subject: [PATCH 3/4] fix unbounded source split and coder wiring --- sdks/python/apache_beam/io/iobase.py | 2 +- .../python/apache_beam/io/unbounded_source.py | 61 ++++-- .../apache_beam/io/unbounded_source_test.py | 199 ++++++++++++------ 3 files changed, 186 insertions(+), 76 deletions(-) diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index ee47fb27602b..4fa00cb50552 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -1005,7 +1005,7 @@ def to_runner_api_parameter( if isinstance(self.source, (BoundedSource, UnboundedSource)): # READ.urn covers both source flavours; the IsBounded enum distinguishes # them. NB: today the bundle_processor.py IMPULSE_READ_TRANSFORM handler - # only consumes BOUNDED — the UNBOUNDED branch round-trips correctly + # only consumes BOUNDED - the UNBOUNDED branch round-trips correctly # through the protobuf graph but execution still flows through this # composite's expanded sub-transforms (Impulse | Map | SDF-ParDo), not # through the URN-handler. Runner-side UNBOUNDED dispatch is W2 work. diff --git a/sdks/python/apache_beam/io/unbounded_source.py b/sdks/python/apache_beam/io/unbounded_source.py index d05403b132d8..f2f0a2abaa3d 100644 --- a/sdks/python/apache_beam/io/unbounded_source.py +++ b/sdks/python/apache_beam/io/unbounded_source.py @@ -48,8 +48,9 @@ class MySource(UnboundedSource): * Backlog-byte reporting (``restriction_size`` is a constant 1; per-restriction progress is binary 0.0 / 1.0). * Dynamic split fractions / runner-initiated work stealing. - * Initial fan-out: ``RestrictionProvider.split`` ignores ``desired_num_splits`` - and yields a single restriction. + * Initial fan-out uses a fixed default split count (20), matching Java's + wrapper default. There is no public Python SDF hook to pass a runner-chosen + desired split count yet. * Source-specific checkpoint coders threaded through the SDF restriction coder (the restriction coder always pickles checkpoint marks via the source's ``get_checkpoint_mark_coder`` captured once at ``expand()`` time, but no @@ -103,6 +104,7 @@ class MySource(UnboundedSource): _NO_DATA = object() _DEFAULT_POLL_INTERVAL_SECONDS = 1.0 +_DEFAULT_DESIRED_NUM_SPLITS = 20 # ------------------------------------------------------------------------------ # Public abstract base classes (Java semantics, Python names). Following the @@ -202,9 +204,10 @@ def split(self, Each returned sub-source MUST be independent and MUST NOT share mutable state with siblings (the runner may execute them concurrently across workers). Mirrors Java's ``UnboundedSource.split``. Note that the current - ``ReadFromUnboundedSource`` PoC ignores ``desired_num_splits`` -- this - method is the public API but is dead code from the SDF wrapper's - perspective until W2. + ``ReadFromUnboundedSource`` calls this during initial SDF splitting with a + fixed default desired split count. Dynamic re-splitting of the source itself + remains out of scope; once a checkpoint exists the wrapper keeps that + restriction intact. """ raise NotImplementedError @@ -242,9 +245,8 @@ def is_bounded(self) -> bool: def default_output_coder(self) -> Coder: # Permissive default, matching BoundedSource (iobase.py). Override for a - # tighter coder. Not wired into ReadFromUnboundedSource in this PoC -- - # this method is kept as a forward-compat hook so subclasses written - # against the API today will Just Work when wiring lands in W2. + # tighter coder. ReadFromUnboundedSource uses this coder when inferring the + # output PCollection type/coder. return coders.registry.get_coder(object) @@ -556,11 +558,33 @@ def create_tracker( def split(self, element, restriction) -> Iterable[_UnboundedSourceRestriction]: - # Minimal PoC: no initial fan-out. ``desired_num_splits`` is *not* honored - # and ``UnboundedSource.split(desired_num_splits, options)`` is currently - # dead code from this provider's perspective. Real splitting (one - # restriction per sub-source, e.g. one per Kafka partition) is W2 work. - yield restriction + if restriction.is_done or restriction.checkpoint_mark is not None: + yield restriction + return + + try: + split_sources = list( + restriction.source.split(_DEFAULT_DESIRED_NUM_SPLITS, self._options)) + if not split_sources: + yield restriction + return + for split_source in split_sources: + if not isinstance(split_source, UnboundedSource): + raise TypeError( + 'UnboundedSource.split() produced %r, expected UnboundedSource' % + (split_source, )) + for split_source in split_sources: + yield dataclasses.replace( + restriction, + source=split_source, + checkpoint_mark=None, + is_done=False, + finalization_checkpoint_mark=None) + except Exception: # pylint: disable=broad-except + _LOGGER.warning( + 'Exception while splitting UnboundedSource. Source not split.', + exc_info=True) + yield restriction def restriction_size(self, element, restriction) -> int: # Backlog estimation is out of scope; report a constant non-negative size. @@ -619,6 +643,7 @@ def __init__( def expand(self, pbegin): source = self._source poll_interval_seconds = self._poll_interval_seconds + output_coder = source.default_output_coder() provider = _UnboundedSourceRestrictionProvider( checkpoint_mark_coder=source.get_checkpoint_mark_coder()) @@ -705,8 +730,16 @@ def process( 'close on exception path skipped, relying on GC. Beam SDF ' 'wrapper internals may have changed -- file an issue.') - return ( + output = ( pbegin | 'Impulse' >> Impulse() | 'EmitSource' >> core.Map(lambda _: source) | 'ReadUnbounded' >> core.ParDo(_ReadFromUnboundedSourceDoFn())) + try: + output.element_type = output_coder.to_type_hint() + except NotImplementedError: + pass + return output + + def _infer_output_coder(self, input_type=None, input_coder=None): + return self._source.default_output_coder() diff --git a/sdks/python/apache_beam/io/unbounded_source_test.py b/sdks/python/apache_beam/io/unbounded_source_test.py index 5ec6a90bc0e7..e8443d6078cc 100644 --- a/sdks/python/apache_beam/io/unbounded_source_test.py +++ b/sdks/python/apache_beam/io/unbounded_source_test.py @@ -30,7 +30,7 @@ import logging import os import tempfile -import threading +import time import unittest import apache_beam as beam @@ -43,12 +43,15 @@ from apache_beam.io.unbounded_source import _NO_DATA from apache_beam.io.unbounded_source import _UnboundedSourceRestriction from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionCoder +from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionProvider from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionTracker from apache_beam.io.unbounded_source import _set_watermark_if_greater +from apache_beam.runners import sdf_utils from apache_beam.io.watermark_estimators import ManualWatermarkEstimator from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import core from apache_beam.transforms.window import FixedWindows from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP @@ -139,7 +142,8 @@ def split(self, desired_num_splits, options=None): return [self] def create_reader(self, options, checkpoint_mark): - start_index = 0 if checkpoint_mark is None else checkpoint_mark.last_index + 1 + start_index = ( + 0 if checkpoint_mark is None else checkpoint_mark.last_index + 1) self.last_reader = _CountingReader( self._count, start_index, finalize_log=self._finalize_log) return self.last_reader @@ -148,6 +152,23 @@ def get_checkpoint_mark_coder(self): return coders.PickleCoder() +class _StringCountingReader(_CountingReader): + def get_current(self): + return 'v%s' % self._current + + +class _StringCountingSource(CountingSource): + def create_reader(self, options, checkpoint_mark): + start_index = ( + 0 if checkpoint_mark is None else checkpoint_mark.last_index + 1) + self.last_reader = _StringCountingReader( + self._count, start_index, finalize_log=self._finalize_log) + return self.last_reader + + def default_output_coder(self): + return coders.StrUtf8Coder() + + class _NoDataReader(UnboundedReader): """Always reports 'no data right now' (watermark < MAX, so never EOF).""" def start(self): @@ -180,26 +201,10 @@ def get_checkpoint_mark_coder(self): return coders.PickleCoder() -# Module-level helpers so they pickle cleanly across Beam's worker boundary. -# The DoFnReaderCloseOnExceptionTest uses ``_set_close_marker`` to install a -# tempfile path under a lock (so concurrent test runners cannot race on it), -# then waits for the reader's close() to write to it. -_READER_CLOSE_MARKER = None # set under _READER_CLOSE_MARKER_LOCK -_READER_CLOSE_MARKER_LOCK = threading.Lock() - - -def _set_close_marker(path): - with _READER_CLOSE_MARKER_LOCK: - global _READER_CLOSE_MARKER - _READER_CLOSE_MARKER = path - - -def _read_close_marker(): - with _READER_CLOSE_MARKER_LOCK: - return _READER_CLOSE_MARKER - - class _RaisingReader(UnboundedReader): + def __init__(self, marker_path): + self._marker_path = marker_path + def start(self): return True # first record available @@ -219,18 +224,20 @@ def get_checkpoint_mark(self): return _CountingCheckpointMark(0) def close(self): - path = _read_close_marker() - if path is not None: - with open(path, 'a') as fp: + if self._marker_path is not None: + with open(self._marker_path, 'a') as fp: fp.write('closed\n') class _RaisingSource(UnboundedSource): + def __init__(self, marker_path=None): + self._marker_path = marker_path + def split(self, desired_num_splits, options=None): return [self] def create_reader(self, options, checkpoint_mark): - return _RaisingReader() + return _RaisingReader(self._marker_path) def get_checkpoint_mark_coder(self): return coders.PickleCoder() @@ -240,7 +247,8 @@ def get_checkpoint_mark_coder(self): # *downstream* yield-raise path (where the source itself is well-behaved but a # downstream Map raises mid-bundle). Module-level for cloudpickle. class _MarkerCloseReader(UnboundedReader): - def __init__(self): + def __init__(self, marker_path): + self._marker_path = marker_path self._idx = -1 def start(self): @@ -264,28 +272,25 @@ def get_checkpoint_mark(self): return _CountingCheckpointMark(self._idx) def close(self): - path = _read_close_marker() - if path is not None: - with open(path, 'a') as fp: + if self._marker_path is not None: + with open(self._marker_path, 'a') as fp: fp.write('closed\n') class _MarkerCloseSource(UnboundedSource): + def __init__(self, marker_path=None): + self._marker_path = marker_path + def split(self, desired_num_splits, options=None): return [self] def create_reader(self, options, checkpoint_mark): - return _MarkerCloseReader() + return _MarkerCloseReader(self._marker_path) def get_checkpoint_mark_coder(self): return coders.PickleCoder() -def _downstream_boom(_unused): - """Module-level so it pickles cleanly through Beam's bundle worker boundary.""" - raise RuntimeError('downstream boom') - - def _new_tracker(source, checkpoint=None): restriction = _UnboundedSourceRestriction( source=source, checkpoint_mark=checkpoint) @@ -355,6 +360,61 @@ def test_roundtrip_with_checkpoint_resumes(self): self.assertEqual(reader.get_current(), 2) +class RestrictionProviderTest(unittest.TestCase): + def test_initial_split_calls_source_split(self): + split_log = [] + + class _NamedSource(CountingSource): + def __init__(self, name): + super().__init__(0) + self.name = name + + def split(self, desired_num_splits, options=None): + split_log.append((desired_num_splits, options)) + return [_NamedSource('a'), _NamedSource('b')] + + source = _NamedSource('root') + provider = _UnboundedSourceRestrictionProvider(options='opts') + restriction = _UnboundedSourceRestriction( + source=source, watermark=Timestamp(7)) + + splits = list(provider.split(source, restriction)) + + self.assertEqual(split_log, [(20, 'opts')]) + self.assertEqual([split.source.name for split in splits], ['a', 'b']) + self.assertEqual([split.watermark for split in splits], [Timestamp(7)] * 2) + self.assertTrue(all(split.checkpoint_mark is None for split in splits)) + self.assertTrue( + all(split.finalization_checkpoint_mark is None for split in splits)) + + def test_initial_split_does_not_split_checkpointed_restriction(self): + split_log = [] + + class _SplitSource(CountingSource): + def split(self, desired_num_splits, options=None): + split_log.append((desired_num_splits, options)) + return [self] + + source = _SplitSource(5) + provider = _UnboundedSourceRestrictionProvider(options='opts') + restriction = _UnboundedSourceRestriction( + source=source, checkpoint_mark=_CountingCheckpointMark(2)) + + self.assertEqual(list(provider.split(source, restriction)), [restriction]) + self.assertEqual(split_log, []) + + def test_initial_split_falls_back_to_original_on_split_error(self): + class _BoomSource(CountingSource): + def split(self, desired_num_splits, options=None): + raise RuntimeError('split boom') + + source = _BoomSource(5) + provider = _UnboundedSourceRestrictionProvider() + restriction = _UnboundedSourceRestriction(source=source) + + self.assertEqual(list(provider.split(source, restriction)), [restriction]) + + class RestrictionTrackerTest(unittest.TestCase): def test_claim_emits_in_order(self): tracker = _new_tracker(CountingSource(3)) @@ -513,6 +573,12 @@ def test_read_dispatches_through_iobase_read(self): self.assertFalse(out.is_bounded) assert_that(out, equal_to([0, 1, 2, 3, 4])) + def test_source_default_output_coder_sets_output_type(self): + with TestPipeline() as p: + out = p | ReadFromUnboundedSource(_StringCountingSource(2)) + self.assertEqual(out.element_type, str) + assert_that(out, equal_to(['v0', 'v1'])) + # ------------------------------------------------------------------------------ # Regression tests for the BLOCKER fixes (EOF watermark, reader close on every @@ -924,36 +990,40 @@ def get_checkpoint_mark_coder(self): class DoFnReaderCloseOnDownstreamRaiseTest(unittest.TestCase): """H4 second half: tracker-internal exception close (already tested in ``BestPracticeRegressionTest.test_h4_*``) handles reader-method failures. - This test covers the OTHER half -- the source is well-behaved but a - downstream transform raises during ``yield``, so the exception happens - AFTER ``try_claim`` returns with a live reader. The DoFn's ``finally`` - must close it via the private SDF chain. + This test covers the OTHER half -- the source is well-behaved but the + generator receives an exception at the ``yield`` point, so the exception + happens AFTER ``try_claim`` returns with a live reader. The DoFn's + ``finally`` must close it via the private SDF chain. """ def test_dofn_finally_closes_reader_when_downstream_yield_raises(self): marker = _new_marker_path('.downstream.close.log') - _set_close_marker(marker) try: - raised = False + source = _MarkerCloseSource(marker) + p = beam.Pipeline() + out = p | ReadFromUnboundedSource(source) + dofn = out.producer.transform.fn + inner_tracker = _UnboundedSourceRestrictionTracker( + _UnboundedSourceRestriction(source=source)) + tracker = sdf_utils.RestrictionTrackerView( + sdf_utils.ThreadsafeRestrictionTracker(inner_tracker)) + generator = dofn.process( + None, + bundle_finalizer=core.DoFn.BundleFinalizerParam(), + tracker=tracker, + watermark_estimator=ManualWatermarkEstimator(None)) + + next(generator) try: - with beam.Pipeline() as p: - _ = ( - p - | beam.io.Read(_MarkerCloseSource()) - | 'BoomMap' >> beam.Map(_downstream_boom)) - except Exception: # pylint: disable=broad-except - raised = True - gc.collect() + generator.throw(RuntimeError('downstream boom')) + except RuntimeError: + pass self.assertTrue( - raised, - 'pipeline did not surface the downstream Map exception') - self.assertTrue( - os.path.exists(marker), + _wait_for_marker(marker), 'DoFn finally did not invoke reader.close() on the downstream ' 'yield-raise path -- reader leaked. Private-chain close in ' 'unbounded_source.py:expand finally may be broken.') finally: - _set_close_marker(None) if os.path.exists(marker): os.unlink(marker) @@ -982,6 +1052,16 @@ def _new_marker_path(suffix): return path +def _wait_for_marker(path, timeout_secs=5): + deadline = time.time() + timeout_secs + while time.time() < deadline: + gc.collect() + if os.path.exists(path): + return True + time.sleep(0.05) + return os.path.exists(path) + + class DoFnWatermarkAdvanceTest(unittest.TestCase): """B-1 regression: the DoFn MUST advance the watermark estimator to MAX_TIMESTAMP on the terminal claim, not rely on the runner's auto-advance. @@ -1029,26 +1109,23 @@ class DoFnReaderCloseOnExceptionTest(unittest.TestCase): def test_reader_close_runs_when_process_raises(self): marker = _new_marker_path('.close.log') - _set_close_marker(marker) try: raised = False try: with beam.Pipeline() as p: - _ = p | ReadFromUnboundedSource(_RaisingSource()) + _ = p | ReadFromUnboundedSource(_RaisingSource(marker)) except Exception: # pylint: disable=broad-except raised = True self.assertTrue( raised, 'pipeline did not surface the reader.advance() exception') # Generator finalisation (which runs the DoFn's ``finally``) may be - # deferred to GC inside Beam's bundle processor; force it here so the - # close-marker is observable. - gc.collect() + # deferred inside Beam's bundle processor; wait briefly so the + # close-marker is observable in slow test environments. self.assertTrue( - os.path.exists(marker), + _wait_for_marker(marker), 'DoFn finally did not invoke reader.close() on the exception path ' '-- reader leaked.') finally: - _set_close_marker(None) if os.path.exists(marker): os.unlink(marker) From 73d06aa6851038b33badf6511ced227654cd1541 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Thu, 28 May 2026 02:42:51 +1000 Subject: [PATCH 4/4] fix: address codex review on coder wiring + harness-accurate close test * default_output_coder wiring (HIGH): setting output.element_type alone let the registry's default coder for the type_hint silently shadow the source's declared coder. Now also register the source-declared coder against the element type via coders.registry.register_coder before assigning element_type, so the runner's coder lookup returns the source's coder. Registration failures (parameterised coders) are logged as a warning instead of crashing the pipeline build. * Provider.split validation (LOW): a non-UnboundedSource returned from source.split() is a source-contract violation, not a split-refusal. Move the isinstance check OUTSIDE the try/except around source.split so we fail loudly instead of silently running single-shard. * DoFn yield-raise close test (MEDIUM): the previous unit test used generator.throw(RuntimeError) which doesn't match Beam's SDK harness path (the harness raises in receiver.receive *outside* the user generator, then drops the generator). Switch to generator.close() which triggers GeneratorExit at the active yield -- the actual cleanup path Beam takes. Also add an integration test that runs a real pipeline with a downstream Map that raises, exercising common._OutputHandler.handle_process_outputs end-to-end. 42/42 unbounded_source_test, 16/16 iobase_test. Tracking #19137. --- .../python/apache_beam/io/unbounded_source.py | 68 ++++++++++++----- .../apache_beam/io/unbounded_source_test.py | 73 ++++++++++++++++--- 2 files changed, 113 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/io/unbounded_source.py b/sdks/python/apache_beam/io/unbounded_source.py index f2f0a2abaa3d..3b4a1cd169c4 100644 --- a/sdks/python/apache_beam/io/unbounded_source.py +++ b/sdks/python/apache_beam/io/unbounded_source.py @@ -562,29 +562,41 @@ def split(self, element, yield restriction return + # Only catch errors raised BY ``source.split`` -- that path is user code + # and may legitimately refuse to split (network glitch, partition fetch + # error, etc.). Falling back to a single restriction matches Java's + # ``BoundedSourceAsSDF`` behaviour. try: split_sources = list( restriction.source.split(_DEFAULT_DESIRED_NUM_SPLITS, self._options)) - if not split_sources: - yield restriction - return - for split_source in split_sources: - if not isinstance(split_source, UnboundedSource): - raise TypeError( - 'UnboundedSource.split() produced %r, expected UnboundedSource' % - (split_source, )) - for split_source in split_sources: - yield dataclasses.replace( - restriction, - source=split_source, - checkpoint_mark=None, - is_done=False, - finalization_checkpoint_mark=None) except Exception: # pylint: disable=broad-except _LOGGER.warning( 'Exception while splitting UnboundedSource. Source not split.', exc_info=True) yield restriction + return + + if not split_sources: + yield restriction + return + + # Validation lives OUTSIDE the try/except above. A non-UnboundedSource + # returned from ``source.split`` is a source-contract violation, not a + # split-refusal, and must fail loudly rather than silently running + # single-shard. + for split_source in split_sources: + if not isinstance(split_source, UnboundedSource): + raise TypeError( + 'UnboundedSource.split() produced %r, expected UnboundedSource' % + (split_source, )) + + for split_source in split_sources: + yield dataclasses.replace( + restriction, + source=split_source, + checkpoint_mark=None, + is_done=False, + finalization_checkpoint_mark=None) def restriction_size(self, element, restriction) -> int: # Backlog estimation is out of scope; report a constant non-negative size. @@ -735,10 +747,32 @@ def process( | 'Impulse' >> Impulse() | 'EmitSource' >> core.Map(lambda _: source) | 'ReadUnbounded' >> core.ParDo(_ReadFromUnboundedSourceDoFn())) + # Wire the source's declared output coder onto the output PCollection. + # Setting ``element_type`` alone is not enough: the runner derives the + # PCollection's coder via ``coders.registry.get_coder(element_type)``, + # which may resolve to a registry default that does NOT match the + # source's declared coder (silently downgrading custom coders to pickle). + # Register the source-declared coder against the element type so the + # registry lookup returns it. try: - output.element_type = output_coder.to_type_hint() + type_hint = output_coder.to_type_hint() except NotImplementedError: - pass + type_hint = None + if type_hint is not None: + try: + coders.registry.register_coder(type_hint, type(output_coder)) + except Exception: # pylint: disable=broad-except + # Some Beam versions / coder classes refuse class-only registration + # (e.g. coders parameterised by non-default constructor args). The + # element_type below still flows through the registry's standard + # lookup; users with parameterised coders must register their coder + # explicitly via ``coders.registry.register_coder`` before pipeline + # construction. Logged so the gap is observable. + _LOGGER.warning( + 'Could not register %s for element type %s; users must register ' + 'their coder explicitly for non-default coders.', + type(output_coder).__name__, type_hint, exc_info=True) + output.element_type = type_hint return output def _infer_output_coder(self, input_type=None, input_coder=None): diff --git a/sdks/python/apache_beam/io/unbounded_source_test.py b/sdks/python/apache_beam/io/unbounded_source_test.py index e8443d6078cc..493afe69d32b 100644 --- a/sdks/python/apache_beam/io/unbounded_source_test.py +++ b/sdks/python/apache_beam/io/unbounded_source_test.py @@ -291,6 +291,14 @@ def get_checkpoint_mark_coder(self): return coders.PickleCoder() +def _downstream_boom(_unused): + """Module-level so it pickles cleanly through Beam's bundle worker boundary. + Used by ``DoFnReaderCloseOnDownstreamRaiseTest`` to simulate a downstream + transform that raises mid-bundle (the harness-driven yield-raise path). + """ + raise RuntimeError('downstream boom') + + def _new_tracker(source, checkpoint=None): restriction = _UnboundedSourceRestriction( source=source, checkpoint_mark=checkpoint) @@ -991,13 +999,25 @@ class DoFnReaderCloseOnDownstreamRaiseTest(unittest.TestCase): """H4 second half: tracker-internal exception close (already tested in ``BestPracticeRegressionTest.test_h4_*``) handles reader-method failures. This test covers the OTHER half -- the source is well-behaved but the - generator receives an exception at the ``yield`` point, so the exception - happens AFTER ``try_claim`` returns with a live reader. The DoFn's - ``finally`` must close it via the private SDF chain. + downstream output handler raises, so the exception happens AFTER + ``try_claim`` returned with a live reader. Beam's + ``common._OutputHandler.handle_process_outputs`` iterates the DoFn's + generator with ``for result in results`` and calls + ``receiver.receive(...)``; when a downstream receiver raises, the + exception is OUTSIDE the user generator. The SDK harness then drops + the generator (no explicit ``throw``); the generator's ``finally`` runs + when the generator is closed (``GeneratorExit``) or garbage collected. + + We exercise that path two ways: + 1. Unit-level: simulate the harness drop with ``generator.close()`` + (raises ``GeneratorExit`` at the active yield, running ``finally``). + 2. Integration: run a real pipeline with a downstream ``Map`` that + raises, and confirm the reader was closed before the pipeline + surfaced the error. """ - def test_dofn_finally_closes_reader_when_downstream_yield_raises(self): - marker = _new_marker_path('.downstream.close.log') + def test_dofn_finally_closes_reader_on_generator_close(self): + marker = _new_marker_path('.gen_close.log') try: source = _MarkerCloseSource(marker) p = beam.Pipeline() @@ -1014,15 +1034,46 @@ def test_dofn_finally_closes_reader_when_downstream_yield_raises(self): watermark_estimator=ManualWatermarkEstimator(None)) next(generator) + # Simulate the harness dropping the generator after a downstream + # receiver raised. Beam's SDK harness does NOT call + # ``generator.throw`` -- the downstream exception happens outside + # the user generator, and the harness lets GC / ``close`` clean up. + generator.close() + self.assertTrue( + _wait_for_marker(marker), + 'DoFn finally did not invoke reader.close() when the generator ' + 'was closed (GeneratorExit) -- reader leaked. Private-chain ' + 'close in unbounded_source.py:expand finally may be broken.') + finally: + if os.path.exists(marker): + os.unlink(marker) + + def test_dofn_finally_closes_reader_via_integration_pipeline(self): + """End-to-end harness coverage: a real pipeline with a downstream + ``Map`` that raises must surface the exception AND must have closed + the reader. This complements the unit-level ``generator.close`` test + above by exercising the actual SDK harness output-handler path + (``common._OutputHandler.handle_process_outputs``). + """ + marker = _new_marker_path('.integration_close.log') + try: + raised = False try: - generator.throw(RuntimeError('downstream boom')) - except RuntimeError: - pass + with beam.Pipeline() as p: + _ = ( + p + | ReadFromUnboundedSource(_MarkerCloseSource(marker)) + | 'BoomMap' >> beam.Map(_downstream_boom)) + except Exception: # pylint: disable=broad-except + raised = True + self.assertTrue( + raised, + 'pipeline did not surface the downstream Map exception') self.assertTrue( _wait_for_marker(marker), - 'DoFn finally did not invoke reader.close() on the downstream ' - 'yield-raise path -- reader leaked. Private-chain close in ' - 'unbounded_source.py:expand finally may be broken.') + 'reader leaked across the integration pipeline -- the SDK ' + 'harness path that drops the DoFn generator on downstream ' + 'failure did not trigger our finally close.') finally: if os.path.exists(marker): os.unlink(marker)