fix(moe): avoid cross-warp stale read in ep_scatter prefix sum#1362
fix(moe): avoid cross-warp stale read in ep_scatter prefix sum#1362Anai-Guo wants to merge 2 commits into
Conversation
_fwd_kernel_ep_scatter_1 stores the full exclusive prefix sum to expert_start_loc with a vectorized tl.store, then immediately reads back expert_start_loc[cur_expert] with a scalar tl.load. The vectorized store is split across the program warps, so under CUDA weak memory ordering the scalar read can observe a stale/uninitialized value written by another warp, producing a garbage offset and a cudaErrorIllegalAddress in the following m_indices write. Extract cur_expert_start (and cur_expert_token_num) directly from the in-register cumsum / tokens_per_expert vectors via tl.where + tl.sum, which reduces through shared memory with proper synchronization. The global store to expert_start_loc is kept since _fwd_kernel_ep_scatter_2 consumes it. Fixes ModelTC#1361
There was a problem hiding this comment.
Code Review
This pull request modifies the Triton kernel _fwd_kernel_ep_scatter_1 in deepep_scatter_gather.py to read the prefix-sum offset from registers instead of the global buffer, preventing stale reads across warps. The reviewer points out that while cur_expert_start needs this register-based calculation, cur_expert_token_num is read from a read-only input tensor and can be loaded directly to avoid unnecessary overhead. They also suggest simplifying the tl.zeros_like call to 0.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| expert_mask = offset_cumsum == cur_expert | ||
| cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum))) | ||
| cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert))) |
There was a problem hiding this comment.
While cur_expert_start needs to be read from registers to avoid the cross-warp stale read on expert_start_loc, num_recv_tokens_per_expert is a read-only input tensor that is never written to in this kernel. Therefore, there is no risk of a stale read or race condition for cur_expert_token_num.
Using tl.sum with tl.where and tl.zeros_like on tokens_per_expert introduces unnecessary overhead (shared memory allocation, synchronization, and reduction across warps) for a value that can be loaded directly and efficiently via a single scalar global load (which is highly cached).
Additionally, we can simplify tl.zeros_like(cumsum) to 0 in the cur_expert_start calculation to make the code cleaner and avoid creating an extra zero tensor.
| expert_mask = offset_cumsum == cur_expert | |
| cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum))) | |
| cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert))) | |
| expert_mask = offset_cumsum == cur_expert | |
| cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, 0)) | |
| cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) |
num_recv_tokens_per_expert is a read-only input and is never written in this kernel, so reading it directly carries no cross-warp stale-read risk. Only the expert_start_loc read-back (written just above) needs the register-based extraction. Reverts the unnecessary reduction for cur_expert_token_num per review feedback.
|
Good catch — |
Summary
Fixes #1361.
_fwd_kernel_ep_scatter_1computes the full exclusive prefix sum over experts in registers, stores the whole array toexpert_start_locwith a vectorizedtl.store, and then immediately reads its own slot back:With
num_experts=256the vectorized store is spread across the program's warps. The subsequent scalar read ofexpert_start_loc[cur_expert]may target a slot written by a different warp, and there is no barrier between the store and the load. Under CUDA's weak memory model the read can observe the stale/uninitialized contents of thetorch.empty-allocated buffer, producing a garbage offset that makes the following unmaskedm_indiceswrite land in unmapped memory:Fix
Read
cur_expert_start(andcur_expert_token_num) directly from the in-registercumsum/tokens_per_expertvectors instead of from the just-written global buffer:Exactly one lane matches
cur_expert, so eachtl.sumreturns that lane's value.tl.sumreduces through shared memory with proper synchronization, so the result is always correct and the racy global round-trip is removed. Thetl.storetoexpert_start_locis kept because_fwd_kernel_ep_scatter_2consumes it downstream.The produced values are identical to the original code on any correct execution — this only removes the data race, so behaviour is unchanged when the race happened not to fire.
🤖 Generated with Claude Code