diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index a12ac06f..a6fd0bdc 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -8,6 +8,8 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize_model logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen3.5-4B' @@ -68,7 +70,9 @@ def train(): device_mesh=device_mesh, strategy='native_fsdp', ) - + # npu patch + if Torch.is_npu_available(): + model = kernelize_model(model, mode='train', device='npu') lora_config = LoraConfig(target_modules='all-linear') model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index acb2a6aa..4ae8771a 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -510,6 +510,8 @@ def _is_fla_available() -> bool: continue _module.chunk_gated_delta_rule = mindspeed_fla + # Mark as NPU-patched to prevent it from being overwritten by SP + _module._twinkle_npu_patched = True patched_instances += 1 logger.debug( '[NPU] [FLA] Replaced %s(%s).chunk_gated_delta_rule -> MindSpeed', diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index ec1f6dab..4a033212 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -66,12 +66,7 @@ def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor: def _ensure_linear_attention_kernels(mod: torch.nn.Module): - if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: - mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN - mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE - return False - - from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule + """Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward.""" def _torch_causal_conv1d_fn( *, @@ -110,6 +105,19 @@ def _torch_causal_conv1d_fn( out = _apply_conv_activation(out[:, :, :seq_len], activation) return out.transpose(1, 2).contiguous() + # NPU: keep MindSpeed Triton chunk_gated_delta_rule (patched by + # monkey_patch_npu), use torch fallback for causal_conv1d to avoid + # UB overflow in FLA Triton backward kernels on Ascend NPU. + if getattr(mod, '_twinkle_npu_patched', False): + mod.causal_conv1d_fn = _torch_causal_conv1d_fn + return False + + if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: + mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + return False + + from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule mod.causal_conv1d_fn = _torch_causal_conv1d_fn mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2)