Skip to content

[Bugfix]Fix SP path overriding NPU-patched chunk_gated_delta_rule#208

Open
ys2025-AI wants to merge 3 commits into
modelscope:mainfrom
ys2025-AI:main
Open

[Bugfix]Fix SP path overriding NPU-patched chunk_gated_delta_rule#208
ys2025-AI wants to merge 3 commits into
modelscope:mainfrom
ys2025-AI:main

Conversation

@ys2025-AI
Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Problem

When sequence parallel (SP) is enabled on Ascend NPU, two issues occur:

  1. NPU patch is silently overridden: monkey_patch_npu.py replaces chunk_gated_delta_rule with MindSpeed Triton. However, _ensure_linear_attention_kernels() in linear_attention_sp.py unconditionally reassigns it back to FLA_CHUNK_GATED_DELTA_RULE on every forward pass. This silently defeats the NPU acceleration for the GDN compute core.

  2. UB overflow in causal_conv1d backward: causal_conv1d_fn is also set to FLA_CAUSAL_CONV1D_FN (Triton). When SP reshapes tensors (local_H = H // sp_size, global_S = S * sp_size), FLA's causal_conv1d_bwd_kernel autotune configs exceed the Ascend 910 Unified Buffer limit (192KB), causing:

error: ub overflow, requires 1840384 bits while 1572864 bits available!

Root Cause

# In linear_attention_sp.py _ensure_linear_attention_kernels():
if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None:
    mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN       # UB overflow on NPU
    mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE  # Overrides NPU patch!

This function is called at the start of every _run_forward, so the NPU patch is effectively disabled in SP mode.

Fix

  • Detect NPU patch state via _twinkle_npu_patched attribute (set by monkey_patch_npu.py)
  • On NPU: use torch fallback (F.conv1d) for causal_conv1d_fn to avoid UB overflow; do not touch chunk_gated_delta_rule so the MindSpeed Triton patch remains effective
  • Non-NPU behavior is unchanged

Changes

File Change
monkey_patch_npu.py Set _twinkle_npu_patched = True when patching instances
linear_attention_sp.py Extract _torch_causal_conv1d_fn to module level; add NPU branch in _ensure_linear_attention_kernels

Experiment results

Verified on 8× Ascend A3 NPUs with Qwen3.5-27B (DP=4, FSDP=4, Ulysses=4).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces NPU-specific patches and optimizations for sequence parallel training. Specifically, it adds NPU detection and model kernelization in the FSDP dense cookbook, marks NPU-patched modules with a _twinkle_npu_patched flag during monkey patching, and updates the linear attention sequence parallel strategy to respect this flag. When running on Ascend NPU, it falls back to the PyTorch implementation for causal_conv1d to prevent UB overflow in FLA Triton backward kernels while preserving the MindSpeed Triton chunk_gated_delta_rule. There are no review comments, and we have no additional feedback to provide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant