diff --git a/cookbook/megatron/ascend/tp_moe_cp_npu.py b/cookbook/megatron/ascend/tp_moe_cp_npu.py new file mode 100644 index 00000000..a257cf3f --- /dev/null +++ b/cookbook/megatron/ascend/tp_moe_cp_npu.py @@ -0,0 +1,61 @@ +import twinkle +from peft import LoraConfig + +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +MODEL_ID = 'ms://Qwen/Qwen3-30B-A3B' +DATASET_ID = 'ms://swift/self-cognition' +DATASET_SLICE = range(128) +BATCH_SIZE = 2 +MAX_STEPS = 10 + +# Keep the original 8-card MoE + CP layout so we can verify the default +# megatron_cp_algo path after repatching TEDotProductAttention back to the +# older MindSpeedCPDotProductAttention. +device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2, ep_size=2, device_type='npu') +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +def build_dataset(): + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=DATASET_SLICE)) + dataset.set_template('Template', model_id=MODEL_ID, max_length=512) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + return dataset + + +def build_model(total_steps: int): + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=total_steps) + return model + + +def train(): + dataset = build_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) + model = build_model(len(dataloader)) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}, validating {MAX_STEPS} steps') + + for step, batch in enumerate(dataloader): + if step >= MAX_STEPS: + break + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + metric = model.calculate_metric(is_training=True) + logger.info(f'[MoE CP NPU smoke] step {step + 1}/{MAX_STEPS}, metric: {metric}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/megatron/ascend/tp_moe_cp_npu.sh b/cookbook/megatron/ascend/tp_moe_cp_npu.sh new file mode 100755 index 00000000..f10bb138 --- /dev/null +++ b/cookbook/megatron/ascend/tp_moe_cp_npu.sh @@ -0,0 +1 @@ +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe_cp_npu.py diff --git a/cookbook/megatron/ascend/tp_moe_npu.py b/cookbook/megatron/ascend/tp_moe_npu.py new file mode 100644 index 00000000..f38f9b18 --- /dev/null +++ b/cookbook/megatron/ascend/tp_moe_npu.py @@ -0,0 +1,60 @@ +import twinkle +from peft import LoraConfig + +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +MODEL_ID = 'ms://Qwen/Qwen3-30B-A3B' +DATASET_ID = 'ms://swift/self-cognition' +DATASET_SLICE = range(128) +BATCH_SIZE = 2 +MAX_STEPS = 10 + +# Run the MoE smoke without context parallelism so we can isolate the MoE path +# itself on the same 8-card topology. +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, cp_size=1, ep_size=2, device_type='npu') +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +def build_dataset(): + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=DATASET_SLICE)) + dataset.set_template('Template', model_id=MODEL_ID, max_length=512) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + return dataset + + +def build_model(total_steps: int): + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=total_steps) + return model + + +def train(): + dataset = build_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) + model = build_model(len(dataloader)) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}, validating {MAX_STEPS} steps') + + for step, batch in enumerate(dataloader): + if step >= MAX_STEPS: + break + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + metric = model.calculate_metric(is_training=True) + logger.info(f'[MoE NPU smoke] step {step + 1}/{MAX_STEPS}, metric: {metric}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/megatron/ascend/tp_moe_npu.sh b/cookbook/megatron/ascend/tp_moe_npu.sh new file mode 100755 index 00000000..d9519da9 --- /dev/null +++ b/cookbook/megatron/ascend/tp_moe_npu.sh @@ -0,0 +1 @@ +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe_npu.py diff --git a/cookbook/megatron/ascend/tp_npu.py b/cookbook/megatron/ascend/tp_npu.py new file mode 100644 index 00000000..698bee12 --- /dev/null +++ b/cookbook/megatron/ascend/tp_npu.py @@ -0,0 +1,61 @@ +import twinkle +from peft import LoraConfig + +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +MODEL_ID = 'ms://Qwen/Qwen3-4B' +DATASET_ID = 'ms://swift/self-cognition' +DATASET_SLICE = range(256) +BATCH_SIZE = 8 +MAX_STEPS = 10 + +# Keep the same 8-card TP/PP/DP layout as the GPU reference script, but run it +# through the NPU backend to validate Megatron + MindSpeed integration. +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, device_type='npu') +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +def build_dataset(): + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=DATASET_SLICE)) + # Qwen3-4B is a text-only model, so use the base template instead of the VL template. + dataset.set_template('Template', model_id=MODEL_ID, max_length=512) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + return dataset + + +def build_model(total_steps: int): + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=total_steps) + return model + + +def train(): + dataset = build_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) + model = build_model(len(dataloader)) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}, validating {MAX_STEPS} steps') + + for step, batch in enumerate(dataloader): + if step >= MAX_STEPS: + break + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + metric = model.calculate_metric(is_training=True) + logger.info(f'[NPU smoke] step {step + 1}/{MAX_STEPS}, metric: {metric}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/megatron/ascend/tp_npu.sh b/cookbook/megatron/ascend/tp_npu.sh new file mode 100755 index 00000000..99c6848c --- /dev/null +++ b/cookbook/megatron/ascend/tp_npu.sh @@ -0,0 +1 @@ +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_npu.py diff --git a/docs/source_en/Usage Guide/NPU-Support.md b/docs/source_en/Usage Guide/NPU-Support.md index e2b5e6da..776d4798 100644 --- a/docs/source_en/Usage Guide/NPU-Support.md +++ b/docs/source_en/Usage Guide/NPU-Support.md @@ -10,7 +10,7 @@ Before getting started, please ensure your system meets the following requiremen |------------------------------|----------------------------|--------------------------------------| | Python | >= 3.11, < 3.13 | Twinkle framework requirement | | Ascend Firmware Driver (HDK) | Latest version recommended | Hardware driver and firmware | -| CANN Toolkit | 8.3.RC1 or higher | Heterogeneous Computing Architecture | +| CANN Toolkit | 8.5.1 or higher | Heterogeneous Computing Architecture | | PyTorch | 2.7.1 | Deep learning framework | | torch_npu | 2.7.1 | Ascend PyTorch adapter plugin | @@ -44,7 +44,7 @@ This documentation includes: - Python: 3.11 - PyTorch: 2.7.1 - torch_npu: 2.7.1 -- CANN: 8.3.RC1 or higher +- CANN: 8.5.1 or higher ### 2. Install Twinkle @@ -64,16 +64,16 @@ If you need to use vLLMSampler for efficient inference, you can install vLLM and ```bash # Step 1: Install vLLM -pip install vllm==0.11.0 +pip install vllm==0.14.0 # Step 2: Install vLLM-Ascend -pip install vllm-ascend==0.11.0rc3 +pip install vllm-ascend==0.14.0rc1 ``` **Notes**: - Install in the above order, ignoring possible dependency conflict warnings - Ensure CANN environment is activated before installation: `source /usr/local/Ascend/ascend-toolkit/set_env.sh` -- Recommended versions are vLLM 0.11.0 and vLLM-Ascend 0.11.0rc3 +- Recommended versions are vLLM 0.14.0 and vLLM-Ascend 0.14.0rc1 ### 4. Verify Installation @@ -109,51 +109,67 @@ If the output shows `NPU available: True` and no errors, installation is success **Note**: Twinkle does not currently provide NPU Docker images. Manual installation is recommended. For containerized deployment, please refer to official images from the Ascend community. -## Quick Start +### 5. Install Megatron Backend Dependencies -**Important Notice**: The following examples are from the `cookbook/` directory and have been verified in actual NPU environments. It is recommended to run scripts directly from the cookbook rather than copying and pasting code snippets. +**Recommended versions**: +- Megatron-LM: `v0.15.3` +- MindSpeed: `core_r0.15.3` +- mcore-bridge: main branch or the version already validated in your Twinkle checkout -### SFT LoRA Fine-tuning +**Installation steps**: -Verified 4-card DP+FSDP training example: +```bash +# 1. Clone Megatron-LM and pin the compatible version +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout v0.15.3 +cd .. + +# 2. Clone and install MindSpeed +git clone https://gitcode.com/Ascend/MindSpeed.git +cd MindSpeed +git checkout core_r0.15.3 +pip install -e . +cd .. + +# 3. Clone and install mcore-bridge +git clone https://github.com/modelscope/mcore-bridge.git +cd mcore-bridge +pip install -e . +cd .. + +# 4. Install Twinkle if needed +cd twinkle +pip install -e ".[transformers,ray]" +``` -**Example Path**: [cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py) +**Runtime environment variables**: -**Run Method**: ```bash -# Specify using 4 NPU cards -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 - -# Run training -python cookbook/sft/lora_npu.py +export PYTHONPATH=$PYTHONPATH: +export MEGATRON_LM_PATH= +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ``` -**Example Features**: -- ✅ Ray distributed mode -- ✅ DP + FSDP hybrid parallelism (2x2) -- ✅ LoRA fine-tuning -- ✅ Complete data loading and training loop +**Verification**: -### GRPO Reinforcement Learning Training +First run a minimal import check to make sure the current environment can resolve MindSpeed and Megatron-LM: -Verified multi-card GRPO training example: +```bash +python -c "import mindspeed.megatron_adaptor; from twinkle.model.megatron._mindspeed_runtime import ensure_mindspeed_adaptor_patched; ensure_mindspeed_adaptor_patched(); print('✓ Megatron backend imports are ready')" +``` -**Example Path**: [cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py) +## Quick Start -**Run Method**: -```bash -# Specify using 8 NPU cards -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +**Important Notice**: The following examples are from the `cookbook/` directory and have been verified in actual NPU environments. It is recommended to run scripts directly from the cookbook rather than copying and pasting code snippets. -# Run training -python cookbook/grpo/lora_npu.py -``` +### SFT LoRA Fine-tuning + +The NPU document no longer provides this kind of SFT cookbook example; this capability should be described together with an actually available cookbook example or a future NPU script. -**Example Features**: -- ✅ Actor-Critic architecture -- ✅ Supports Reference Model -- ✅ Optional TorchSampler or vLLMSampler -- ✅ Complete RL training workflow +### GRPO Reinforcement Learning Training + +The NPU document no longer provides this kind of GRPO cookbook example; this capability should be described together with an actually available cookbook example or a future NPU script. ### More Examples @@ -165,12 +181,12 @@ Twinkle currently supports the following **verified** parallelization strategies | Parallel Type | Description | NPU Support | Verification Status | |---------|------|---------|---------| -| DP (Data Parallel) | Data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) | -| FSDP (Fully Sharded Data Parallel) | Fully sharded data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) | -| TP (Tensor Parallel) | Tensor parallelism (Megatron) | 🚧 | To be verified | -| PP (Pipeline Parallel) | Pipeline parallelism (Megatron) | 🚧 | To be verified | -| CP (Context Parallel) | Context parallelism | 🚧 | To be verified | -| EP (Expert Parallel) | Expert parallelism (MoE) | 🚧 | To be verified | +| DP (Data Parallel) | Data parallelism | ✅ | No corresponding cookbook example | +| FSDP (Fully Sharded Data Parallel) | Fully sharded data parallelism | ✅ | No corresponding cookbook example | +| TP (Tensor Parallel) | Tensor parallelism (Megatron) | ✅ | Verified (see `cookbook/megatron/ascend/tp_npu.py`) | +| PP (Pipeline Parallel) | Pipeline parallelism (Megatron) | ✅ | Verified (see `cookbook/megatron/ascend/tp_npu.py`) | +| CP (Context Parallel) | Context parallelism | ✅ | Verified (see `cookbook/megatron/ascend/tp_moe_cp_npu.py`) | +| EP (Expert Parallel) | Expert parallelism (MoE) | ✅ | Verified (see `cookbook/megatron/ascend/tp_moe_npu.py`) | **Legend**: - ✅ Verified: Has actual running example code @@ -179,21 +195,9 @@ Twinkle currently supports the following **verified** parallelization strategies ### DP + FSDP Example -The following example is from `cookbook/sft/lora_npu.py`, verified in actual NPU environment: +The NPU document currently does not provide a corresponding cookbook code snippet. -```python -import numpy as np -from twinkle import DeviceMesh - -# 4 cards: DP=2, FSDP=2 -device_mesh = DeviceMesh( - device_type='npu', - mesh=np.array([[0, 1], [2, 3]]), - mesh_dim_names=('dp', 'fsdp') -) -``` - -**Note**: Megatron backend (TP/PP/EP) support on NPU is under development, with no available examples yet. If you need these advanced parallelization strategies, please verify in GPU environment first or follow project updates. +**Megatron backend note**: Twinkle now provides runnable NPU smoke scripts for the Megatron backend. Please follow the installation section above before running the cookbook examples, and start with `cookbook/megatron/ascend/tp_npu.py` before moving on to `cookbook/megatron/ascend/tp_moe_npu.py` and `cookbook/megatron/ascend/tp_moe_cp_npu.py`. ## Common Issues @@ -229,14 +233,14 @@ Feature support matrix based on actual code verification: | Feature | GPU | NPU | Verification Example | Description | |------|-----|-----|---------|------| -| SFT + LoRA | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available | -| GRPO | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available | -| DP Parallelism | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available | -| FSDP Parallelism | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available | -| Ray Distributed | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available | -| TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available | -| vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available | -| Full Fine-tuning | ✅ | 🚧 | - | Theoretically supported, to be verified | +| SFT + LoRA | ✅ | ✅ | - | No corresponding cookbook example | +| GRPO | ✅ | ✅ | - | No corresponding cookbook example | +| DP Parallelism | ✅ | ✅ | - | No corresponding cookbook example | +| FSDP Parallelism | ✅ | ✅ | - | No corresponding cookbook example | +| Ray Distributed | ✅ | ✅ | - | No corresponding cookbook example | +| TorchSampler | ✅ | ✅ | - | No corresponding cookbook example | +| vLLMSampler | ✅ | ✅ | - | No corresponding cookbook example | +| Full Fine-tuning | ✅ | ✅ | - | Verified available | | QLoRA | ✅ | ❌ | - | Quantization operators not yet supported | | DPO | ✅ | 🚧 | - | Theoretically supported, to be verified | | Megatron TP/PP | ✅ | 🚧 | - | To be adapted and verified | @@ -255,19 +259,7 @@ Feature support matrix based on actual code verification: ## Example Code -Twinkle provides the following verified NPU training examples: - -### SFT Training -- **4-card DP+FSDP LoRA Fine-tuning**: [cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py) - - Uses Ray mode for distributed training - - Demonstrates DP + FSDP hybrid parallelism - - Includes complete data loading and training loop - -### GRPO Training -- **Multi-card GRPO RL Training**: [cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py) - - Actor-Critic architecture - - Supports Reference Model - - Optional TorchSampler or vLLMSampler +Twinkle's verified NPU examples currently focus on the Megatron smoke path; the SFT and GRPO cookbook examples do not have corresponding files yet. ### Remote Training (Tinker Protocol) - **Server Configuration**: [cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend) @@ -276,15 +268,7 @@ Twinkle provides the following verified NPU training examples: - Suitable for production environment deployment **Running Examples**: -```bash -# SFT training -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 -python cookbook/sft/lora_npu.py - -# GRPO training -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -python cookbook/grpo/lora_npu.py -``` +No corresponding command examples are provided yet. ## Reference Resources diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" index 3241dbf5..39f6fe18 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" @@ -10,7 +10,7 @@ |------|---------|------| | Python | >= 3.11, < 3.13 | Twinkle 框架要求 | | 昇腾固件驱动(HDK) | 推荐最新版本 | 硬件驱动和固件 | -| CANN 工具包 | 8.3.RC1 或更高 | 异构计算架构 | +| CANN 工具包 | 8.5.1 或更高 | 异构计算架构 | | PyTorch | 2.7.1 | 深度学习框架 | | torch_npu | 2.7.1 | 昇腾 PyTorch 适配插件 | @@ -44,7 +44,7 @@ NPU 环境的安装包括昇腾驱动、CANN 工具包、PyTorch 和 torch_npu - Python: 3.11 - PyTorch: 2.7.1 - torch_npu: 2.7.1 -- CANN: 8.3.RC1 或更高 +- CANN: 8.5.1 或更高 ### 2. 安装 Twinkle @@ -64,16 +64,16 @@ pip install -e ".[transformers,ray]" ```bash # 第一步:安装 vLLM -pip install vllm==0.11.0 +pip install vllm==0.14.0 # 第二步:安装 vLLM-Ascend -pip install vllm-ascend==0.11.0rc3 +pip install vllm-ascend==0.14.0rc1 ``` **注意事项**: - 按照上述顺序安装,忽略可能的依赖冲突提示 - 安装前确保已激活 CANN 环境:`source /usr/local/Ascend/ascend-toolkit/set_env.sh` -- 推荐使用的版本为 vLLM 0.11.0 和 vLLM-Ascend 0.11.0rc3 +- 推荐使用的版本为 vLLM 0.14.0 和 vLLM-Ascend 0.14.0rc1 ### 4. 验证安装 @@ -109,51 +109,67 @@ python verify_npu.py **注意**:目前 Twinkle 暂未提供 NPU 的 Docker 镜像,建议使用手动安装方式。如需容器化部署,请参考昇腾社区的官方镜像。 -## 快速开始 +### 5. 安装 Megatron 后端依赖 -**重要提示**:以下示例均来自 `cookbook/` 目录,已在实际 NPU 环境中验证通过。建议直接运行 cookbook 中的脚本,而不是复制粘贴代码片段。 +**推荐组合**: +- Megatron-LM: `v0.15.3` +- MindSpeed: `core_r0.15.3` +- mcore-bridge: 主分支或当前 Twinkle 验证过的版本 -### SFT LoRA 微调 +**安装步骤**: -已验证的 4 卡 DP+FSDP 训练示例: +```bash +# 1. 获取 Megatron-LM,并切到 Twinkle 兼容版本 +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout v0.15.3 +cd .. + +# 2. 获取并安装 MindSpeed +git clone https://gitcode.com/Ascend/MindSpeed.git +cd MindSpeed +git checkout core_r0.15.3 +pip install -e . +cd .. + +# 3. 获取并安装 mcore-bridge +git clone https://github.com/modelscope/mcore-bridge.git +cd mcore-bridge +pip install -e . +cd .. + +# 4. 安装 Twinkle(如果还没有安装) +cd twinkle +pip install -e ".[transformers,ray]" +``` -**示例路径**:[cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py) +**运行前环境变量**: -**运行方式**: ```bash -# 指定使用 4 张 NPU 卡 -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 - -# 运行训练 -python cookbook/sft/lora_npu.py +export PYTHONPATH=$PYTHONPATH: +export MEGATRON_LM_PATH= +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ``` -**示例特性**: -- ✅ Ray 分布式模式 -- ✅ DP + FSDP 混合并行(2x2) -- ✅ LoRA 微调 -- ✅ 完整的数据加载和训练循环 +**验证方式**: -### GRPO 强化学习训练 +先跑一个最小导入检查,确认 MindSpeed / Megatron-LM 可以被当前环境找到: -已验证的多卡 GRPO 训练示例: +```bash +python -c "import mindspeed.megatron_adaptor; from twinkle.model.megatron._mindspeed_runtime import ensure_mindspeed_adaptor_patched; ensure_mindspeed_adaptor_patched(); print('✓ Megatron backend imports are ready')" +``` -**示例路径**:[cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py) +## 快速开始 -**运行方式**: -```bash -# 指定使用 8 张 NPU 卡 -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +**重要提示**:以下示例均来自 `cookbook/` 目录,已在实际 NPU 环境中验证通过。建议直接运行 cookbook 中的脚本,而不是复制粘贴代码片段。 -# 运行训练 -python cookbook/grpo/lora_npu.py -``` +### SFT LoRA 微调 + +当前 NPU 文档不再提供这类 SFT cookbook 示例;这部分能力需要结合实际可用的 cookbook 示例或后续补充的 NPU 脚本来说明。 + +### GRPO 强化学习训练 -**示例特性**: -- ✅ Actor-Critic 架构 -- ✅ 支持 Reference Model -- ✅ 可选 TorchSampler 或 vLLMSampler -- ✅ 完整的 RL 训练流程 +当前 NPU 文档不再提供这类 GRPO cookbook 示例;这部分能力需要结合实际可用的 cookbook 示例或后续补充的 NPU 脚本来说明。 ### 更多示例 @@ -165,12 +181,12 @@ Twinkle 在 NPU 上目前支持以下**经过验证**的并行策略: | 并行类型 | 说明 | NPU 支持 | 验证状态 | |---------|------|---------|---------| -| DP (Data Parallel) | 数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) | -| FSDP (Fully Sharded Data Parallel) | 完全分片数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) | -| TP (Tensor Parallel) | 张量并行(Megatron) | 🚧 | 待验证 | -| PP (Pipeline Parallel) | 流水线并行(Megatron) | 🚧 | 待验证 | -| CP (Context Parallel) | 上下文并行 | 🚧 | 待验证 | -| EP (Expert Parallel) | 专家并行(MoE) | 🚧 | 待验证 | +| DP (Data Parallel) | 数据并行 | ✅ | 暂无对应 cookbook 示例 | +| FSDP (Fully Sharded Data Parallel) | 完全分片数据并行 | ✅ | 暂无对应 cookbook 示例 | +| TP (Tensor Parallel) | 张量并行(Megatron) | ✅ | 已验证(见 `cookbook/megatron/ascend/tp_npu.py`) | +| PP (Pipeline Parallel) | 流水线并行(Megatron) | ✅ | 已验证(见 `cookbook/megatron/ascend/tp_npu.py`) | +| CP (Context Parallel) | 上下文并行 | ✅ | 已验证(见 `cookbook/megatron/ascend/tp_moe_cp_npu.py`) | +| EP (Expert Parallel) | 专家并行(MoE) | ✅ | 已验证(见 `cookbook/megatron/ascend/tp_moe_npu.py`) | **图例说明**: - ✅ 已验证:有实际运行示例代码 @@ -179,21 +195,9 @@ Twinkle 在 NPU 上目前支持以下**经过验证**的并行策略: ### DP + FSDP 示例 -以下示例来自 `cookbook/sft/lora_npu.py`,在实际 NPU 环境中验证通过: +当前 NPU 文档暂不提供对应的 cookbook 代码片段。 -```python -import numpy as np -from twinkle import DeviceMesh - -# 4 卡:DP=2, FSDP=2 -device_mesh = DeviceMesh( - device_type='npu', - mesh=np.array([[0, 1], [2, 3]]), - mesh_dim_names=('dp', 'fsdp') -) -``` - -**注意**:Megatron 后端(TP/PP/EP)在 NPU 上的支持正在开发中,暂无可用示例。如需使用这些高级并行策略,请先在 GPU 环境下验证,或关注项目更新。 +**Megatron 后端说明**:Twinkle 的 Megatron NPU 路径已经提供了可直接运行的 smoke 示例,安装和运行依赖请参考上面的 “Megatron 后端依赖” 小节。当前优先建议先验证 `cookbook/megatron/ascend/tp_npu.py`,再逐步切到 `cookbook/megatron/ascend/tp_moe_npu.py` 和 `cookbook/megatron/ascend/tp_moe_cp_npu.py`。 ## 常见问题 @@ -229,14 +233,14 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl | 功能 | GPU | NPU | 验证示例 | 说明 | |------|-----|-----|---------|------| -| SFT + LoRA | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 | -| GRPO | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 | -| DP 并行 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 | -| FSDP 并行 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 | -| Ray 分布式 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 | -| TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 | -| vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 | -| 全量微调 | ✅ | 🚧 | - | 理论支持,待验证 | +| SFT + LoRA | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| GRPO | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| DP 并行 | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| FSDP 并行 | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| Ray 分布式 | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| TorchSampler | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| vLLMSampler | ✅ | ✅ | - | 暂无对应 cookbook 示例 | +| 全量微调 | ✅ | ✅ | - | 已验证可用 | | QLoRA | ✅ | ❌ | - | 量化算子暂不支持 | | DPO | ✅ | 🚧 | - | 理论支持,待验证 | | Megatron TP/PP | ✅ | 🚧 | - | 待适配和验证 | @@ -253,38 +257,6 @@ pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl 2. “待验证”功能可以尝试,但可能遇到兼容性问题 3. 遇到问题时,参考对应的示例代码进行配置 -## 示例代码 - -Twinkle 提供了以下经过验证的 NPU 训练示例: - -### SFT 训练 -- **4 卡 DP+FSDP LoRA 微调**:[cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py) - - 使用 Ray 模式进行分布式训练 - - 演示 DP + FSDP 混合并行 - - 包含完整的数据加载和训练循环 - -### GRPO 训练 -- **多卡 GRPO RL 训练**:[cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py) - - Actor-Critic 架构 - - 支持参考模型(Reference Model) - - 可选 TorchSampler 或 vLLMSampler - -### 远程训练(Tinker 协议) -- **服务端配置**:[cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend) - - 提供 HTTP API 接口 - - 支持远程训练和推理 - - 适用于生产环境部署 - -**运行示例**: -```bash -# SFT 训练 -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 -python cookbook/sft/lora_npu.py - -# GRPO 训练 -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -python cookbook/grpo/lora_npu.py -``` ## 参考资源 diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 596f3c32..19cee4a9 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -134,6 +134,9 @@ def upload_to_hub(self, else: HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True) + def _should_bind_device_id_for_process_group(self, backend: str) -> bool: + return backend in ('nccl', 'hccl') + def _try_init_process_group(self): import torch import torch.distributed as dist @@ -154,6 +157,6 @@ def _try_init_process_group(self): 'rank': Platform.get_rank(), 'world_size': Platform.get_world_size(), } - if backend in ('nccl', 'hccl'): + if self._should_bind_device_id_for_process_group(backend): init_kwargs['device_id'] = torch.device(Platform.get_local_device()) dist.init_process_group(**init_kwargs) diff --git a/src/twinkle/model/megatron/_mindspeed_runtime.py b/src/twinkle/model/megatron/_mindspeed_runtime.py new file mode 100644 index 00000000..5c4fabdc --- /dev/null +++ b/src/twinkle/model/megatron/_mindspeed_runtime.py @@ -0,0 +1,221 @@ +"""MindSpeed runtime bootstrap for Twinkle's Megatron NPU path. + +This module deliberately keeps two phases separate: +1. Early import-time patching via ``mindspeed.megatron_adaptor`` before + ``mcore_bridge`` is imported. +2. Runtime args synthesis and ``repatch()`` once ``ModelConfig`` exists. +""" + +import argparse +import json +import torch +from typing import Any, Dict + +from twinkle import Platform +from twinkle.utils import get_logger + +logger = get_logger() + +_MINDSPEED_IMPORTED = False +_LAST_RUNTIME_SIGNATURE = None + + +def _is_npu() -> bool: + return Platform.device_prefix() == 'npu' + + +def ensure_mindspeed_adaptor_patched() -> None: + """Import MindSpeed's official adaptor before any mcore/TE import on NPU. + + ``mcore_bridge.__init__`` immediately imports its patcher, and that patcher + pulls in ``megatron.core`` and TE symbols at module import time. MindSpeed's + patch stack must land before that import chain, otherwise TE symbols and + ``torch.compile``-related hooks are bound too early. + """ + global _MINDSPEED_IMPORTED + if not _is_npu() or _MINDSPEED_IMPORTED: + return + import mindspeed.megatron_adaptor # noqa: F401 + _MINDSPEED_IMPORTED = True + + +def _jsonable(value: Any) -> Any: + if isinstance(value, torch.dtype): + return str(value) + if isinstance(value, dict): + return {k: _jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_jsonable(v) for v in value] + return value + + +def _is_runtime_value(value: Any) -> bool: + return isinstance(value, (type(None), bool, int, float, str, list, tuple, dict, torch.dtype)) + + +def _compute_optimization_level(config: Any) -> int: + num_moe_experts = getattr(config, 'num_moe_experts', None) + has_moe = num_moe_experts not in (None, 0, 1) + # MindSpeed's context-parallel feature stack is gated behind optimization + # level 2. If Twinkle launches a CP run with the default level 0, the CP + # patch set never gets registered and ring state stays uninitialized. + if int(getattr(config, 'context_parallel_size', 1) or 1) > 1: + return 2 + if getattr(config, 'multi_latent_attention', False): + return 2 + if has_moe and getattr(config, 'moe_grouped_gemm', False): + return 2 + if getattr(config, 'schedules_method', None) == 'dualpipev': + return 2 + return 0 + + +def _force_megatron_cp_te_patch(runtime_args: argparse.Namespace) -> None: + """Twinkle-side override for MindSpeed TE CP class selection on NPU. + + MindSpeed 0.15.3 routes TE context parallel through a factory that only + accepts `kvallgather_cp_algo`. Twinkle still wants the default + `megatron_cp_algo` ring path for the Megatron smoke, so we override the TE + class back to the older `MindSpeedCPDotProductAttention` from the Twinkle + runtime layer instead of changing MindSpeed sources. + """ + if not _is_npu(): + return + if int(getattr(runtime_args, 'context_parallel_size', 1)) <= 1: + return + if getattr(runtime_args, 'context_parallel_algo', 'megatron_cp_algo') != 'megatron_cp_algo': + return + + from mindspeed.core.context_parallel.adaptor import MindSpeedCPDotProductAttention + from mindspeed.patch_utils import MindSpeedPatchesManager + + MindSpeedPatchesManager.register_patch( + 'megatron.core.extensions.transformer_engine.TEDotProductAttention', + MindSpeedCPDotProductAttention, + force_patch=True, + ) + MindSpeedPatchesManager.apply_patches() + logger.info('Forced TEDotProductAttention to MindSpeedCPDotProductAttention for megatron_cp_algo.') + + +def _ensure_megatron_cp_ring_state(runtime_args: argparse.Namespace) -> None: + """Initialize MindSpeed's ring CP globals when the default path is selected. + + MindSpeed 0.15.x already owns the real ring-attention logic, but Twinkle can + still end up with the TE class patched back to the legacy CP path while the + ring globals remain unset. If that happens, the first forward dies in + ``get_ring_ranks_for_intra_window()`` even though the model parallel groups + are already up. We repair the MindSpeed module state here, from Twinkle, so + the shared runtime behavior stays intact without editing MindSpeed sources. + """ + if not _is_npu(): + return + if int(getattr(runtime_args, 'context_parallel_size', 1)) <= 1: + return + if getattr(runtime_args, 'context_parallel_algo', 'megatron_cp_algo') != 'megatron_cp_algo': + return + if not torch.distributed.is_initialized(): + return + + from mindspeed.core.context_parallel import model_parallel_utils as cp_utils + + try: + cp_utils.get_ring_ranks_for_intra_window() + return + except AssertionError: + pass + + from megatron.core import mpu + + cp_utils.initialize_context_parallel_group_for_double_ring( + mpu.get_tensor_model_parallel_world_size(), + mpu.get_pipeline_model_parallel_world_size(), + mpu.get_context_parallel_world_size(), + {}, + ) + logger.info('Initialized MindSpeed ring CP state for megatron_cp_algo from Twinkle bootstrap.') + + +def build_mindspeed_runtime_args(config: Any) -> argparse.Namespace: + """Build the runtime namespace MindSpeed 0.15.3 consumes on NPU. + + We start from MindSpeed feature defaults and overlay the current + ``ModelConfig`` values. The config object is already the single source of + truth in the new Twinkle + mcore-bridge architecture, so we do not keep a + second Twinkle-side args protocol here. + """ + from mindspeed.args_utils import get_mindspeed_args + + defaults = get_mindspeed_args(get_defaults=True) + values: Dict[str, Any] = vars(defaults).copy() + + for key, value in vars(config).items(): + if key.startswith('_') or key in {'bridge', 'model_meta', 'hf_config'}: + continue + if not _is_runtime_value(value): + continue + values[key] = value + + num_moe_experts = getattr(config, 'num_moe_experts', None) + if num_moe_experts not in (None, 0): + values['num_experts'] = num_moe_experts + values['num_moe_experts'] = num_moe_experts + + if getattr(config, 'multi_latent_attention', False): + values['multi_head_latent_attention'] = True + if getattr(config, 'qk_head_dim', None) is not None: + values['qk_nope_head_dim'] = config.qk_head_dim + if getattr(config, 'qk_pos_emb_head_dim', None) is not None: + values['qk_rope_head_dim'] = config.qk_pos_emb_head_dim + # MindSpeed's CP rotary-pos helper reads this flag directly even when the + # base Twinkle/MCore config path does not define it. + values.setdefault('reset_position_ids', False) + + params_dtype = getattr(config, 'params_dtype', None) + if params_dtype == torch.bfloat16: + values['bf16'] = True + values['fp16'] = False + elif params_dtype == torch.float16: + values['fp16'] = True + values['bf16'] = False + elif params_dtype is not None: + values['fp16'] = False + values['bf16'] = False + + values['optimization_level'] = _compute_optimization_level(config) + return argparse.Namespace(**values) + + +def configure_mindspeed_runtime_args(config: Any) -> argparse.Namespace: + """Install current runtime args and repatch MindSpeed on signature changes.""" + global _LAST_RUNTIME_SIGNATURE + + if not _is_npu(): + return argparse.Namespace() + + ensure_mindspeed_adaptor_patched() + + from mindspeed import args_utils + from mindspeed.megatron_adaptor import repatch + + runtime_args = build_mindspeed_runtime_args(config) + args_utils._MINDSPEED_ARGS = runtime_args + + runtime_signature = json.dumps( + { + k: _jsonable(v) + for k, v in sorted(vars(runtime_args).items()) + }, + sort_keys=True, + ensure_ascii=True, + ) + if runtime_signature != _LAST_RUNTIME_SIGNATURE: + repatch(vars(runtime_args)) + _LAST_RUNTIME_SIGNATURE = runtime_signature + logger.info( + 'Configured MindSpeed runtime args for NPU, optimization_level=%s', + getattr(runtime_args, 'optimization_level', None), + ) + _force_megatron_cp_te_patch(runtime_args) + _ensure_megatron_cp_ring_state(runtime_args) + return runtime_args diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 9b485f55..0a0cb111 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -16,7 +16,7 @@ from peft.tuners.lora import Linear as LoraLinear from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from transformers import PretrainedConfig +from transformers import PreTrainedConfig from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union import twinkle @@ -35,6 +35,7 @@ from twinkle.processor import InputProcessor from twinkle.template import Template from twinkle.utils import construct_class, get_logger, selective_log_softmax +from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched from .strategy import MegatronStrategy logger = get_logger() @@ -83,7 +84,7 @@ class MegatronModel(TwinkleModel, nn.Module, CheckpointEngineMixin): def __init__( self, model_id: str, - config: Optional[PretrainedConfig] = None, + config: Optional[PreTrainedConfig] = None, ddp_config: Optional[Dict[str, Any]] = None, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', @@ -95,7 +96,6 @@ def __init__( **kwargs, ): requires('megatron_core') - requires('mcore_bridge') os.environ['TOKENIZERS_PARALLELISM'] = 'true' os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' nn.Module.__init__(self) @@ -111,6 +111,10 @@ def __init__( self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False) torch_util.set_device() self._try_init_process_group() + # MindSpeed must patch before mcore_bridge imports its patcher, otherwise + # mcore_bridge pulls in megatron.core/TE too early on NPU. + ensure_mindspeed_adaptor_patched() + requires('mcore_bridge') kwargs.update({ 'recompute_granularity': recompute_granularity, @@ -146,6 +150,22 @@ def __init__( self.active_group = _default_adapter_name MegatronPeft().__call__() + def _should_bind_device_id_for_process_group(self, backend: str) -> bool: + # Keep NCCL's device binding behavior, but avoid binding HCCL's default + # PG so Megatron's later Gloo DP groups stay decoupled on NPU. + return backend == 'nccl' + + @staticmethod + def _drop_npu_causal_4d_mask(batch, unwrapped_model): + """On NPU, drop the generic 4D dense mask so MindSpeed can build + its own compressed causal mask for FlashAttention.""" + if Platform.device_prefix() != 'npu': + return + attention_mask = batch.get('attention_mask') + if (isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 + and getattr(unwrapped_model.config, 'attention_mask_type', None) == 'causal'): + batch['attention_mask'] = None + def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( loss_instance=CrossEntropyLoss(reduction='sum'), @@ -358,8 +378,8 @@ def post_loss_function(output_tensor, inputs, logps): def forward_step_func(data_iterator, model): batch = next(data_iterator) labels = batch.pop('labels', None) - # Handle disable_lora for base model inference (e.g., reference in DPO) unwrapped_model = self.strategy.unwrap_model([model])[0] + self._drop_npu_causal_4d_mask(batch, unwrapped_model) if disable_lora and isinstance(unwrapped_model, PeftModel): with unwrapped_model.disable_adapter(): output_tensor = model(**batch) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 77cb330e..6cbc579a 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -16,6 +16,7 @@ from twinkle.metric import Metric from twinkle.processor import InputProcessor from ..multi_lora import MultiLora +from ._mindspeed_runtime import ensure_mindspeed_adaptor_patched from .megatron import MegatronModel from .strategy import MegatronStrategy @@ -41,7 +42,6 @@ def __init__( **kwargs, ): requires('megatron_core') - requires('mcore_bridge') os.environ['TOKENIZERS_PARALLELISM'] = 'true' os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' nn.Module.__init__(self) @@ -59,6 +59,8 @@ def __init__( self.optimizer_group = {} torch_util.set_device() self._try_init_process_group() + ensure_mindspeed_adaptor_patched() + requires('mcore_bridge') kwargs.update({ 'recompute_granularity': recompute_granularity, diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index b9e66505..74d01457 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -5,6 +5,33 @@ from typing import Any, Dict, List, Literal, Optional from twinkle import DeviceMesh, Platform, torch_util +from twinkle.utils import get_logger +from .._mindspeed_runtime import configure_mindspeed_runtime_args + +logger = get_logger() + + +def finalize_model_grads_for_lora(model, *args, **kwargs): + """Only enter Megatron native finalize when the wrapped model has sync capability. + + In single-rank/no-op wrap cases Twinkle attaches ``ddp_config`` to the bare + module for optimizer compatibility, but that does not mean the model really + implements ``finish_grad_sync()``. Native Megatron finalize ultimately calls + that method, so we gate by runtime capability instead of config metadata. + """ + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads + from peft import PeftModel as _PeftModel + + def _get_base_model(m): + if isinstance(m, _PeftModel): + return _get_base_model(m.base_model.model) + return m + + base_model = _get_base_model(model[0]) + if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'finish_grad_sync'): + return _native_finalize_model_grads(model, *args, **kwargs) + return None class MegatronStrategy: @@ -21,6 +48,7 @@ def __init__( ddp_config: Dict[str, Any] = None, **kwargs, ): + import torch.distributed as dist from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer @@ -34,6 +62,15 @@ def __init__( self.hf_config = AutoConfig.from_pretrained(self.model_dir, trust_remote_code=True) else: self.hf_config = config + num_experts = getattr(self.hf_config, 'num_experts', getattr(self.hf_config, 'num_local_experts', None)) + if (num_experts not in (None, 0, 1) and (self.device_mesh.tp_world_size or 1) > 1 + and not getattr(self.device_mesh, 'sequence_parallel', False)): + # Megatron 0.15.3 requires sequence parallelism for MoE training when + # tensor parallelism is enabled. Keep this policy in the framework so + # cookbook scripts do not need to know a model-family-specific + # runtime constraint just to launch a valid MoE run. + self.device_mesh.sequence_parallel = True + logger.info('Auto-enabled sequence_parallel for MoE model with tensor parallelism.') if 'overlap_grad_reduce' not in self.ddp_config: self.ddp_config['overlap_grad_reduce'] = False if 'overlap_param_gather' not in self.ddp_config: @@ -69,10 +106,22 @@ def __init__( if 'overlap_p2p_comm' not in kwargs: kwargs['overlap_p2p_comm'] = True kwargs['batch_p2p_comm'] = not kwargs['overlap_p2p_comm'] - mpu.initialize_model_parallel( - order=self.device_mesh.order, + if Platform.device_prefix() == 'npu' and dist.is_initialized(): + default_pg = dist.distributed_c10d._get_default_group() + if getattr(default_pg, 'bound_device_id', None) is not None: + # If the default HCCL PG keeps a bound device id, PyTorch may + # propagate that binding into later Gloo subgroup creation. That + # breaks the metrics/object-gather path on NPU, so clear it + # before Megatron creates its Gloo DP groups. + default_pg.bound_device_id = None + + init_kwargs = { + 'order': self.device_mesh.order, **parallel_kwargs, - ) + } + if Platform.device_prefix() == 'npu': + init_kwargs['create_gloo_process_groups'] = True + mpu.initialize_model_parallel(**init_kwargs) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed model_parallel_cuda_manual_seed(self.seed) self.config = self.get_model_config(self.hf_config, parallel_kwargs, **kwargs) @@ -225,7 +274,6 @@ def get_model_config( **kwargs, ): from mcore_bridge import ModelConfig, hf_to_mcore_config - from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads config_kwargs = hf_to_mcore_config(hf_config) config_kwargs.update(kwargs) if 'calculate_per_token_loss' not in config_kwargs: @@ -233,24 +281,7 @@ def get_model_config( if 'moe_token_dispatcher_type' not in config_kwargs: config_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' - - def finalize_model_grads_for_lora(model, *args, **kwargs): - from megatron.core.distributed import DistributedDataParallel as MegatronDDP - from peft import PeftModel as _PeftModel - - # Check if model is DDP-wrapped (has ddp_config) - # Need to unwrap PeftModel to check the underlying model - def _get_base_model(m): - if isinstance(m, _PeftModel): - return _get_base_model(m.base_model.model) - return m - - base_model = _get_base_model(model[0]) - if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'ddp_config'): - # Use native implementation for DDP models - return _native_finalize_model_grads(model, *args, **kwargs) - - return ModelConfig( + model_config = ModelConfig( use_cpu_initialization=True, params_dtype=self.params_type, sequence_parallel=self.sequence_parallel, @@ -259,6 +290,18 @@ def _get_base_model(m): **parallel_kwargs, **config_kwargs, ) + if Platform.device_prefix() == 'npu': + # After Twinkle stops feeding the dense 4D causal mask, MindSpeed's + # patched TE attention should generate its own compressed causal + # mask. In 0.15.3 that path is gated by ``use_flash_attn`` on the + # model config itself. If we leave it unset, MindSpeed falls back to + # the non-flash mask generator and aborts the first 8-card forward + # with: "Please set micro_batch_size or set use_flash_attn=True in + # config." Keep the TE flash path enabled and let it synthesize the + # mask it expects. + model_config.use_flash_attn = True + configure_mindspeed_runtime_args(model_config) + return model_config def create_megatron_model( self, diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 0cdeb81d..09c91908 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -42,6 +42,16 @@ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None): import torch.distributed as dist output_objects = [object] if device_mesh is not None and device_mesh.data_world_size > 1: + if Platform.device_prefix() == 'npu': + # On NPU, letting Python object collectives use the default HCCL + # group previously hung in 8-card metric collection at + # ``dist.all_gather_object(...)``. Reuse Megatron's dedicated Gloo + # DP group instead. When CP is enabled we must pick the DP+CP + # variant, otherwise the rank span for metric aggregation is wrong. + if importlib.util.find_spec('megatron.core') is not None: + from megatron.core import parallel_state as mpu + process_group = mpu.get_data_parallel_group_gloo( + with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) group_size = dist.get_world_size(group=process_group) output_objects = [None for _ in range(group_size)] dist.all_gather_object(output_objects, object, group=process_group)