[Bugfix]Fix SP path overriding NPU-patched chunk_gated_delta_rule#208
Open
ys2025-AI wants to merge 3 commits into
Open
[Bugfix]Fix SP path overriding NPU-patched chunk_gated_delta_rule#208ys2025-AI wants to merge 3 commits into
ys2025-AI wants to merge 3 commits into
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces NPU-specific patches and optimizations for sequence parallel training. Specifically, it adds NPU detection and model kernelization in the FSDP dense cookbook, marks NPU-patched modules with a _twinkle_npu_patched flag during monkey patching, and updates the linear attention sequence parallel strategy to respect this flag. When running on Ascend NPU, it falls back to the PyTorch implementation for causal_conv1d to prevent UB overflow in FLA Triton backward kernels while preserving the MindSpeed Triton chunk_gated_delta_rule. There are no review comments, and we have no additional feedback to provide.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR type
PR information
Problem
When sequence parallel (SP) is enabled on Ascend NPU, two issues occur:
NPU patch is silently overridden:
monkey_patch_npu.pyreplaceschunk_gated_delta_rulewith MindSpeed Triton. However,_ensure_linear_attention_kernels()inlinear_attention_sp.pyunconditionally reassigns it back toFLA_CHUNK_GATED_DELTA_RULEon every forward pass. This silently defeats the NPU acceleration for the GDN compute core.UB overflow in causal_conv1d backward:
causal_conv1d_fnis also set toFLA_CAUSAL_CONV1D_FN(Triton). When SP reshapes tensors (local_H = H // sp_size,global_S = S * sp_size), FLA'scausal_conv1d_bwd_kernelautotune configs exceed the Ascend 910 Unified Buffer limit (192KB), causing:Root Cause
This function is called at the start of every
_run_forward, so the NPU patch is effectively disabled in SP mode.Fix
_twinkle_npu_patchedattribute (set bymonkey_patch_npu.py)F.conv1d) forcausal_conv1d_fnto avoid UB overflow; do not touchchunk_gated_delta_ruleso the MindSpeed Triton patch remains effectiveChanges
monkey_patch_npu.py_twinkle_npu_patched = Truewhen patching instanceslinear_attention_sp.py_torch_causal_conv1d_fnto module level; add NPU branch in_ensure_linear_attention_kernelsExperiment results
Verified on 8× Ascend A3 NPUs with Qwen3.5-27B (DP=4, FSDP=4, Ulysses=4).