Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cookbook/transformers/sp_fsdp_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
from twinkle.utils.framework import Torch
from twinkle.kernel import kernelize_model

logger = get_logger()
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
Expand Down Expand Up @@ -68,7 +70,9 @@ def train():
device_mesh=device_mesh,
strategy='native_fsdp',
)

# npu patch
if Torch.is_npu_available():
model = kernelize_model(model, mode='train', device='npu')
lora_config = LoraConfig(target_modules='all-linear')
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
model.set_optimizer('AdamW', lr=1e-4, adapter_name='default')
Expand Down
2 changes: 2 additions & 0 deletions src/twinkle/kernel/monkey_patch_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def _is_fla_available() -> bool:
continue

_module.chunk_gated_delta_rule = mindspeed_fla
# Mark as NPU-patched to prevent it from being overwritten by SP
_module._twinkle_npu_patched = True
patched_instances += 1
logger.debug(
'[NPU] [FLA] Replaced %s(%s).chunk_gated_delta_rule -> MindSpeed',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor:


def _ensure_linear_attention_kernels(mod: torch.nn.Module):
if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None:
mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN
mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE
return False

from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule
"""Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward."""

def _torch_causal_conv1d_fn(
*,
Expand Down Expand Up @@ -110,6 +105,19 @@ def _torch_causal_conv1d_fn(
out = _apply_conv_activation(out[:, :, :seq_len], activation)
return out.transpose(1, 2).contiguous()

# NPU: keep MindSpeed Triton chunk_gated_delta_rule (patched by
# monkey_patch_npu), use torch fallback for causal_conv1d to avoid
# UB overflow in FLA Triton backward kernels on Ascend NPU.
if getattr(mod, '_twinkle_npu_patched', False):
mod.causal_conv1d_fn = _torch_causal_conv1d_fn
return False

if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None:
mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN
mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE
return False

from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule
mod.causal_conv1d_fn = _torch_causal_conv1d_fn
mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2)
Expand Down
Loading