From af6d166f7b0228e25f16fa30bde5f7580ab9c189 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 22 May 2026 21:03:19 +0000 Subject: [PATCH 01/11] Created an Asynchronous Wrapper for DoFn as well as JUnit tests for the Apache Beam Java SDK (#38529) --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 689 ++++++++++++++++ .../beam/sdk/transforms/AsyncDoFnTest.java | 733 ++++++++++++++++++ 2 files changed, 1422 insertions(+) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java new file mode 100644 index 000000000000..31c7e6c3d78c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -0,0 +1,689 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + *

For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn extends DoFn, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap< + String, ConcurrentHashMap>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class InFlightElement { + final KV element; + final CompletableFuture> future; + + InFlightElement(KV element, CompletableFuture> future) { + this.element = element; + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver implements OutputReceiver { + private final List outputs = Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver(windowedValue -> outputs.add(windowedValue.getValue())); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(output); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(output); + } + + public List getOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction idFn, + boolean useThreadPool, + @Nullable Coder> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap> getProcessingElements() { + ConcurrentHashMap> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap>) + (ConcurrentHashMap) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + + if (threadPool != null) { + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + } + } catch (InterruptedException e) { + threadPool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + // Asynchronous Scheduling & Deduplication + // Submits tasks to the background thread pool. If an element with the same ID is already + // in-flight, + // the submission is silently ignored to enforce exactly-once semantics. + private boolean scheduleIfRoom( + KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { + lock.lock(); + try { + ConcurrentHashMap> activeElements = + getProcessingElements(); + Object elementId = idFn.apply(element.getValue()); + + if (activeElements.containsKey(elementId)) { + LOG.info("Item {} already in processing elements", element); + return true; + } + + int currentBuffer = getItemsInBuffer().get(); + if (currentBuffer < maxItemsToBuffer || ignoreBuffer) { + java.util.concurrent.Executor executor = + useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); + + // Pending asynchronous task that will produce a list of outputs + CompletableFuture> future = + CompletableFuture.supplyAsync( + () -> { + try { + AccumulatingOutputReceiver receiver = + new AccumulatingOutputReceiver<>(); + DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn); + + DoFnInvoker.ArgumentProvider bundleArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return doFn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions(); + } + + @Override + public void output( + OutputT output, Instant timestamp, BoundedWindow window) { + receiver.output(output); + } + + @Override + public void output( + TupleTag tag, + T output, + Instant timestamp, + BoundedWindow window) { + throw new UnsupportedOperationException( + "Tagged output not supported in FinishBundleContext for AsyncDoFn"); + } + }; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Bundle"; + } + }; + + invoker.invokeStartBundle(bundleArgProvider); + + DoFnInvoker.ArgumentProvider processArgProvider = + new DoFnInvoker.BaseArgumentProvider() { + @Override + public InputT element(DoFn doFn) { + return element.getValue(); + } + + @Override + public OutputReceiver outputReceiver( + DoFn doFn) { + return receiver; + } + + @Override + public BoundedWindow window() { + return window; + } + + @Override + public Instant timestamp(DoFn doFn) { + return timestamp; + } + + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Process"; + } + }; + + invoker.invokeProcessElement(processArgProvider); + invoker.invokeFinishBundle(bundleArgProvider); + + return receiver.getOutputs(); + } catch (Exception e) { + throw new CompletionException(e); + } + }, + executor); + + // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for + // cancellation + CompletableFuture> unused = + future.whenComplete( + (res, ex) -> { + lock.lock(); + try { + getItemsInBuffer().decrementAndGet(); + } finally { + lock.unlock(); + } + }); + + activeElements.put(elementId, new InFlightElement<>(element, future)); + getItemsInBuffer().incrementAndGet(); + return true; + } + + return false; + } finally { + lock.unlock(); + } + } + + private void scheduleItem(KV element, BoundedWindow window, Instant timestamp) { + boolean done = false; + long sleepTime = INITIAL_BACKOFF_SLEEP_MS; + long totalSleep = 0; + long timeoutMs = timeout.getMillis(); + + while (!done && totalSleep < timeoutMs) { + done = scheduleIfRoom(element, window, timestamp, false); + if (!done) { + long sleep = Math.min(maxWaitTime.getMillis(), sleepTime); + if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) { + LOG.info( + "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.", + element, + getItemsInBuffer().get(), + sleep, + totalSleep); + } + try { + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for space in buffer", e); + } + sleepTime *= 2; + totalSleep += sleep; + } + } + // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. + } + + private Instant nextTimeToFire(@Nullable K key) { + long seed = (key == null) ? 0 : key.hashCode(); + Random random = new Random(seed); + double timerFrequencySec = timerFrequency.getMillis() / 1000.0; + double nowSec = System.currentTimeMillis() / 1000.0; + + double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; + double offset = random.nextDouble() * timerFrequencySec; + + return Instant.ofEpochMilli((long) ((base + offset) * 1000)); + } + + @ProcessElement + public void processElement( + ProcessContext c, + BoundedWindow window, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer) { + + KV element = c.element(); + scheduleItem(element, window, c.timestamp()); + toProcessState.add(element); + + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext c, + @StateId("to_process") BagState> toProcessState, + @TimerId("timer") Timer timer, + OutputReceiver receiver) { + + commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver); + } + + // Synchronizes local task results with the runner's persistent state container. + // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. + private void commitFinishedItems( + Instant fireTimestamp, + BagState> toProcessState, + Timer timer, + OutputReceiver receiver) { + + Iterable> toProcessLocal = toProcessState.read(); + if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) { + // Early Exit: if BagState is empty, we skip checking activeElements for this key. + return; + } + + // Since fireTimestamp is key-scoped, we determine the current key from the first element in + // state + List> stateList = new ArrayList<>(); + K key = null; + for (KV element : toProcessLocal) { + stateList.add(element); + if (key == null) { + key = element.getKey(); + } + } + + if (verboseLogging) { + LOG.info("processing timer for key: {}", key); + } + + ConcurrentHashMap> activeElements = + getProcessingElements(); + Set stateIds = new HashSet<>(); + for (KV element : stateList) { + stateIds.add(idFn.apply(element.getValue())); + } + + List toCancel = new ArrayList<>(); + lock.lock(); + try { + // Cancel any active elements for this key that are no longer in runner's state + for (Map.Entry> entry : + activeElements.entrySet()) { + Object elementId = entry.getKey(); + InFlightElement inFlight = entry.getValue(); + + if (Objects.equals(inFlight.element.getKey(), key) && !stateIds.contains(elementId)) { + inFlight.future.cancel(true); + toCancel.add(elementId); + LOG.info("Cancelling item {} which is no longer in state", inFlight.element); + } + } + for (Object elementId : toCancel) { + activeElements.remove(elementId); + } + } finally { + lock.unlock(); + } + + List> toReturn = new ArrayList<>(); + List> finishedItems = new ArrayList<>(); + List> toReschedule = new ArrayList<>(); + + int itemsFinished = 0; + int itemsNotYetFinished = 0; + int itemsRescheduled = 0; + int itemsCancelled = toCancel.size(); + + lock.lock(); + try { + for (KV element : stateList) { + Object elementId = idFn.apply(element.getValue()); + if (activeElements.containsKey(elementId)) { + InFlightElement inFlight = activeElements.get(elementId); + if (inFlight.future.isDone()) { + try { + if (!inFlight.future.isCancelled()) { + toReturn.add(inFlight.future.get()); + } + finishedItems.add(element); + activeElements.remove(elementId); + itemsFinished++; + } catch (Exception e) { + LOG.error("Error executing async task for element {}", element, e); + finishedItems.add(element); + activeElements.remove(elementId); + } + } else { + itemsNotYetFinished++; + } + } else { + LOG.info( + "Item {} found in state but not in local active elements, scheduling now", element); + toReschedule.add(element); + itemsRescheduled++; + } + } + } finally { + lock.unlock(); + } + + // Reschedule missing elements + for (KV element : toReschedule) { + scheduleItem(element, GlobalWindow.INSTANCE, fireTimestamp); + } + + // Update State: keep only unfinished items + toProcessState.clear(); + int itemsInProcessingState = 0; + for (KV element : stateList) { + if (!finishedItems.contains(element)) { + toProcessState.add(element); + itemsInProcessingState++; + } + } + + // Emit completed outputs (Emit completed tasks immediately; do not wait for all active tasks to + // finish). + for (List outputs : toReturn) { + for (OutputT out : outputs) { + receiver.output(out); + } + } + + LOG.info( + "Items finished: {}, not yet finished: {}, rescheduled: {}, cancelled: {}, in processing state: {}", + itemsFinished, + itemsNotYetFinished, + itemsRescheduled, + itemsCancelled, + itemsInProcessingState); + + if (itemsInProcessingState > 0) { + Instant timeToFire = nextTimeToFire(key); + timer.set(timeToFire); + } + } + + // Package-private helper methods for testing direct execution without Pipeline / ProcessContext + // boilerplate + void processDirect( + KV element, + BoundedWindow window, + Instant timestamp, + BagState> toProcessState, + Timer timer) { + scheduleItem(element, window, timestamp); + toProcessState.add(element); + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + List commitFinishedItemsDirect( + Instant fireTimestamp, BagState> toProcessState, Timer timer) { + AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(); + commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); + return receiver.getOutputs(); + } + + boolean isEmpty() { + return getItemsInBuffer().get() == 0; + } + + int getItemsInBufferCount() { + return getItemsInBuffer().get(); + } + + static void resetState() { + lock.lock(); + try { + for (Map.Entry entry : pool.entrySet()) { + entry.getValue().shutdownNow(); + } + pool.clear(); + processingElements.clear(); + itemsInBuffer.clear(); + } finally { + lock.unlock(); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java new file mode 100644 index 000000000000..912aca3f309c --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for verifying async processing structures and logic. */ +@RunWith(JUnit4.class) +public class AsyncDoFnTest implements Serializable { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + private final boolean useThreadPool = true; + + // Used for testing basic DoFn processing logic with optional latency. + private static class BasicDofn extends DoFn { + private final long sleepTimeMs; + private int processed = 0; + private final ReentrantLock lock = new ReentrantLock(); + + BasicDofn(long sleepTimeMs) { + this.sleepTimeMs = sleepTimeMs; + } + + BasicDofn() { + this(0); + } + + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + if (sleepTimeMs > 0) { + try { + Thread.sleep(sleepTimeMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + lock.lock(); + try { + processed += 1; + } finally { + lock.unlock(); + } + receiver.output(element); + } + + int getProcessed() { + lock.lock(); + try { + return processed; + } finally { + lock.unlock(); + } + } + } + + // Used for testing multi element processing with optional finish bundle call. + private static class MultiElementDoFn extends DoFn { + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + receiver.output(element); + receiver.output(element); + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + c.output("bundle end", Instant.now(), GlobalWindow.INSTANCE); + } + } + + // Used for testing BagState thread safety. + private static class FakeBagState implements BagState { + private final List items; + private final ReentrantLock lock = new ReentrantLock(); + + FakeBagState(List initialItems) { + this.items = new ArrayList<>(initialItems); + } + + FakeBagState(T initialItem) { + this(new ArrayList<>(List.of(initialItem))); + } + + FakeBagState() { + this(new ArrayList<>()); + } + + @Override + public void add(T item) { + lock.lock(); + try { + items.add(item); + } finally { + lock.unlock(); + } + } + + @Override + public void clear() { + lock.lock(); + try { + items.clear(); + } finally { + lock.unlock(); + } + } + + @Override + public Iterable read() { + lock.lock(); + try { + return new ArrayList<>(items); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + lock.lock(); + try { + return items.isEmpty(); + } finally { + lock.unlock(); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public BagState readLater() { + return this; + } + } + + // 4. Used for testing Timer mock implementations. + private static class FakeTimer implements Timer { + private Instant time = Instant.EPOCH; + + @Override + public void set(Instant absoluteTime) { + this.time = absoluteTime; + } + + @Override + public void setRelative() {} + + @Override + public void clear() { + this.time = Instant.EPOCH; + } + + @Override + public Timer offset(Duration offset) { + return this; + } + + @Override + public Timer align(Duration period) { + return this; + } + + @Override + public Timer withOutputTimestamp(Instant outputTime) { + return this; + } + + @Override + public Timer withNoOutputTimestamp() { + return this; + } + + @Override + public Instant getCurrentRelativeTime() { + return time; + } + } + + @Before + public void setUp() { + AsyncDoFn.resetState(); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn) { + waitForEmpty(asyncDoFn, 10); + } + + private void waitForEmpty(AsyncDoFn asyncDoFn, int timeoutSeconds) { + int count = 0; + while (!asyncDoFn.isEmpty()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + count += 1; + if (count > timeoutSeconds) { + throw new RuntimeException("Timed out waiting for async dofn to be empty"); + } + } + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private void checkOutput(List result, List expectedOutput) { + List resultStr = new ArrayList<>(); + for (T val : result) { + resultStr.add(val.toString()); + } + List expectedStr = new ArrayList<>(); + for (T val : expectedOutput) { + expectedStr.add(val.toString()); + } + Collections.sort(resultStr); + Collections.sort(expectedStr); + assertEquals(expectedStr, resultStr); + } + + private void checkItemsInBuffer(AsyncDoFn asyncDoFn, int expectedCount) { + assertEquals(expectedCount, asyncDoFn.getItemsInBufferCount()); + } + + // Test 1: testCustomIdFn + // Verifies key extraction custom logic. Duplicate elements (same custom ID but different payload) + // should be recognized as already in-flight and deduplicated. + @Test + public void testCustomIdFn() { + class CustomIdObject implements Serializable { + final int elementId; + final String value; + + CustomIdObject(int elementId, String value) { + this.elementId = elementId; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof CustomIdObject)) return false; + CustomIdObject that = (CustomIdObject) o; + return elementId == that.elementId; + } + + @Override + public int hashCode() { + return java.util.Objects.hash(elementId); + } + + @Override + public String toString() { + return "CustomIdObject{id=" + elementId + ", val=" + value + "}"; + } + } + + class CustomIdDofn extends DoFn { + @ProcessElement + public void processElement(@Element CustomIdObject element, OutputReceiver receiver) { + receiver.output(element.value); + } + } + + CustomIdDofn dofn = new CustomIdDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + null, + null, + null, + x -> x.elementId, + useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", new CustomIdObject(1, "a")); + KV msg2 = KV.of("key1", new CustomIdObject(1, "b")); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("a")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 2: testBasic + // Verifies the standard end-to-end execution flow. Elements should be queued in persistent state + // and output correctly upon completion. + @Test + public void testBasic() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + assertNotEquals(Instant.EPOCH, fakeTimer.getCurrentRelativeTime()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 3: testMultiKey + // Verifies key grouping isolation. Firing a timer for one partition key must not release + // or interfere with elements queued under a different partition key. + @Test + public void testMultiKey() { + for (boolean useThreadPool : new boolean[] {true, false}) { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagStateKey1 = new FakeBagState<>(); + FakeBagState> fakeBagStateKey2 = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key2", "2"); + + asyncDoFn.processDirect( + msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey1, fakeTimer); + asyncDoFn.processDirect( + msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagStateKey2, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey2, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(1, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagStateKey1, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagStateKey1.items.size()); + assertEquals(0, fakeBagStateKey2.items.size()); + } + } + + // Test 4: testLongItem + // Verifies that outputs are kept in-flight and not committed prematurely if the background + // execution task has not finished processing yet. + @Test + public void testLongItem() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(0, dofn.getProcessed()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn, 20); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(1, dofn.getProcessed()); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 5: testLostItem + // Verifies if the local worker's in-memory cache is empty but the runner's + // persistent state contains pending items. + // The wrapper must automatically detect the mismatch and reschedule execution. + @Test + public void testLostItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + FakeBagState> fakeBagState = new FakeBagState<>(msg); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + } + + // Test 6: testCancelledItem + // Verifies active task cancellation. If a pending element is deleted from the runner's persistent + // state prior to a commit (e.g., due to a rollback), the background future task must be actively + // cancelled. + @Test + public void testCancelledItem() { + BasicDofn dofn = new BasicDofn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("2")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 7: testMultiElementDofn + // Verifies support for DoFns that emit multiple outputs per element, and correctly aggregates + // outputs produced during the finishBundle stage of the sync DoFn's lifecycle. + @Test + public void testMultiElementDofn() { + MultiElementDoFn dofn = new MultiElementDoFn(); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Arrays.asList("1", "1", "bundle end")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 8: testDuplicates + // Verifies deduplication of duplicate elements under active processing. + // Identical elements should not spawn multiple concurrent background executions. + @Test + public void testDuplicates() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + fakeBagState.clear(); + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 9: testBufferCount + // Verifies accurate in-flight metrics tracking. + // The item count in the buffer must increment on task scheduling + // and decrement immediately upon execution completion. + @Test + public void testBufferCount() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg = KV.of("key1", "1"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 1); + + waitForEmpty(asyncDoFn); + checkItemsInBuffer(asyncDoFn, 0); + + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + } + + // Test 10: testBufferStopsAcceptingItems + // Verifies queue boundaries and backpressure throttling. + // When concurrent threads push elements exceeding the capacity limit, + // the scheduler must block and delay submissions appropriately. + @Test + public void testBufferStopsAcceptingItems() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + 5, // max buffer capacity + null, + null, + null, + useThreadPool); + asyncDoFn.setup(null); + + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + ExecutorService poolExecutor = Executors.newFixedThreadPool(10); + List expectedOutput = new ArrayList<>(); + List> futures = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + final int idx = i; + expectedOutput.add(String.valueOf(idx)); + futures.add( + poolExecutor.submit( + () -> { + KV item = KV.of("key", String.valueOf(idx)); + asyncDoFn.processDirect( + item, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + })); + } + + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + assertEquals(5, asyncDoFn.getItemsInBufferCount()); + + waitForEmpty(asyncDoFn, 100); + + // Verify that all background tasks completed successfully without throwing exceptions + for (Future future : futures) { + try { + future.get(); // This will re-throw any exception that occurred in the background thread + } catch (Exception e) { + throw new AssertionError("Background task failed", e); + } + } + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn, 100); + + result.addAll( + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer)); + + checkOutput(result, expectedOutput); + checkItemsInBuffer(asyncDoFn, 0); + poolExecutor.shutdown(); + } + + // Test 11: testBufferWithCancellation + // Verifies backpressure behavior in conjunction with element cancellation. + // Elements that are actively cancelled during queue throttling should be dropped cleanly from the + // buffer. + @Test + public void testBufferWithCancellation() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + KV msg1 = KV.of("key1", "1"); + KV msg2 = KV.of("key1", "2"); + FakeTimer fakeTimer = new FakeTimer(); + FakeBagState> fakeBagState = new FakeBagState<>(); + + asyncDoFn.processDirect(msg1, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + asyncDoFn.processDirect(msg2, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + checkItemsInBuffer(asyncDoFn, 2); + + fakeBagState.clear(); + fakeBagState.add(msg2); + + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(1, fakeBagState.items.size()); + + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkItemsInBuffer(asyncDoFn, 0); + checkOutput(result, Collections.singletonList("2")); + } + + // Test 12: testResetStateConcurrentTeardown + // Verifies safe resource cleanup during concurrent shutdown. + // Resetting the global shared execution state while workers are running + // must complete cleanly without thread or lock deadlocks. + @Test + public void testResetStateConcurrentTeardown() { + BasicDofn dofn = new BasicDofn(500); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + asyncDoFn.processDirect( + KV.of("key1", "1"), GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Verify calling resetState() while background tasks are running finishes cleanly + AsyncDoFn.resetState(); + } +} From 218b24de97f64e3082798546cbcbe9072672ace4 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 22 May 2026 22:52:59 +0000 Subject: [PATCH 02/11] Optimize State reconciliation loop and eliminate O(N^2) complexity. Removed O(N) global activeElements scan. Fixed logic bug where duplicate elements were incorrectly marked for rescheduling. Optimized lookups by converting finishedItems from a list to a HashSet. --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 74 ++++++++----------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 31c7e6c3d78c..e499cbdf2c1e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -22,7 +22,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.UUID; @@ -99,8 +98,7 @@ public class AsyncDoFn extends DoFn, OutputT> // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. private static final ConcurrentHashMap pool = new ConcurrentHashMap<>(); // activeElements (processingElements) is global JVM memory (all keys) - private static final ConcurrentHashMap< - String, ConcurrentHashMap>> + private static final ConcurrentHashMap>> processingElements = new ConcurrentHashMap<>(); private static final ConcurrentHashMap itemsInBuffer = new ConcurrentHashMap<>(); @@ -108,12 +106,10 @@ public class AsyncDoFn extends DoFn, OutputT> private static final ReentrantLock lock = new ReentrantLock(); private static final boolean verboseLogging = false; - private static class InFlightElement { - final KV element; + private static class InFlightElement { final CompletableFuture> future; - InFlightElement(KV element, CompletableFuture> future) { - this.element = element; + InFlightElement(CompletableFuture> future) { this.future = future; } } @@ -212,13 +208,12 @@ private ExecutorService getThreadPool() { } @SuppressWarnings("unchecked") - private ConcurrentHashMap> getProcessingElements() { - ConcurrentHashMap> elements = processingElements.get(uuid); + private ConcurrentHashMap> getProcessingElements() { + ConcurrentHashMap> elements = processingElements.get(uuid); if (elements == null) { throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); } - return (ConcurrentHashMap>) - (ConcurrentHashMap) elements; + return (ConcurrentHashMap>) (ConcurrentHashMap) elements; } private AtomicInteger getItemsInBuffer() { @@ -298,8 +293,7 @@ private boolean scheduleIfRoom( KV element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { lock.lock(); try { - ConcurrentHashMap> activeElements = - getProcessingElements(); + ConcurrentHashMap> activeElements = getProcessingElements(); Object elementId = idFn.apply(element.getValue()); if (activeElements.containsKey(elementId)) { @@ -428,7 +422,7 @@ public String getErrorContext() { } }); - activeElements.put(elementId, new InFlightElement<>(element, future)); + activeElements.put(elementId, new InFlightElement<>(future)); getItemsInBuffer().incrementAndGet(); return true; } @@ -536,70 +530,60 @@ private void commitFinishedItems( LOG.info("processing timer for key: {}", key); } - ConcurrentHashMap> activeElements = - getProcessingElements(); - Set stateIds = new HashSet<>(); - for (KV element : stateList) { - stateIds.add(idFn.apply(element.getValue())); - } - - List toCancel = new ArrayList<>(); - lock.lock(); - try { - // Cancel any active elements for this key that are no longer in runner's state - for (Map.Entry> entry : - activeElements.entrySet()) { - Object elementId = entry.getKey(); - InFlightElement inFlight = entry.getValue(); - - if (Objects.equals(inFlight.element.getKey(), key) && !stateIds.contains(elementId)) { - inFlight.future.cancel(true); - toCancel.add(elementId); - LOG.info("Cancelling item {} which is no longer in state", inFlight.element); - } - } - for (Object elementId : toCancel) { - activeElements.remove(elementId); - } - } finally { - lock.unlock(); - } + ConcurrentHashMap> activeElements = getProcessingElements(); List> toReturn = new ArrayList<>(); - List> finishedItems = new ArrayList<>(); + Set> finishedItems = new HashSet<>(); List> toReschedule = new ArrayList<>(); int itemsFinished = 0; int itemsNotYetFinished = 0; int itemsRescheduled = 0; - int itemsCancelled = toCancel.size(); + int itemsCancelled = 0; + + Set finishedElementIds = new HashSet<>(); + Set inFlightElementIds = new HashSet<>(); + Set rescheduledElementIds = new HashSet<>(); lock.lock(); try { for (KV element : stateList) { Object elementId = idFn.apply(element.getValue()); + + // Skip processing if we already completed, rescheduled, or found this elementId active in + // this cycle + if (finishedElementIds.contains(elementId) + || rescheduledElementIds.contains(elementId) + || inFlightElementIds.contains(elementId)) { + continue; + } + if (activeElements.containsKey(elementId)) { - InFlightElement inFlight = activeElements.get(elementId); + InFlightElement inFlight = activeElements.get(elementId); if (inFlight.future.isDone()) { try { if (!inFlight.future.isCancelled()) { toReturn.add(inFlight.future.get()); } finishedItems.add(element); + finishedElementIds.add(elementId); activeElements.remove(elementId); itemsFinished++; } catch (Exception e) { LOG.error("Error executing async task for element {}", element, e); finishedItems.add(element); + finishedElementIds.add(elementId); activeElements.remove(elementId); } } else { + inFlightElementIds.add(elementId); itemsNotYetFinished++; } } else { LOG.info( "Item {} found in state but not in local active elements, scheduling now", element); toReschedule.add(element); + rescheduledElementIds.add(elementId); itemsRescheduled++; } } From bb15e1ded3597fcad34dd4ae9529144feccf7e18 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Wed, 27 May 2026 20:47:07 +0000 Subject: [PATCH 03/11] Added check for long overflow possibility when exponentially increasing sleep. Added two more tests to match Python SDK. Fixed formatting issues. --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 17 +- .../beam/sdk/transforms/AsyncDoFnTest.java | 168 ++++++++++++++++-- 2 files changed, 168 insertions(+), 17 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index e499cbdf2c1e..24834355fbf3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -348,7 +348,8 @@ public void output( Instant timestamp, BoundedWindow window) { throw new UnsupportedOperationException( - "Tagged output not supported in FinishBundleContext for AsyncDoFn"); + "Tagged output not supported in " + + "FinishBundleContext for AsyncDoFn"); } }; } @@ -457,7 +458,12 @@ private void scheduleItem(KV element, BoundedWindow window, Instant t Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted while waiting for space in buffer", e); } - sleepTime *= 2; + + // Prevents long overflow possibility + if (sleepTime < maxWaitTime.getMillis()) { + sleepTime *= 2; + } + totalSleep += sleep; } } @@ -606,8 +612,8 @@ private void commitFinishedItems( } } - // Emit completed outputs (Emit completed tasks immediately; do not wait for all active tasks to - // finish). + // Emit completed outputs + // (Emit completed tasks immediately; do not wait for all active tasks to finish). for (List outputs : toReturn) { for (OutputT out : outputs) { receiver.output(out); @@ -615,7 +621,8 @@ private void commitFinishedItems( } LOG.info( - "Items finished: {}, not yet finished: {}, rescheduled: {}, cancelled: {}, in processing state: {}", + "Items finished: {}, not yet finished: {}, " + + "rescheduled: {}, cancelled: {}, in processing state: {}", itemsFinished, itemsNotYetFinished, itemsRescheduled, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java index 912aca3f309c..4d9a0a5c0d72 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -271,8 +272,7 @@ private void checkItemsInBuffer(AsyncDoFn asyncDoFn, int expectedCount) } // Test 1: testCustomIdFn - // Verifies key extraction custom logic. Duplicate elements (same custom ID but different payload) - // should be recognized as already in-flight and deduplicated. + // Verifies custom ID extraction and deduplication of in-flight duplicate elements. @Test public void testCustomIdFn() { class CustomIdObject implements Serializable { @@ -477,9 +477,7 @@ public void testLostItem() { } // Test 6: testCancelledItem - // Verifies active task cancellation. If a pending element is deleted from the runner's persistent - // state prior to a commit (e.g., due to a rollback), the background future task must be actively - // cancelled. + // Verifies active task cancellation if a pending element is deleted from the runner's state. @Test public void testCancelledItem() { BasicDofn dofn = new BasicDofn(); @@ -564,7 +562,48 @@ public void testDuplicates() { assertEquals(0, fakeBagState.items.size()); } - // Test 9: testBufferCount + // Test 9: testSlowDuplicates + // Verifies that duplicate elements sent after the in-memory buffer + // has cleared are correctly tracked and processed. + @Test + public void testSlowDuplicates() { + BasicDofn dofn = new BasicDofn(5000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + KV msg = KV.of("key1", "1"); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + fakeBagState.clear(); + List result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.emptyList()); + assertEquals(0, fakeBagState.items.size()); + + asyncDoFn.processDirect(msg, GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + assertEquals(1, fakeBagState.items.size()); + waitForEmpty(asyncDoFn); + + result = + asyncDoFn.commitFinishedItemsDirect( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer); + checkOutput(result, Collections.singletonList("1")); + assertEquals(0, fakeBagState.items.size()); + } + + // Test 10: testBufferCount // Verifies accurate in-flight metrics tracking. // The item count in the buffer must increment on task scheduling // and decrement immediately upon execution completion. @@ -591,7 +630,7 @@ public void testBufferCount() { checkItemsInBuffer(asyncDoFn, 0); } - // Test 10: testBufferStopsAcceptingItems + // Test 11: testBufferStopsAcceptingItems // Verifies queue boundaries and backpressure throttling. // When concurrent threads push elements exceeding the capacity limit, // the scheduler must block and delay submissions appropriately. @@ -663,10 +702,8 @@ public void testBufferStopsAcceptingItems() { poolExecutor.shutdown(); } - // Test 11: testBufferWithCancellation - // Verifies backpressure behavior in conjunction with element cancellation. - // Elements that are actively cancelled during queue throttling should be dropped cleanly from the - // buffer. + // Test 12: testBufferWithCancellation + // Verifies actively cancelled elements are cleanly dropped from the buffer during throttling. @Test public void testBufferWithCancellation() { BasicDofn dofn = new BasicDofn(1000); @@ -703,7 +740,114 @@ public void testBufferWithCancellation() { checkOutput(result, Collections.singletonList("2")); } - // Test 12: testResetStateConcurrentTeardown + // Test 13: testLoadCorrectness + // Verifies that the async wrapper processes large concurrent volumes + // across multiple keys correctly under heavy multi-threaded load. + @Test + public void testLoadCorrectness() { + BasicDofn dofn = new BasicDofn(1000); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, + 1, + Duration.standardSeconds(5), + null, + null, + Duration.millis(10), + null, + useThreadPool); + asyncDoFn.setup(null); + + java.util.Map>> bagStates = new java.util.HashMap<>(); + java.util.Map timers = new java.util.HashMap<>(); + java.util.Map> expectedOutputs = new java.util.HashMap<>(); + + for (int i = 0; i < 10; i++) { + String key = "key" + i; + bagStates.put(key, new FakeBagState<>()); + timers.put(key, new FakeTimer()); + expectedOutputs.put(key, new ArrayList<>()); + } + + ExecutorService poolExecutor = Executors.newFixedThreadPool(10); + List> futures = new ArrayList<>(); + Random random = new Random(); + + for (int i = 0; i < 100; i++) { + final int val = i; + final String key = "key" + random.nextInt(10); + expectedOutputs.get(key).add(String.valueOf(val)); + + futures.add( + poolExecutor.submit( + () -> { + KV item = KV.of(key, String.valueOf(val)); + asyncDoFn.processDirect( + item, + GlobalWindow.INSTANCE, + Instant.now(), + bagStates.get(key), + timers.get(key)); + })); + try { + Thread.sleep(random.nextInt(200)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + try { + Thread.sleep(3000 + random.nextInt(2000)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Verify that all background tasks completed successfully + for (Future future : futures) { + try { + future.get(); + } catch (Exception e) { + throw new AssertionError("Background task failed", e); + } + } + + boolean done = false; + java.util.Map> results = new java.util.HashMap<>(); + for (int i = 0; i < 10; i++) { + results.put("key" + i, new ArrayList<>()); + } + + while (!done) { + done = true; + for (int i = 0; i < 10; i++) { + String key = "key" + i; + results + .get(key) + .addAll( + asyncDoFn.commitFinishedItemsDirect( + timers.get(key).getCurrentRelativeTime(), bagStates.get(key), timers.get(key))); + if (!bagStates.get(key).items.isEmpty()) { + done = false; + } else { + checkOutput(results.get(key), expectedOutputs.get(key)); + } + } + try { + Thread.sleep(1000 + random.nextInt(2000)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + for (int i = 0; i < 10; i++) { + String key = "key" + i; + checkOutput(results.get(key), expectedOutputs.get(key)); + assertEquals(0, bagStates.get(key).items.size()); + } + poolExecutor.shutdown(); + } + + // Test 14: testResetStateConcurrentTeardown // Verifies safe resource cleanup during concurrent shutdown. // Resetting the global shared execution state while workers are running // must complete cleanly without thread or lock deadlocks. From 6ce9ce90c71e42fb5ad1c6e072e97e1ab4fa0af2 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Wed, 27 May 2026 21:58:09 +0000 Subject: [PATCH 04/11] Fix Timestamp propagation and add relevant test too. Spotless Apply fixes. Spot Bugs potential fixes. --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 57 +++++++++++---- .../beam/sdk/transforms/AsyncDoFnTest.java | 72 ++++++++++++++++++- 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 24834355fbf3..b8ef4bc8c970 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -106,10 +106,20 @@ public class AsyncDoFn extends DoFn, OutputT> private static final ReentrantLock lock = new ReentrantLock(); private static final boolean verboseLogging = false; + private static class TimestampedOutput { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + private static class InFlightElement { - final CompletableFuture> future; + final CompletableFuture>> future; - InFlightElement(CompletableFuture> future) { + InFlightElement(CompletableFuture>> future) { this.future = future; } } @@ -119,7 +129,8 @@ private static class InFlightElement { // Buffered elements are only committed downstream once the parent task completes successfully // and the timer fires. private static class AccumulatingOutputReceiver implements OutputReceiver { - private final List outputs = Collections.synchronizedList(new ArrayList<>()); + private final List> outputs = + Collections.synchronizedList(new ArrayList<>()); @Override public org.apache.beam.sdk.values.OutputBuilder builder(T value) { @@ -128,22 +139,34 @@ public org.apache.beam.sdk.values.OutputBuilder builder(T value) { .setTimestamp(Instant.now()) .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) - .setReceiver(windowedValue -> outputs.add(windowedValue.getValue())); + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); } // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. // JVM optimization to prevent garbage collection pressure under high pipeline throughput. @Override public void output(T output) { - outputs.add(output); + outputs.add(new TimestampedOutput<>(output, null)); } @Override public void outputWithTimestamp(T output, Instant timestamp) { - outputs.add(output); + outputs.add(new TimestampedOutput<>(output, timestamp)); } public List getOutputs() { + List rawOutputs = new ArrayList<>(); + for (TimestampedOutput out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List> getTimestampedOutputs() { return outputs; } } @@ -307,7 +330,7 @@ private boolean scheduleIfRoom( useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); // Pending asynchronous task that will produce a list of outputs - CompletableFuture> future = + CompletableFuture>> future = CompletableFuture.supplyAsync( () -> { try { @@ -338,7 +361,7 @@ public PipelineOptions getPipelineOptions() { @Override public void output( OutputT output, Instant timestamp, BoundedWindow window) { - receiver.output(output); + receiver.outputWithTimestamp(output, timestamp); } @Override @@ -403,7 +426,7 @@ public String getErrorContext() { invoker.invokeProcessElement(processArgProvider); invoker.invokeFinishBundle(bundleArgProvider); - return receiver.getOutputs(); + return receiver.getTimestampedOutputs(); } catch (Exception e) { throw new CompletionException(e); } @@ -412,7 +435,7 @@ public String getErrorContext() { // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for // cancellation - CompletableFuture> unused = + CompletableFuture>> unused = future.whenComplete( (res, ex) -> { lock.lock(); @@ -509,7 +532,7 @@ public void onTimer( // Synchronizes local task results with the runner's persistent state container. // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. - private void commitFinishedItems( + void commitFinishedItems( Instant fireTimestamp, BagState> toProcessState, Timer timer, @@ -538,7 +561,7 @@ private void commitFinishedItems( ConcurrentHashMap> activeElements = getProcessingElements(); - List> toReturn = new ArrayList<>(); + List>> toReturn = new ArrayList<>(); Set> finishedItems = new HashSet<>(); List> toReschedule = new ArrayList<>(); @@ -614,9 +637,13 @@ private void commitFinishedItems( // Emit completed outputs // (Emit completed tasks immediately; do not wait for all active tasks to finish). - for (List outputs : toReturn) { - for (OutputT out : outputs) { - receiver.output(out); + for (List> outputs : toReturn) { + for (TimestampedOutput out : outputs) { + if (out.timestamp != null) { + receiver.outputWithTimestamp(out.value, out.timestamp); + } else { + receiver.output(out.value); + } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java index 4d9a0a5c0d72..9bd2d07a3558 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -107,6 +107,19 @@ public void finishBundle(FinishBundleContext c) { } } + private static class TimestampingDoFn extends DoFn { + private final Instant outputTimestamp; + + TimestampingDoFn(Instant outputTimestamp) { + this.outputTimestamp = outputTimestamp; + } + + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + receiver.outputWithTimestamp(element, outputTimestamp); + } + } + // Used for testing BagState thread safety. private static class FakeBagState implements BagState { private final List items; @@ -286,8 +299,12 @@ class CustomIdObject implements Serializable { @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof CustomIdObject)) return false; + if (this == o) { + return true; + } + if (!(o instanceof CustomIdObject)) { + return false; + } CustomIdObject that = (CustomIdObject) o; return elementId == that.elementId; } @@ -874,4 +891,55 @@ public void testResetStateConcurrentTeardown() { // Verify calling resetState() while background tasks are running finishes cleanly AsyncDoFn.resetState(); } + + // Test 15: testTimestampPropagation + // Verifies that custom timestamps output by the wrapped DoFn are correctly propagated + // and not lost or replaced during async execution. + @Test + public void testTimestampPropagation() { + Instant customTimestamp = new Instant(123456789000L); + TimestampingDoFn dofn = new TimestampingDoFn(customTimestamp); + AsyncDoFn asyncDoFn = + new AsyncDoFn<>( + dofn, 1, Duration.standardSeconds(5), null, null, null, null, useThreadPool); + asyncDoFn.setup(null); + + FakeBagState> fakeBagState = new FakeBagState<>(); + FakeTimer fakeTimer = new FakeTimer(); + + class CapturingReceiver implements DoFn.OutputReceiver { + final List values = new ArrayList<>(); + final List timestamps = new ArrayList<>(); + + @Override + public org.apache.beam.sdk.values.OutputBuilder builder(String value) { + throw new UnsupportedOperationException(); + } + + @Override + public void output(String output) { + values.add(output); + timestamps.add(null); + } + + @Override + public void outputWithTimestamp(String output, Instant timestamp) { + values.add(output); + timestamps.add(timestamp); + } + } + + CapturingReceiver capturingReceiver = new CapturingReceiver(); + + asyncDoFn.processDirect( + KV.of("key1", "val1"), GlobalWindow.INSTANCE, Instant.now(), fakeBagState, fakeTimer); + + waitForEmpty(asyncDoFn); + + asyncDoFn.commitFinishedItems( + fakeTimer.getCurrentRelativeTime(), fakeBagState, fakeTimer, capturingReceiver); + + assertEquals(Collections.singletonList("val1"), capturingReceiver.values); + assertEquals(Collections.singletonList(customTimestamp), capturingReceiver.timestamps); + } } From 4cdcfa4756619e575ef564cea845f7eb1329017e Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Thu, 28 May 2026 08:07:58 +0000 Subject: [PATCH 05/11] Resolve SpotBugs DMI_RANDOM_USED_ONLY_ONCE replacing new Random(seed) that preserves deterministic jitter behavior and avoids pressure on garbage collector (apache#38529) --- .../org/apache/beam/sdk/transforms/AsyncDoFn.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index b8ef4bc8c970..5cf2fca9f350 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -22,7 +22,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -91,7 +90,7 @@ public class AsyncDoFn extends DoFn, OutputT> private final boolean useThreadPool; private final String uuid; - private transient @Nullable PipelineOptions pipelineOptions; + private transient volatile @Nullable PipelineOptions pipelineOptions; // Shared JVM-Wide States (Static Registries) // Map-backed registry holding shared resources across serialized worker instances. Since runners @@ -495,12 +494,12 @@ private void scheduleItem(KV element, BoundedWindow window, Instant t private Instant nextTimeToFire(@Nullable K key) { long seed = (key == null) ? 0 : key.hashCode(); - Random random = new Random(seed); + double fractionalOffset = Math.abs(seed % 1000000) / 1000000.0; double timerFrequencySec = timerFrequency.getMillis() / 1000.0; double nowSec = System.currentTimeMillis() / 1000.0; double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; - double offset = random.nextDouble() * timerFrequencySec; + double offset = fractionalOffset * timerFrequencySec; return Instant.ofEpochMilli((long) ((base + offset) * 1000)); } @@ -568,7 +567,6 @@ void commitFinishedItems( int itemsFinished = 0; int itemsNotYetFinished = 0; int itemsRescheduled = 0; - int itemsCancelled = 0; Set finishedElementIds = new HashSet<>(); Set inFlightElementIds = new HashSet<>(); @@ -648,12 +646,10 @@ void commitFinishedItems( } LOG.info( - "Items finished: {}, not yet finished: {}, " - + "rescheduled: {}, cancelled: {}, in processing state: {}", + "Items finished: {}, not yet finished: {}, " + "rescheduled: {}, in processing state: {}", itemsFinished, itemsNotYetFinished, itemsRescheduled, - itemsCancelled, itemsInProcessingState); if (itemsInProcessingState > 0) { From 5259bf9dd89bfa487e42074d9c863cafa179349f Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Thu, 28 May 2026 18:58:18 +0000 Subject: [PATCH 06/11] Improve AsyncDoFn robustness and fix critical warnings provided by gemini-code-assist - Propagates asynchronous task exceptions as RuntimeExceptions to prevent silent data loss and enable runner-level retries. - Implements a static refCounts registry to safely tear down the shared executor service only when the last cloned instance is destroyed. - Validates timerFrequency in the constructor to prevent zero/negative values from entering infinite loops. - Documents multi-threading requirements, multi-output limitations, and bundle lifecycle behaviors in a class-level comment. --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 5cf2fca9f350..619001e230dd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -62,7 +62,15 @@ * runners are optimized for latencies less than a few seconds and longer operations can result in * high retry rates. Async should be considered when the default parallelism is not correct and/or * items are expected to take longer than a few seconds to process. - */ + +/* + * NOTE: + * 1) The wrapped syncFn requires thread-safety ONLY if BOTH parallelism > 1 AND + * the DoFn is stateful (keeps instance state). + * 2) Tagged output multi-outputs are unsupported. + * 3) StartBundle/finishBundle are invoked per element so any batching or + * aggregation logic will not behave as expected. +*/ public class AsyncDoFn extends DoFn, OutputT> { private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); @@ -101,6 +109,8 @@ public class AsyncDoFn extends DoFn, OutputT> processingElements = new ConcurrentHashMap<>(); private static final ConcurrentHashMap itemsInBuffer = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap refCounts = + new ConcurrentHashMap<>(); private static final ReentrantLock lock = new ReentrantLock(); private static final boolean verboseLogging = false; @@ -203,6 +213,9 @@ public AsyncDoFn( @Nullable Coder> coder) { this.syncFn = syncFn; this.parallelism = parallelism; + if (timerFrequency.getMillis() <= 0) { + throw new IllegalArgumentException("timerFrequency must be greater than zero"); + } this.timerFrequency = timerFrequency; this.maxItemsToBuffer = (maxItemsToBuffer != null) @@ -274,6 +287,7 @@ public String getErrorContext() { pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + refCounts.computeIfAbsent(uuid, k -> new AtomicInteger(0)).incrementAndGet(); } finally { lock.unlock(); } @@ -283,17 +297,19 @@ public String getErrorContext() { @Teardown public void teardown() { DoFnInvokers.invokerFor(syncFn).invokeTeardown(); - - ExecutorService threadPool; + ExecutorService threadPool = null; lock.lock(); try { - threadPool = pool.remove(uuid); - processingElements.remove(uuid); - itemsInBuffer.remove(uuid); + AtomicInteger refCount = refCounts.get(uuid); + if (refCount != null && refCount.decrementAndGet() == 0) { + refCounts.remove(uuid); + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } } finally { lock.unlock(); } - if (threadPool != null) { threadPool.shutdown(); try { @@ -598,9 +614,7 @@ void commitFinishedItems( itemsFinished++; } catch (Exception e) { LOG.error("Error executing async task for element {}", element, e); - finishedItems.add(element); - finishedElementIds.add(elementId); - activeElements.remove(elementId); + throw new RuntimeException("Error executing async task for element " + element, e); } } else { inFlightElementIds.add(elementId); @@ -696,6 +710,7 @@ static void resetState() { pool.clear(); processingElements.clear(); itemsInBuffer.clear(); + refCounts.clear(); } finally { lock.unlock(); } From 6c486213ed72eb867f9b67cafb8c573e3138a82c Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Thu, 28 May 2026 20:35:52 +0000 Subject: [PATCH 07/11] Fixed formatting issue --- .../org/apache/beam/sdk/transforms/AsyncDoFn.java | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 619001e230dd..4a422ab4d773 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -62,15 +62,12 @@ * runners are optimized for latencies less than a few seconds and longer operations can result in * high retry rates. Async should be considered when the default parallelism is not correct and/or * items are expected to take longer than a few seconds to process. - -/* - * NOTE: - * 1) The wrapped syncFn requires thread-safety ONLY if BOTH parallelism > 1 AND - * the DoFn is stateful (keeps instance state). - * 2) Tagged output multi-outputs are unsupported. - * 3) StartBundle/finishBundle are invoked per element so any batching or - * aggregation logic will not behave as expected. -*/ + * + *

/* NOTE: 1) The wrapped syncFn requires thread-safety ONLY if BOTH parallelism > 1 AND the + * DoFn is stateful (keeps instance state). 2) Tagged output multi-outputs are unsupported. 3) + * StartBundle/finishBundle are invoked per element so any batching or aggregation logic will not + * behave as expected. + */ public class AsyncDoFn extends DoFn, OutputT> { private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); From 08e1d5f63e30886361c3adbf0e8f0587730a3061 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 29 May 2026 18:32:04 +0000 Subject: [PATCH 08/11] Implement cross-key task cancellation in AsyncDoFn. Stores the partition key inside InFlightElement and cancels/purges orphaned futures inside commitFinishedItems. This prevents silent memory leaks on bundle rollbacks (#38529) --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 49 ++++++++++++++----- .../beam/sdk/transforms/AsyncDoFnTest.java | 3 -- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 4a422ab4d773..ae2bbe9874d0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -63,10 +63,9 @@ * high retry rates. Async should be considered when the default parallelism is not correct and/or * items are expected to take longer than a few seconds to process. * - *

/* NOTE: 1) The wrapped syncFn requires thread-safety ONLY if BOTH parallelism > 1 AND the - * DoFn is stateful (keeps instance state). 2) Tagged output multi-outputs are unsupported. 3) - * StartBundle/finishBundle are invoked per element so any batching or aggregation logic will not - * behave as expected. + *

/* NOTE: 1) The wrapped syncFn REQUIRES thread-safety if BOTH parallelism > 1 and the DoFn is + * stateful. 2) Tagged output multi-outputs are unsupported. 3) StartBundle/finishBundle are invoked + * per element so any batching or aggregation logic will not behave as expected. */ public class AsyncDoFn extends DoFn, OutputT> { @@ -106,9 +105,13 @@ public class AsyncDoFn extends DoFn, OutputT> processingElements = new ConcurrentHashMap<>(); private static final ConcurrentHashMap itemsInBuffer = new ConcurrentHashMap<>(); + // Reference counts for cloned instances sharing the same UUID. Coordinates safe, + // leak-free thread pool shutdown during teardown without crashing active sibling clones. private static final ConcurrentHashMap refCounts = new ConcurrentHashMap<>(); + // If contention becomes a bottleneck, this can be replaced with per-uuid locks + // in a ConcurrentHashMap private static final ReentrantLock lock = new ReentrantLock(); private static final boolean verboseLogging = false; @@ -123,9 +126,12 @@ private static class TimestampedOutput { } private static class InFlightElement { + final @Nullable Object key; final CompletableFuture>> future; - InFlightElement(CompletableFuture>> future) { + InFlightElement( + @Nullable Object key, CompletableFuture>> future) { + this.key = key; this.future = future; } } @@ -450,15 +456,10 @@ public String getErrorContext() { CompletableFuture>> unused = future.whenComplete( (res, ex) -> { - lock.lock(); - try { - getItemsInBuffer().decrementAndGet(); - } finally { - lock.unlock(); - } + getItemsInBuffer().decrementAndGet(); }); - activeElements.put(elementId, new InFlightElement<>(future)); + activeElements.put(elementId, new InFlightElement<>(element.getKey(), future)); getItemsInBuffer().incrementAndGet(); return true; } @@ -505,6 +506,8 @@ private void scheduleItem(KV element, BoundedWindow window, Instant t // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. } + // Uses hashcode based jitter instead of random for deterministic rescheduling + // Satisfies lint check private Instant nextTimeToFire(@Nullable K key) { long seed = (key == null) ? 0 : key.hashCode(); double fractionalOffset = Math.abs(seed % 1000000) / 1000000.0; @@ -587,6 +590,28 @@ void commitFinishedItems( lock.lock(); try { + Set stateElementIds = new HashSet<>(); + for (KV element : stateList) { + stateElementIds.add(idFn.apply(element.getValue())); + } + + List toCancelIds = new ArrayList<>(); + for (Map.Entry> entry : activeElements.entrySet()) { + InFlightElement inFlight = entry.getValue(); + if (java.util.Objects.equals(inFlight.key, key) + && !stateElementIds.contains(entry.getKey())) { + toCancelIds.add(entry.getKey()); + } + } + + for (Object cancelId : toCancelIds) { + InFlightElement inFlight = activeElements.get(cancelId); + if (inFlight != null) { + inFlight.future.cancel(true); + activeElements.remove(cancelId); + } + } + for (KV element : stateList) { Object elementId = idFn.apply(element.getValue()); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java index 9bd2d07a3558..ec2a940a817a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/AsyncDoFnTest.java @@ -33,13 +33,11 @@ import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.Timer; -import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.KV; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -48,7 +46,6 @@ @RunWith(JUnit4.class) public class AsyncDoFnTest implements Serializable { - @Rule public final transient TestPipeline p = TestPipeline.create(); private final boolean useThreadPool = true; // Used for testing basic DoFn processing logic with optional latency. From c899da28a211c4c5f2172d4f5036075f30b7b139 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 29 May 2026 20:59:23 +0000 Subject: [PATCH 09/11] Passes original element's inputTimestamp to AccumulatingOutputReceiver to preserve event-time downstream. Initializes the ExecutorService thread pool only when useThreadPool is true. Refactors state filtering to use finishedElementIds instead of finishedItems, preventing duplicate processing (#38529) --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index ae2bbe9874d0..0309593b8d3b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -144,11 +144,17 @@ private static class AccumulatingOutputReceiver implements OutputReceiver private final List> outputs = Collections.synchronizedList(new ArrayList<>()); + private final Instant inputTimestamp; // <-- Store original timestamp + + AccumulatingOutputReceiver(Instant inputTimestamp) { + this.inputTimestamp = inputTimestamp; + } + @Override public org.apache.beam.sdk.values.OutputBuilder builder(T value) { return org.apache.beam.sdk.values.WindowedValues.builder() .setValue(value) - .setTimestamp(Instant.now()) + .setTimestamp(inputTimestamp) .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) .setReceiver( @@ -162,7 +168,7 @@ public org.apache.beam.sdk.values.OutputBuilder builder(T value) { // JVM optimization to prevent garbage collection pressure under high pipeline throughput. @Override public void output(T output) { - outputs.add(new TimestampedOutput<>(output, null)); + outputs.add(new TimestampedOutput<>(output, inputTimestamp)); } @Override @@ -287,7 +293,9 @@ public String getErrorContext() { lock.lock(); try { - pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + if (useThreadPool) { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + } processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); refCounts.computeIfAbsent(uuid, k -> new AtomicInteger(0)).incrementAndGet(); @@ -353,7 +361,7 @@ private boolean scheduleIfRoom( () -> { try { AccumulatingOutputReceiver receiver = - new AccumulatingOutputReceiver<>(); + new AccumulatingOutputReceiver<>(timestamp); DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn); DoFnInvoker.ArgumentProvider bundleArgProvider = @@ -577,7 +585,7 @@ void commitFinishedItems( ConcurrentHashMap> activeElements = getProcessingElements(); List>> toReturn = new ArrayList<>(); - Set> finishedItems = new HashSet<>(); + List> toReschedule = new ArrayList<>(); int itemsFinished = 0; @@ -630,7 +638,7 @@ void commitFinishedItems( if (!inFlight.future.isCancelled()) { toReturn.add(inFlight.future.get()); } - finishedItems.add(element); + finishedElementIds.add(elementId); activeElements.remove(elementId); itemsFinished++; @@ -663,7 +671,8 @@ void commitFinishedItems( toProcessState.clear(); int itemsInProcessingState = 0; for (KV element : stateList) { - if (!finishedItems.contains(element)) { + Object elementId = idFn.apply(element.getValue()); + if (!finishedElementIds.contains(elementId)) { toProcessState.add(element); itemsInProcessingState++; } @@ -710,7 +719,7 @@ void processDirect( List commitFinishedItemsDirect( Instant fireTimestamp, BagState> toProcessState, Timer timer) { - AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(); + AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(fireTimestamp); commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); return receiver.getOutputs(); } From 501e15596408dd1aa9731af6899acbab78d84da4 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 29 May 2026 21:28:10 +0000 Subject: [PATCH 10/11] Changed outputs from List type to ConcurrentLinkedQueue to prevent lock contention (#38529) --- .../java/org/apache/beam/sdk/transforms/AsyncDoFn.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 0309593b8d3b..9b24d64bf9a5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.transforms; import java.util.ArrayList; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -141,10 +140,11 @@ private static class InFlightElement { // Buffered elements are only committed downstream once the parent task completes successfully // and the timer fires. private static class AccumulatingOutputReceiver implements OutputReceiver { - private final List> outputs = - Collections.synchronizedList(new ArrayList<>()); + private final java.util.concurrent.ConcurrentLinkedQueue> outputs = + new java.util.concurrent.ConcurrentLinkedQueue<>(); - private final Instant inputTimestamp; // <-- Store original timestamp + // Store original timestamp + private final Instant inputTimestamp; AccumulatingOutputReceiver(Instant inputTimestamp) { this.inputTimestamp = inputTimestamp; @@ -185,7 +185,7 @@ public List getOutputs() { } public List> getTimestampedOutputs() { - return outputs; + return new ArrayList<>(outputs); } } From 2180622b2de5a401c799eec98967a2c66c70d0c5 Mon Sep 17 00:00:00 2001 From: Tejas Iyer Date: Fri, 29 May 2026 22:18:56 +0000 Subject: [PATCH 11/11] Reverted outputs back to list type. Takes original Bounded Window as parameter without Hardcoding GlobalWindow.Instance (#38529) --- .../apache/beam/sdk/transforms/AsyncDoFn.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java index 9b24d64bf9a5..fab9c2279d94 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java @@ -140,14 +140,13 @@ private static class InFlightElement { // Buffered elements are only committed downstream once the parent task completes successfully // and the timer fires. private static class AccumulatingOutputReceiver implements OutputReceiver { - private final java.util.concurrent.ConcurrentLinkedQueue> outputs = - new java.util.concurrent.ConcurrentLinkedQueue<>(); - - // Store original timestamp + private final List> outputs = new ArrayList<>(); private final Instant inputTimestamp; + private final BoundedWindow window; - AccumulatingOutputReceiver(Instant inputTimestamp) { + AccumulatingOutputReceiver(Instant inputTimestamp, BoundedWindow window) { this.inputTimestamp = inputTimestamp; + this.window = window; } @Override @@ -155,7 +154,7 @@ public org.apache.beam.sdk.values.OutputBuilder builder(T value) { return org.apache.beam.sdk.values.WindowedValues.builder() .setValue(value) .setTimestamp(inputTimestamp) - .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setWindows(java.util.Collections.singletonList(window)) .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) .setReceiver( windowedValue -> @@ -185,7 +184,7 @@ public List getOutputs() { } public List> getTimestampedOutputs() { - return new ArrayList<>(outputs); + return outputs; } } @@ -361,7 +360,7 @@ private boolean scheduleIfRoom( () -> { try { AccumulatingOutputReceiver receiver = - new AccumulatingOutputReceiver<>(timestamp); + new AccumulatingOutputReceiver<>(timestamp, window); DoFnInvoker invoker = DoFnInvokers.invokerFor(syncFn); DoFnInvoker.ArgumentProvider bundleArgProvider = @@ -719,7 +718,8 @@ void processDirect( List commitFinishedItemsDirect( Instant fireTimestamp, BagState> toProcessState, Timer timer) { - AccumulatingOutputReceiver receiver = new AccumulatingOutputReceiver<>(fireTimestamp); + AccumulatingOutputReceiver receiver = + new AccumulatingOutputReceiver<>(fireTimestamp, GlobalWindow.INSTANCE); commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); return receiver.getOutputs(); }