[JAX] Remove shard_map from MoEBlock to support quant before FSDP AG using Grouped quant+GEMM custom partitioning rules#3131
Conversation
…mm-custom-partition-rules
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
[JAX] MoE: soft re-pin inbound activations sharding at moe() entry [JAX] MoE: scope gate_logits 2D reshape to topk primitive call [JAX] MoE: add apply_topk_weights_early flag (TE EP backend only) [JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat Signed-off-by: tdophung <tdophung@nvidia.com>
…y step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks Signed-off-by: tdophung <tdophung@nvidia.com>
…per-call ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under jax.jit and fails with TracerArrayConversionError. Move the bootstrap to the test fixture (matching the test_multi_process_ep.py pattern from the TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule now only asserts that the recorded bootstrap signature is wide enough (num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank and recv_capacity_per_rank <= bootstrap values). Test mesh fixture bootstraps with the worst-case recv_pr across _CONFIGS so every parametrized config is compatible with a single per-process bootstrap.
The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in e927903) expects an EpHandle object plus a separate handle_mem buffer for every dispatch/combine call. The MoE wrapper was still passing the raw slots_per_expert int as the second positional and unpacking ep_dispatch_fwd as a 3-tuple, which now blows up with "AttributeError: 'int' object has no attribute 'handle_id'". Changes: - Cache one EpHandle per (top_k, alignment) at module scope so repeated jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool. - _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle) -> (token_counts, handle_mem), and pass (handle, handle_mem) into the fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple. - _Ctx: stash handle_mem alongside handle so the bwd can hand both back to ep_combine_bwd and ep_dispatch_bwd. - _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine calls.
te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap. Tests dispatch bf16 tokens; without this arg the group lands with the legacy kByte default (1 byte) and every dispatch aborts at the ep_backend.cpp:349 dtype check.
…ill fix for real in later commits Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR removes
Confidence Score: 5/5Safe to merge — only P2 findings; no correctness-breaking bugs in the hot path for the primary MXFP8 grouped-quantizer use case. All three flagged issues are P2: the custom_vjp re-creation is a performance concern, the wrong cotangent only affects subclasses with JAX-array leaves (DelayedScaleQuantizer) which may not be exercised in the current call sites, and the 128-alignment floor is a documentation/API-contract issue. No P0 or P1 defects were found. transformer_engine/jax/moe.py (lines 61-75, 648-650, 1069) warrants a second look for the three P2 issues. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant C as Caller
participant M as _moe (custom_vjp)
participant G as Gate+TopK
participant D as ep_dispatch_fwd
participant F as _ffn_fwd_global
participant CB as ep_combine_fwd
C->>M: tokens, topk_idx, topk_weights, ffn_quantizer_set
M->>G: tokens
G-->>M: topk_idx, topk_weights
M->>D: handle, topk_idx, tokens, topk_weights
D-->>M: recv_tokens, recv_topk_weights, handle_mem, token_counts
M->>F: recv_tokens (grouped GEMM + quant)
F-->>M: expert_out
M->>CB: handle, handle_mem, token_counts, expert_out, recv_topk_weights
CB-->>M: combined_output
M-->>C: output, residuals
Note over M,CB: Backward mirrors each step:
Note over M,CB: _combine_bwd → _ffn_bwd_global → _dispatch_bwd
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant C as Caller
participant M as _moe (custom_vjp)
participant G as Gate+TopK
participant D as ep_dispatch_fwd
participant F as _ffn_fwd_global
participant CB as ep_combine_fwd
C->>M: tokens, topk_idx, topk_weights, ffn_quantizer_set
M->>G: tokens
G-->>M: topk_idx, topk_weights
M->>D: handle, topk_idx, tokens, topk_weights
D-->>M: recv_tokens, recv_topk_weights, handle_mem, token_counts
M->>F: recv_tokens (grouped GEMM + quant)
F-->>M: expert_out
M->>CB: handle, handle_mem, token_counts, expert_out, recv_topk_weights
CB-->>M: combined_output
M-->>C: output, residuals
Note over M,CB: Backward mirrors each step:
Note over M,CB: _combine_bwd → _ffn_bwd_global → _dispatch_bwd
Reviews (2): Last reviewed commit: "Use 2D output from grouped quant and sup..." | Re-trigger Greptile |
| ep_resource: str = None | ||
| pp_resource: str = None | ||
| cp_resource: str = None | ||
| ep_resource: str = None |
There was a problem hiding this comment.
The
MeshResource dataclass declares ep_resource twice — once at line 497 (the newly-inserted position before pp_resource) and again at line 500 (the old trailing position). Python silently handles this because __annotations__ is a dict and the second annotation is a no-op update to an already-present key, so the field ends up at the correct new position. However, the duplicate is misleading to readers and also duplicates the docstring entry (lines 482 and 485-490). The trailing declaration at line 500 should be removed.
| ep_resource: str = None | |
| pp_resource: str = None | |
| cp_resource: str = None | |
| ep_resource: str = None | |
| ep_resource: str = None | |
| pp_resource: str = None | |
| cp_resource: str = None |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| ep_resource: Axis name for expert parallelism (expert sharding), default is None | ||
| pp_resource: Axis name for pipeline parallelism (layer sharding), default is None | ||
| cp_resource: Axis name for context parallelism (sequence sharding), default is None | ||
| ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None | ||
| ep_resource: Axis name for expert parallelism. Dispatch input tokens | ||
| must be sharded on their leading dim by ``ep_resource`` (alone or | ||
| compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. | ||
| ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output | ||
| ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` | ||
| on the leading ``ep_size`` dim. |
There was a problem hiding this comment.
The docstring for
MeshResource lists ep_resource twice (lines 482 and 485). The old one-liner should be removed so the expanded description is the only entry.
| ep_resource: Axis name for expert parallelism (expert sharding), default is None | |
| pp_resource: Axis name for pipeline parallelism (layer sharding), default is None | |
| cp_resource: Axis name for context parallelism (sequence sharding), default is None | |
| ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None | |
| ep_resource: Axis name for expert parallelism. Dispatch input tokens | |
| must be sharded on their leading dim by ``ep_resource`` (alone or | |
| compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. | |
| ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output | |
| ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` | |
| on the leading ``ep_size`` dim. | |
| pp_resource: Axis name for pipeline parallelism (layer sharding), default is None | |
| cp_resource: Axis name for context parallelism (sequence sharding), default is None | |
| ep_resource: Axis name for expert parallelism. Dispatch input tokens | |
| must be sharded on their leading dim by ``ep_resource`` (alone or | |
| compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. | |
| ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output | |
| ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` | |
| on the leading ``ep_size`` dim. |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def _get_or_make_ep_handle(top_k: int, dispatch_output_per_expert_alignment: int): | ||
| key = (int(top_k), int(dispatch_output_per_expert_alignment)) | ||
| h = _te_ep_handle_cache.get(key) | ||
| if h is None: | ||
| h = tex.ep_make_handle( | ||
| top_k=key[0], | ||
| dispatch_output_per_expert_alignment=key[1], | ||
| ) | ||
| _te_ep_handle_cache[key] = h | ||
| return h |
There was a problem hiding this comment.
Handle cache keyed only on config, not layer identity
_get_or_make_ep_handle uses (top_k, alignment) as the cache key, so two distinct MoE layers in the same model that happen to share the same top_k and alignment will receive the same EpHandle. The ep_make_handle docstring itself says "distinct layers must hold distinct handles." If the C++ EP backend stores any mutable per-layer routing state indexed by handle_id — rather than relying entirely on the per-call handle_mem buffer — these two layers would collide on that state and produce silent incorrect results or data corruption, especially during backward passes that interleave the two layers' ep_dispatch_bwd/ep_combine_bwd calls. Is the C++ EP backend fully stateless once handle_mem is allocated per call? If mutable per-layer routing state is kept under handle_id, two layers with identical (top_k, alignment) sharing a handle would corrupt each other. Could you clarify whether sharing handles across distinct layers is intentional and safe?
| if wi_0_bias is not None: | ||
| wi_0_bias = jnp.broadcast_to( | ||
| wi_0_bias.reshape(1, num_ep, num_local_experts, *wi_0_bias.shape[1:]), | ||
| (dp_size, num_ep, num_local_experts, *wi_0_bias.shape[1:]), | ||
| ).reshape(num_groups, *wi_0_bias.shape[1:]) | ||
| wi_1_bias = jnp.broadcast_to( | ||
| wi_1_bias.reshape(1, num_ep, num_local_experts, *wi_1_bias.shape[1:]), | ||
| (dp_size, num_ep, num_local_experts, *wi_1_bias.shape[1:]), | ||
| ).reshape(num_groups, *wi_1_bias.shape[1:]) | ||
| wo_bias = jnp.broadcast_to( | ||
| wo_bias.reshape(1, num_ep, num_local_experts, *wo_bias.shape[1:]), | ||
| (dp_size, num_ep, num_local_experts, *wo_bias.shape[1:]), | ||
| ).reshape(num_groups, *wo_bias.shape[1:]) | ||
| wi_0_bias = jax.lax.with_sharding_constraint(wi_0_bias, bias_sharding) | ||
| wi_1_bias = jax.lax.with_sharding_constraint(wi_1_bias, bias_sharding) | ||
| wo_bias = jax.lax.with_sharding_constraint(wo_bias, bias_sharding) |
There was a problem hiding this comment.
Unchecked bias consistency assumption
_ffn_fwd_global guards the bias reshape block on wi_0_bias is not None, then unconditionally accesses wi_1_bias.reshape(...) and wo_bias.reshape(...) inside that branch. If a caller passes only wi_0_bias without the other two, this raises AttributeError: 'NoneType' object has no attribute 'reshape' at trace time. The public moe() signature accepts all three as independently optional and contains no validation that they are all-or-nothing. Adding a consistency check at the moe() boundary would make the failure mode clear rather than crashing inside the FFN helper.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: