Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive NPU support for the Megatron backend, featuring documentation updates for environment requirements and a new MindSpeed runtime bootstrap for NPU-specific patching and argument synthesis. Key technical changes include refined process group initialization for single-rank environments, optimized attention mask handling for NPU FlashAttention, and the use of Gloo groups for object gathering on NPU to prevent hangs. Review feedback pointed out a potential initialization error regarding invalid arguments in init_process_group, a hard dependency on megatron-core in utility functions, hardcoded paths in the documentation, and suggested expanding the mask-dropping logic to all causal NPU configurations.
| if backend == 'nccl': | ||
| init_kwargs['device_id'] = torch.device(Platform.get_local_device()) |
There was a problem hiding this comment.
The torch.distributed.init_process_group function does not accept a device_id argument in standard PyTorch. Including it in init_kwargs will result in a TypeError during initialization. For the nccl backend, device binding is correctly handled by calling torch.cuda.set_device() before initialization, which is already performed via torch_util.set_device() on line 112.
| if backend == 'nccl': | |
| init_kwargs['device_id'] = torch.device(Platform.get_local_device()) | |
| if backend == 'nccl': | |
| # NCCL binding is handled by torch.cuda.set_device() called earlier | |
| pass |
src/twinkle/utils/framework.py
Outdated
| from megatron.core import parallel_state as mpu | ||
| process_group = mpu.get_data_parallel_group_gloo( | ||
| with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) |
There was a problem hiding this comment.
This block introduces a hard dependency on megatron-core whenever the platform is NPU. Since twinkle can be used on NPU with other backends (such as pure transformers), this will cause an ImportError for users who have not installed Megatron dependencies. Consider wrapping this import in a try...except block to allow a graceful fallback to the default process group.
| from megatron.core import parallel_state as mpu | |
| process_group = mpu.get_data_parallel_group_gloo( | |
| with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) | |
| try: | |
| from megatron.core import parallel_state as mpu | |
| process_group = mpu.get_data_parallel_group_gloo( | |
| with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) | |
| except (ImportError, ModuleNotFoundError): | |
| pass |
| if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4: | ||
| unwrapped_model = self.strategy.unwrap_model([model])[0] | ||
| attention_mask_type = getattr(unwrapped_model.config, 'attention_mask_type', None) | ||
| if attention_mask_type == 'causal' and self.device_mesh.cp_world_size > 1: |
There was a problem hiding this comment.
The condition self.device_mesh.cp_world_size > 1 might be too restrictive. The comment explains that the 4D dense mask causes failures in aclnnFlashAttentionScore on NPU and is redundant for causal training. If the NPU FlashAttention implementation requires a specific mask shape (or no mask for causal models), this issue likely affects TP-only configurations as well. Removing the CP world size check ensures FlashAttention works correctly on NPU for all parallel configurations where causal attention is used.
| if attention_mask_type == 'causal' and self.device_mesh.cp_world_size > 1: | |
| if attention_mask_type == 'causal': |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR completes Twinkle’s NPU Megatron integration targeting the Megatron-LM 0.15.3 + MindSpeed 0.15.3 + mcore-bridge stack, focusing on stabilizing 8-card dense/LoRA training on NPU by fixing MindSpeed bootstrap timing, distributed/metric collectives, and NPU FlashAttention mask handling.
Changes:
- Add an NPU MindSpeed bootstrap layer to ensure adaptor patching happens before
mcore_bridgeimports Megatron/TE, and synthesize/refresh MindSpeed runtime args fromModelConfig. - Adjust Megatron initialization for NPU (default PG fallback, Gloo process groups, metrics/object-gather behavior) and fix causal mask handling for NPU FlashAttention.
- Update NPU documentation and add Megatron NPU smoke cookbooks/scripts.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/twinkle/utils/framework.py | Prefer Megatron’s Gloo DP group for all_gather_object on NPU to avoid HCCL hangs during metric/object collection. |
| src/twinkle/model/megatron/strategy/megatron.py | NPU-specific Megatron init tweaks (Gloo PG creation, device binding cleanup), MoE sequence-parallel auto-enable, and MindSpeed runtime arg configuration. |
| src/twinkle/model/megatron/multi_lora_megatron.py | Reorder MindSpeed patching ahead of mcore_bridge import for NPU multi-LoRA Megatron path. |
| src/twinkle/model/megatron/megatron.py | Add default-PG fallback for single-rank smoke, ensure early MindSpeed patching, and drop dense 4D causal masks on NPU causal TE flash path. |
| src/twinkle/model/megatron/_mindspeed_runtime.py | New module implementing early MindSpeed adaptor patching + runtime args synthesis + conditional repatching. |
| docs/source_en/Usage Guide/NPU-Support.md | Update NPU dependency guidance, add Megatron backend install steps, and point to Megatron NPU smoke cookbooks. |
| cookbook/megatron/ascend/tp_npu.py (+ .sh) | Add 8-card TP/PP/DP NPU Megatron smoke script. |
| cookbook/megatron/ascend/tp_moe_npu.py (+ .sh) | Add 8-card MoE NPU smoke script. |
| cookbook/megatron/ascend/tp_moe_cp_npu.py (+ .sh) | Add 8-card MoE+CP NPU smoke script (megatron_cp_algo path). |
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import LRScheduler | ||
| from transformers import PretrainedConfig | ||
| from transformers import PreTrainedConfig |
There was a problem hiding this comment.
transformers exposes PretrainedConfig (lowercase “t”), not PreTrainedConfig. Importing PreTrainedConfig will raise ImportError at runtime. Please switch the import (and corresponding type hints) back to PretrainedConfig to match the Transformers API and the rest of the codebase.
| from transformers import PreTrainedConfig | |
| from transformers import PretrainedConfig |
| def __init__( | ||
| self, | ||
| model_id: str, | ||
| config: Optional[PretrainedConfig] = None, | ||
| config: Optional[PreTrainedConfig] = None, | ||
| ddp_config: Optional[Dict[str, Any]] = None, | ||
| device_mesh: Optional[DeviceMesh] = None, | ||
| mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', |
There was a problem hiding this comment.
The config parameter is annotated as PreTrainedConfig, but Transformers’ config base class is PretrainedConfig. With the current import this will fail to import at runtime; please update the annotation to use PretrainedConfig after fixing the import.
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||
| import torch | ||
| import torch.nn as nn | ||
| from transformers import PreTrainedConfig | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
There was a problem hiding this comment.
transformers uses PretrainedConfig (lowercase “t”), not PreTrainedConfig. This import will fail at runtime and also makes the Megatron strategy inconsistent with other Twinkle model code that uses PretrainedConfig.
| if Platform.device_prefix() == 'npu': | ||
| # On NPU, letting Python object collectives use the default HCCL | ||
| # group previously hung in 8-card metric collection at | ||
| # ``dist.all_gather_object(...)``. Reuse Megatron's dedicated Gloo | ||
| # DP group instead. When CP is enabled we must pick the DP+CP | ||
| # variant, otherwise the rank span for metric aggregation is wrong. | ||
| try: | ||
| from megatron.core import parallel_state as mpu | ||
| process_group = mpu.get_data_parallel_group_gloo( | ||
| with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) | ||
| except (ImportError, ModuleNotFoundError): | ||
| pass | ||
| group_size = dist.get_world_size(group=process_group) | ||
| output_objects = [None for _ in range(group_size)] | ||
| dist.all_gather_object(output_objects, object, group=process_group) |
There was a problem hiding this comment.
On NPU this tries to switch all_gather_object onto Megatron’s Gloo DP group, but if Megatron isn’t importable (or its groups aren’t initialized yet) the code silently falls back to the default group and still calls dist.all_gather_object(...), which your comment notes can hang on HCCL. Consider adding a safe fallback (e.g., create/use a dedicated Gloo process group for object collectives) or raise/log a clear error instead of silently using the default group.
| default_pg = dist.distributed_c10d._get_default_group() | ||
| if getattr(default_pg, 'bound_device_id', None) is not None: | ||
| # If the default HCCL PG keeps a bound device id, PyTorch may | ||
| # propagate that binding into later Gloo subgroup creation. That | ||
| # breaks the metrics/object-gather path on NPU, so clear it | ||
| # before Megatron creates its Gloo DP groups. | ||
| default_pg.bound_device_id = None |
There was a problem hiding this comment.
This relies on private PyTorch APIs (dist.distributed_c10d._get_default_group()) and unconditionally assigns to default_pg.bound_device_id. Depending on the torch version/backend, bound_device_id may be read-only or absent, which would raise at init time. Please guard both retrieval and assignment with try/except (or hasattr + safe setter) so Megatron init can’t crash on versions where this attribute isn’t writable.
| default_pg = dist.distributed_c10d._get_default_group() | |
| if getattr(default_pg, 'bound_device_id', None) is not None: | |
| # If the default HCCL PG keeps a bound device id, PyTorch may | |
| # propagate that binding into later Gloo subgroup creation. That | |
| # breaks the metrics/object-gather path on NPU, so clear it | |
| # before Megatron creates its Gloo DP groups. | |
| default_pg.bound_device_id = None | |
| try: | |
| default_pg = dist.distributed_c10d._get_default_group() | |
| except Exception as e: | |
| logger.debug(f'Failed to get default process group for NPU bound_device_id workaround: {e}') | |
| else: | |
| if getattr(default_pg, 'bound_device_id', None) is not None: | |
| # If the default HCCL PG keeps a bound device id, PyTorch may | |
| # propagate that binding into later Gloo subgroup creation. That | |
| # breaks the metrics/object-gather path on NPU, so clear it | |
| # before Megatron creates its Gloo DP groups. | |
| try: | |
| default_pg.bound_device_id = None | |
| except Exception as e: | |
| logger.debug( | |
| f'Failed to clear default process group bound_device_id for NPU workaround: {e}') |
| self.active_group = _default_adapter_name | ||
| MegatronPeft().__call__() | ||
|
|
||
| def _ensure_megatron_process_group(self): |
There was a problem hiding this comment.
上面的self._try_init_process_group() 应该包含了process_group的初始化操作,为什么需要处理第二次呢
There was a problem hiding this comment.
这块是有点问题,已经重写相关逻辑
| # and padded query positions are ignored by labels == -100. So on | ||
| # the NPU TE path, drop this dense mask and let MindSpeed build the | ||
| # compressed causal mask it requires. | ||
| if Platform.device_prefix() == 'npu': |
There was a problem hiding this comment.
这个if放入InputProcessor里面是否更合适
There was a problem hiding this comment.
这个判断依赖 unwrapped_model.config.attention_mask_type,该属性依赖的是 Megatron 运行时的模型配置,只能在 forward 运行时从模型实例上获取。InputProcessor 是纯数据处理组件,不持有模型引用,无法做这个判断。我把这块封装了一下简洁一点
src/twinkle/utils/framework.py
Outdated
| from megatron.core import parallel_state as mpu | ||
| process_group = mpu.get_data_parallel_group_gloo( | ||
| with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1) | ||
| except (ImportError, ModuleNotFoundError): |
There was a problem hiding this comment.
这个try except应该没有意义吧 如果初始化出错了就应该抛错
PR Type
Summary
This PR completes Twinkle's NPU Megatron adaptation and targets the Twinkle + Megatron-LM 0.15.3 + MindSpeed 0.15.3 + mcore-bridge stack. The goal is to make the dense / LoRA 8-card training path stable on NPU.
Main changes:
mcore_bridgeis imported to avoid late patching and early binding of TE / Megatron symbols.ModelConfigand the runtime parallel topology, then callrepatch()when the runtime signature changes.What Changed
1. MindSpeed runtime bootstrap
mcore_bridgeimport.2. Process group / metric gather
gather_object()to prefer Megatron's Gloo DP group to avoid hangs in metrics / Python object gathering.3. NPU FlashAttention
4. LoRA / Multi-LoRA
ddp_configis not incorrectly treated as a model that can run native finalize.5. Documentation
Notes
This PR targets the following version stack: