Skip to content

[SPARK-57491][CORE] Prevent and detect stale push-based shuffle data from duplicate map task attempts#56559

Open
gaoyajun02 wants to merge 19 commits into
apache:masterfrom
gaoyajun02:SPARK-33235
Open

[SPARK-57491][CORE] Prevent and detect stale push-based shuffle data from duplicate map task attempts#56559
gaoyajun02 wants to merge 19 commits into
apache:masterfrom
gaoyajun02:SPARK-33235

Conversation

@gaoyajun02

@gaoyajun02 gaoyajun02 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

What changes were proposed in this pull request?

Fix push-based shuffle serving stale/incorrect data when multiple task attempts for the same map partition both push data to the external merger. Three layers of defense:

  • ShuffleWriteProcessor — Defer shuffle block push via PostStatusUpdateListener so that killed/failed tasks never push. The listener checks context.isInterrupted() / context.isFailed() before initiating push, and logs which task was skipped and why.

  • TaskSetManager + MapOutputTracker — Track stale (duplicate) partitionIds on the driver side. When a speculative or retried ShuffleMapTask result arrives after another attempt for the same partition has already committed, TaskSetManager.detectStalePushIfShuffleTask marks the partition's ID as stale in ShuffleStatus.staleMapIndexes. The stale set is serialized alongside existing map/merge output metadata and propagated to executor-side MapOutputTrackerWorker.

  • PushBasedFetchHelper (ShuffleBlockFetcherIterator) — Before reading any merged block on the reducer side, check chunk-level RoaringBitmaps for stale partitionIds via the new checkStaleMapIdInMergedBlock method. If a stale mapIndex is found in any chunk bitmap, log a WARN and fall back to fetching original unmerged blocks. When the stale set is empty (the common case), this check short-circuits immediately with zero overhead.

Why are the changes needed?

With push-based shuffle enabled, indeterminate stages can produce incorrect results when speculation or task retry causes two attempts of the same map partition to both push data to the external merger. The merged block ends up with interleaved data from both attempts, but only one attempt's MapStatus is committed on the driver. Downstream reducers silently read corrupted/duplicated data.

This is triggered by non-deterministic functions (rand(), UUID(), etc.) in shuffle keys or join conditions where row ordering differs between attempts. Even with deferred push (layer 1), there is a race window: if an original attempt succeeds and pushes, then gets killed by a speculative winner that also pushes, the merger receives data from both attempts but only one MapStatus is registered.

Why the existing checksum-mismatch / rollback machinery does not cover this case?

The existing checksum-mismatch detection in DAGScheduler.handleTaskCompletion relies on comparing MapStatus.checksumValue values via MapOutputTracker.registerMapOutput(). This only works for task results reported as Success — the MapStatus is registered and compared against any previously registered status for the same partition.

In speculation with push-based shuffle, the losing attempt's result never reaches registerMapOutput(): when a later-arriving or killed speculative attempt is detected in TaskSetManager.handleSuccessfulTask, it is handled as TaskKilled (not Success) and forwarded to DAGScheduler as a kill event, which bypasses the registerMapOutput path entirely. Therefore no checksum comparison occurs for the loser's output.

Furthermore, even if this path were reachable, checksumMismatchFullRetryEnabled is force-disabled when push-based shuffle is active (Dependency.scala:148-149), since the two features are designed for mutually exclusive scenarios: checksum mismatch triggers stage-level rollback with full retry of all downstream stages, while stale push detection handles the problem at the partition level with a lightweight fetch fallback.

Does this PR introduce any user-facing change?

No

How was this patch tested?

  • UT
  • Manual verification on internal cluster confirmed: speculative duplicate pushes are either prevented by deferred push or detected and fallen back from, with expected WARN log on fallback.

Was this patch authored or co-authored using generative AI tooling?

Yes, co-authored with GLM-5V-Turbo.

@gaoyajun02

Copy link
Copy Markdown
Contributor Author

PTAL @cloud-fan @Ngone51 @mridulm @otterc

@cloud-fan cloud-fan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 blocking, 5 non-blocking, 3 nits.
A well-motivated correctness fix filling a real gap in the existing checksum-mismatch path; two blocking items below.

Design / architecture (2)

  • MapOutputTracker.scala:121: stale*MapId* naming holds partitionId (== mapIndex), not MapStatus.mapId — a load-bearing invariant the names hide — see inline
  • (non-blocking, question) overlaps the existing checksumMismatchIndices/rollback path and the merger's own stale-block handling; document the interaction or consider unifying

Correctness (4)

  • MapOutputTrackerSuite.scala:640 (existing, unchanged line): the "SPARK-32921" test still binds the reply with askSync[(Array[Byte], Array[Byte])], but the endpoint now returns a 3-tuple → ClassCastException. This test needs updating.
  • MapOutputTracker.scala:114: staleMapIds is never cleared except by unregisterShuffle — permanent fallback even after a clean recompute — see inline
  • TaskSetManager.scala:966: over-broad marking (winner re-delivery; attempt may never have pushed) — see inline
  • PushBasedFetchHelper.scala:313: @return overstates precision ("stale data" vs "marked partition present") — see inline

Suggestions (1)

  • No tests exercise the reduce-side checkStaleMapIdInMergedBlock fallback or the deferred-push skip-on-kill

Nits: 3 minor items (see inline comments).

Verification

Traced the marked-value invariant: ShuffleMapTask pushes with mapIndex = partitionId, RemoteBlockPushResolver records chunkTracker.add(mapIndex), and the driver marks tasks(index).partitionId. So the reduce-side bitmap.contains(id) works only because partitionId == mapIndex — never because it uses MapStatus.mapId. Detection is functionally correct today, but per-partition rather than per-attempt: both attempts share the mapIndex, so a marked block always falls back regardless of whether duplicate bytes are actually present (conservative and safe).

PR description suggestions

  • Document: why the existing checksum-mismatch / rollback machinery does not cover this case (the loser's MapStatus is never registered, so no checksum comparison happens).

* detects that multiple task attempts for the same map output pushed data to the merger.
* @param staleMapId the mapId of the stale (redundant) attempt
*/
def markStalePushedMap(staleMapId: Int): Unit = withWriteLock {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This set holds the partitionId, which equals the mapIndex the merger records in its chunk bitmaps — not MapStatus.mapId. ShuffleMapTask pushes with mapIndex = partitionId, RemoteBlockPushResolver does chunkTracker.add(mapIndex), and the driver marks tasks(index).partitionId, so the reduce-side contains() check works only because partitionId == mapIndex. Naming all of this *MapId* (and logging mapId=${mapStatus.mapId} in TaskSetManager) hides that invariant: someone "fixing" the marked value to the real mapId would break the bitmap match and serve stale data as clean. Suggest renaming to staleMapIndexes / markStalePushedPartition and logging the partitionId you actually mark.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed rename plan:

The codebase already uses staleMapIndexes / markStalePushedPartition / getStaleMapIndexes in most places. The remaining inconsistencies are:

  1. MapOutputTrackerWorker: staleMapIds → should be staleMapIndexes (for consistency with ShuffleStatus and PushBasedFetchHelper)
  2. TaskSetManager log line 967: mapId=${mapStatus.mapId} → should be partitionId=$partitionId (log what we actually mark, not the MapStatus field)
  3. Test names / comments: any remaining *MapId* references to the stale set should use *PartitionId* or *MapIndex*

We'll clean these up to make the invariant obvious and prevent future "corrections" that break the bitmap match. Thanks for catching this.

// partition already has a successful attempt registered. Any MapStatus arriving
// here is from a stale (redundant) attempt that also pushed data.
// Mark its mapId as stale so reducers can detect it in merged block chunks.
shuffleStatus.markStalePushedMap(partitionId)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This marks the partition stale unconditionally, without checking whether this attempt actually pushed. With the new deferred push, a killed attempt's completion listener skips the push, so it usually pushed nothing — yet it is still marked, forcing reducers to fall back. The info.finished call site is sharper: it can fire on a re-delivered winning attempt, marking the legitimate partition. The net effect is unnecessary fallback on most speculation/retry, including deterministic stages. Can the marking be gated on evidence that the attempt pushed (and skip the winner re-delivery case)?

@gaoyajun02 gaoyajun02 Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review. You raise an excellent point about the unconditional marking. After tracing through the code, here's my analysis:

  1. The info.finished path: winner re-delivery, not a stale push.
    handleSuccessfulTask is serialized under TaskSchedulerImpl.synchronized (line 896), so two results for the same partition are never processed concurrently — they are strictly ordered. The info.finished guard exists to handle message retransmission of the same winning attempt, not a concurrent late arrival from a different attempt. Marking it as stale is incorrect: this is the canonical result arriving twice.
  2. The killedByOtherAttempt path: push has occurredFor a task result to reach
    handleSuccessfulTask, the executor must have sent it with state FINISHED (Executor.scala:1029). A task only reports FINISHED after runTask() completes successfully and the completion listener fires. At that point, ShuffleWriteProcessor's listener checks !context.isInterrupted() && !context.isFailed() — both are still false because the task finished normally. So initiateBlockPush() is always called before the result is sent back. The kill signal arrives later (or concurrently on a different thread) but cannot retroactively undo the already-submitted push job.
    We have actually observed this in production(with this PR's
    changes applied):
    when the driver receives two tasks with the same mapIndex and rejects one
Image

our debug logs confirm that the rejected task had already started pushing before its result arrived at the driver. This validates that the killedByOtherAttempt path always represents a real stale push.

Image

Therefore, only the killedByOtherAttempt path legitimately needs markStalePushedPartition. The info.finished path should skip it entirely.

* We record the stale mapIds here so the reduce side can check chunkBitmaps and fallback
* if stale data is present in a merged block.
*/
private[this] val staleMapIds = new java.util.HashSet[Int]()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

staleMapIds is only ever added to — addMapOutput/updateMapOutput and the output-invalidation paths never clear it; only unregisterShuffle drops it. So a partition stays marked, and its merged blocks keep falling back to unmerged, for the whole life of the shuffle, even after a clean recompute re-registers a valid output. Clear the relevant id when its mapIndex is (re-)registered.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've carefully evaluated whether staleMapIndexes should be cleared in addMapOutput, and believe the current design is correct as-is. Here's the reasoning:

Core principle: staleMapIndexes cleanup should be tied to merged block lifecycle, not map output lifecycle.

The stale set protects against reading merged blocks that contain stale chunks. A partition's stale mark can only be safely removed when we know all corresponding merged blocks have been refreshed or discarded — which means all merger locations must be updated (either new shuffleMergeId, or explicit merge result invalidation).

Why not clearing in addMapOutput:

  1. addMapOutput only updates map-side metadata (mapStatuses) — it does not invalidate any existing merged blocks on mergers. The physical merged files on the external shuffle service may still contain stale chunks from duplicate pushes.

  2. Within the same stageAttempt / shuffleMergeId, a late-arriving speculative addMapOutput call is exactly the event whose MapStatus is the stale one — clearing the stale mark here would defeat layers 2/3.

Stage retry (full reset) — confirmed handled:

When unregisterAllMapAndMergeOutput is called (stage retry / rollback / barrier stage abort), it triggers:

  • removeMergeResultsByFilter(x => true) — all merge results cleared
  • removeShuffleMergerLocations() — all merger locations cleared
  • newShuffleMergeState()shuffleMergeId += 1 — new merge ID
  • incrementEpoch() — workers fetch fresh ShuffleStatus on next getStatuses call

The worker-side staleMapIndexes for this shuffleId is either cleared by unregisterShuffle or replaced entirely when workers re-fetch with the new epoch and get a fresh serialized snapshot where the driver-side staleMapIndexes set is empty (new ShuffleStatus).

Partial merge cleanup (single FetchFailed) — intentional keep:

For non-barrier stages, unregisterMergeResult only removes a specific <reduceId, bmAddress> merge entry without incrementing shuffleMergeId. Other reduces' merged blocks on the same mergers may still contain stale chunks, so keeping the stale mark and falling back is the safe choice.

Future optimization (not blocking):

If we want to avoid unnecessary fallbacks in the partial-cleanup case, we could track stale at merger-location granularity (e.g., staleMapIndexesByMerger[mergerAddress][partitionId]) and clear entries when that specific merger's merge results are all invalidated. But this adds complexity for marginal gain — the fallback path is correct and only costs performance, not correctness.

In summary: the current "only-add, clear-on-shuffle-end" semantics are intentional and match the merged block lifecycle. We're happy to add location-level tracking as a follow-up optimization if desired.

Comment thread core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala Outdated
Comment thread core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala Outdated
Comment thread core/src/main/scala/org/apache/spark/MapOutputTracker.scala Outdated
Comment thread core/src/main/scala/org/apache/spark/MapOutputTracker.scala Outdated
Comment thread core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala Outdated

@cloud-fan cloud-fan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 addressed, 0 remaining, 5 new. (5 = 1 newly introduced, 4 late catches.)

1 blocking, 2 non-blocking, 2 nits.
Strong iteration — the substantive prior feedback is well addressed; the one blocking item is a stale test that contradicts the agreed design and breaks CI.

Design / architecture (1)

  • TaskContext.scala:208: addPostStatusUpdateListener is public @Experimental but its param trait is private[spark] — uncallable externally; make the method private[spark] — see inline

Correctness (1)

  • TaskSetManagerSuite.scala:2999: the "finished task" test asserts the info.finished path marks stale, but production skips it (the agreed behavior) -> fails CI; remove or invert — see inline

Suggestions (1)

  • PushBasedFetchHelper.scala:170: readChunkBitmaps() is called twice on both meta paths; reuse the already-read bitmaps — see inline

Nits: 2 items — @param partitionId should be @param mapIndex (see inline); and several new tests use the SPARK-33235 prefix instead of SPARK-57491 (TaskContextSuite, the finished-task TaskSetManagerSuite test, ShuffleBlockFetcherIteratorSuite).

Verification

Traced the stale-marking control flow: detectStalePushIfShuffleTask has exactly one production call site (TaskSetManager.scala:828, the killedByOtherAttempt branch); the info.finished early-return (816-822) does not call it. So the finished-task test's assertion that info.finished marks the partition stale cannot hold — confirmed by the failing CI annotation on head e7319fa (Set() did not contain 0).

PR description suggestions

  • Document: why the existing checksum-mismatch / rollback machinery does not cover this case (the loser attempt's MapStatus is never registered, so no checksum comparison happens for it).

s"Expected staleMapIndexes to contain mapIndex 0, got $staleMapIndexes")
}

test("SPARK-33235: late-arriving result for finished task marks stale partitionId") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test drives the info.finished path (a duplicate result for the same tid) and asserts the partition is marked stale. But production handleSuccessfulTask returns early on info.finished without calling detectStalePushIfShuffleTask — the sole call site is the killedByOtherAttempt branch (line 828), which is exactly the design you agreed to (winner re-delivery must not mark stale). So this test asserts behavior the code intentionally does not implement, and it fails CI:

Set() did not contain 0  Expected staleMapIndexes to contain mapIndex 0 on finished-task path, got Set()

It looks left over from before the info.finished call was removed. Either delete it, or invert it to assert the finished-task path does not mark stale — which documents the intentional decision.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this — you're absolutely right. This test is a leftover from before the info.finished early-return was finalized. The production code in handleSuccessfulTask (line 816-822) returns immediately on info.finished without calling detectStalePushIfShuffleTask, and the sole call site for detectStalePushIfShuffleTask is the killedByOtherAttempt branch at line 828 — exactly as designed, since a re-delivery of an already-committed winner result must not mark anything stale.

I've removed this test entirely. The speculative duplicate path (killedByOtherAttempt) is already covered by the SPARK-57491: speculative duplicate result marks stale partitionId test in the same suite.

* The callback runs on the same executor thread that sends the status update.
*/
@Experimental
def addPostStatusUpdateListener(listener: PostStatusUpdateListener): TaskContext

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is public @Experimental, but its parameter trait PostStatusUpdateListener is private[spark]. External code can't name the type, so the method is uncallable outside Spark while still leaking an internal trait into the public API signature. Since the mechanism is entirely internal to push-based shuffle, make the method private[spark] to match the trait.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — changed addPostStatusUpdateListener from public @Experimental to private[spark] to match the visibility of PostStatusUpdateListener. The mechanism is entirely internal to push-based shuffle.

try {
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, shuffleMergeId,
reduceId, sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
val chunkBitmaps = meta.readChunkBitmaps().toIndexedSeq

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta.readChunkBitmaps() is called twice here — once to build chunkBitmaps for checkStaleMapIdInMergedBlock, then again when constructing PushMergedRemoteMetaFetchResult. Each call re-deserializes the RoaringBitmaps, on the common (non-stale) path. Since bitmaps is Array[RoaringBitmap] and the check accepts Seq[RoaringBitmap], reuse the already-read value instead of re-reading. Same pattern on the local meta path (line 282).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — reused the already-read chunkBitmaps variable instead of calling meta.readChunkBitmaps() / chunksMeta.readChunkBitmaps() a second time when constructing the fetch result. This was a partial fix in the previous commit; now fully cleaned up for both remote and local meta paths.

Comment thread core/src/main/scala/org/apache/spark/MapOutputTracker.scala Outdated
gaoyajun02 and others added 2 commits June 24, 2026 14:42
@gaoyajun02

Copy link
Copy Markdown
Contributor Author

All review feedback addressed — visibility fix, stale test removal, readChunkBitmaps dedup, javadoc nit, test renaming, and PR description update included. Thanks for the thorough review!
PTAL @cloud-fan

@cloud-fan cloud-fan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 addressed, 0 remaining, 3 new. (3 = 1 newly introduced, 2 late catches.)

1 blocking, 1 non-blocking, 1 nit.
Strong iteration — all 5 prior findings are addressed; the one blocking item is a mechanical import-order break that fails scalastyle.

Correctness (2)

  • PushBasedFetchHelper.scala:20: import block reordered to spark/3rdParty/java/scala; scalastyle ImportOrderChecker (error level) will fail the lint CI gate — see inline
  • MapOutputTracker.scala:130: getStaleMapIndexes returns a live .asScala view of a non-thread-safe HashSet with no read lock, racing markStalePushedPartition — see inline (non-blocking)

Nits: 1 minor item (see inline comments).

Verification

Traced layer 1's skip-on-kill guard: invokePostStatusUpdateListeners runs only on the success path, after statusUpdate(FINISHED) (Executor.scala:1030), and setTaskFinishedAndClearInterruptStatus (Executor.scala:738) clears only the thread interrupt and TaskRunner.finished — not TaskContextImpl.reasonIfKilled — so the listener's !isInterrupted() && !isFailed() check stays accurate and a killed-but-completed attempt correctly skips its push. Layer 2's marking has its sole call site in the killedByOtherAttempt branch (TaskSetManager.scala:828); the info.finished winner-re-delivery path does not mark, as agreed. The serialize/deserialize stale-set round-trip is format-consistent.

Comment thread core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala Outdated
Comment thread core/src/main/scala/org/apache/spark/MapOutputTracker.scala Outdated
Comment thread core/src/main/scala/org/apache/spark/MapOutputTracker.scala Outdated
@gaoyajun02

Copy link
Copy Markdown
Contributor Author

PTAL @cloud-fan

@cloud-fan cloud-fan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 addressed, 0 remaining, 2 new. (2 = 0 newly introduced, 2 late catches.)

0 blocking, 1 non-blocking, 1 nit.
The substantive prior feedback is fully addressed; the delta this round (read-lock fix, import order, mapIndex doc) is clean. Two late catches remain — both pre-date my earlier reviews.

Design / architecture (1)

  • (non-blocking) MapOutputTracker.scala:123: markStalePushedPartition doesn't incrementEpoch, and the worker caches the stale set per-shuffle — a reducer that fetched statuses before a late attempt's mark won't see it (no cache refresh), so layer-3 fallback can be missed in the layer-1-failed race — see inline

Nits: 1 minor item (see inline comments).

Verification

Confirmed all 3 prior-round items are fixed: getStaleMapIndexes now takes withReadLock and returns a defensive copy (130-132); the PushBasedFetchHelper imports are restored to java/scala/3rdParty/spark order; the serialize-format diagram reads mapIndex. Traced the visibility gap behind the design finding: incrementEpoch is called only on output-mutation paths (none on the markStalePushedPartition path), and the worker refreshes its cached stale set only on cache-miss / epoch-bump / fetch-failure — so a stale mark added after a reducer's first status fetch is not observed by that reducer.

* @param mapIndex the partition index (== mapIndex) of the stale (redundant) attempt;
* this is NOT MapStatus.mapId
*/
def markStalePushedPartition(mapIndex: Int): Unit = withWriteLock {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking design question. markStalePushedPartition records the stale partition but does not incrementEpoch(), and the worker (MapOutputTrackerWorker.staleMapIndexes) caches the stale set per-shuffle, refreshing it only on a cache miss, an epoch bump, or a fetch failure. So a reducer that fetched its merge statuses before a late attempt's mark keeps its already-cached (empty) stale set and never sees the mark — meaning layer-3 fallback does not fire for that reducer in exactly the layer-1-failed race this PR targets (original attempt pushed, then killed by a speculative winner that also pushed).

The window is narrow and layer 1 (deferred push) is the primary defense, so this is a question rather than a blocker: is there an ordering guarantee that the mark always precedes any reducer's first status fetch for the shuffle? If not, should the mark advance the epoch (the checksum-mismatch path effectively forces a refresh via stage recompute)? A sentence in the code or PR description either way would help. Distinct from the "never cleared" thread above, which is about removing marks, not propagating new ones.

Comment thread core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala Outdated
gaoyajun02 and others added 2 commits July 1, 2026 18:57
…per.scala

Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants