Skip to content

fix(moe): avoid cross-warp stale read in ep_scatter prefix sum#1362

Open
Anai-Guo wants to merge 2 commits into
ModelTC:mainfrom
Anai-Guo:fix-ep-scatter-cross-warp-stale-read
Open

fix(moe): avoid cross-warp stale read in ep_scatter prefix sum#1362
Anai-Guo wants to merge 2 commits into
ModelTC:mainfrom
Anai-Guo:fix-ep-scatter-cross-warp-stale-read

Conversation

@Anai-Guo

Copy link
Copy Markdown

Summary

Fixes #1361.

_fwd_kernel_ep_scatter_1 computes the full exclusive prefix sum over experts in registers, stores the whole array to expert_start_loc with a vectorized tl.store, and then immediately reads its own slot back:

tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)

cur_expert_start = tl.load(expert_start_loc + cur_expert)        # racy read-back
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

With num_experts=256 the vectorized store is spread across the program's warps. The subsequent scalar read of expert_start_loc[cur_expert] may target a slot written by a different warp, and there is no barrier between the store and the load. Under CUDA's weak memory model the read can observe the stale/uninitialized contents of the torch.empty-allocated buffer, producing a garbage offset that makes the following unmasked m_indices write land in unmapped memory:

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
  File "deepep_scatter_gather.py", line 25, in _fwd_kernel_ep_scatter_1
    cur_expert_start = tl.load(expert_start_loc + cur_expert)

Fix

Read cur_expert_start (and cur_expert_token_num) directly from the in-register cumsum / tokens_per_expert vectors instead of from the just-written global buffer:

expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert)))

Exactly one lane matches cur_expert, so each tl.sum returns that lane's value. tl.sum reduces through shared memory with proper synchronization, so the result is always correct and the racy global round-trip is removed. The tl.store to expert_start_loc is kept because _fwd_kernel_ep_scatter_2 consumes it downstream.

The produced values are identical to the original code on any correct execution — this only removes the data race, so behaviour is unchanged when the race happened not to fire.

🤖 Generated with Claude Code

_fwd_kernel_ep_scatter_1 stores the full exclusive prefix sum to
expert_start_loc with a vectorized tl.store, then immediately reads back
expert_start_loc[cur_expert] with a scalar tl.load. The vectorized store is
split across the program warps, so under CUDA weak memory ordering the scalar
read can observe a stale/uninitialized value written by another warp, producing
a garbage offset and a cudaErrorIllegalAddress in the following m_indices write.

Extract cur_expert_start (and cur_expert_token_num) directly from the in-register
cumsum / tokens_per_expert vectors via tl.where + tl.sum, which reduces through
shared memory with proper synchronization. The global store to expert_start_loc
is kept since _fwd_kernel_ep_scatter_2 consumes it.

Fixes ModelTC#1361

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the Triton kernel _fwd_kernel_ep_scatter_1 in deepep_scatter_gather.py to read the prefix-sum offset from registers instead of the global buffer, preventing stale reads across warps. The reviewer points out that while cur_expert_start needs this register-based calculation, cur_expert_token_num is read from a read-only input tensor and can be loaded directly to avoid unnecessary overhead. They also suggest simplifying the tl.zeros_like call to 0.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +31 to +33
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert)))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While cur_expert_start needs to be read from registers to avoid the cross-warp stale read on expert_start_loc, num_recv_tokens_per_expert is a read-only input tensor that is never written to in this kernel. Therefore, there is no risk of a stale read or race condition for cur_expert_token_num.

Using tl.sum with tl.where and tl.zeros_like on tokens_per_expert introduces unnecessary overhead (shared memory allocation, synchronization, and reduction across warps) for a value that can be loaded directly and efficiently via a single scalar global load (which is highly cached).

Additionally, we can simplify tl.zeros_like(cumsum) to 0 in the cur_expert_start calculation to make the code cleaner and avoid creating an extra zero tensor.

Suggested change
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert)))
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, 0))
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

num_recv_tokens_per_expert is a read-only input and is never written in
this kernel, so reading it directly carries no cross-warp stale-read
risk. Only the expert_start_loc read-back (written just above) needs the
register-based extraction. Reverts the unnecessary reduction for
cur_expert_token_num per review feedback.
@Anai-Guo

Copy link
Copy Markdown
Author

Good catch — num_recv_tokens_per_expert is a read-only input and is never written inside this kernel, so its load was never at risk of a cross-warp stale read. I've reverted cur_expert_token_num to the original direct load and kept only the register-based extraction for cur_expert_start (which is read back from expert_start_loc immediately after the split-warp store). This narrows the change to exactly the line that needed it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]

1 participant