Skip to content

[JAX] Remove shard_map from MoEBlock to support quant before FSDP AG using Grouped quant+GEMM custom partitioning rules#3131

Draft
jberchtold-nvidia wants to merge 38 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/quant-before-fsdp-ag-gmm
Draft

[JAX] Remove shard_map from MoEBlock to support quant before FSDP AG using Grouped quant+GEMM custom partitioning rules#3131
jberchtold-nvidia wants to merge 38 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/quant-before-fsdp-ag-gmm

Conversation

@jberchtold-nvidia

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 30 commits May 26, 2026 16:12
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
[JAX] MoE: soft re-pin inbound activations sharding at moe() entry
[JAX] MoE: scope gate_logits 2D reshape to topk primitive call
[JAX] MoE: add apply_topk_weights_early flag (TE EP backend only)
[JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat

Signed-off-by: tdophung <tdophung@nvidia.com>
…y step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks

Signed-off-by: tdophung <tdophung@nvidia.com>
…per-call

ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under
jax.jit and fails with TracerArrayConversionError. Move the bootstrap to
the test fixture (matching the test_multi_process_ep.py pattern from the
TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls
record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule
now only asserts that the recorded bootstrap signature is wide enough
(num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank
and recv_capacity_per_rank <= bootstrap values). Test mesh fixture
bootstraps with the worst-case recv_pr across _CONFIGS so every
parametrized config is compatible with a single per-process bootstrap.
tdophung and others added 7 commits June 3, 2026 16:55
The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in
e927903) expects an EpHandle object plus a separate handle_mem buffer
for every dispatch/combine call. The MoE wrapper was still passing the
raw slots_per_expert int as the second positional and unpacking
ep_dispatch_fwd as a 3-tuple, which now blows up with
"AttributeError: 'int' object has no attribute 'handle_id'".

Changes:
- Cache one EpHandle per (top_k, alignment) at module scope so repeated
  jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool.
- _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle)
  -> (token_counts, handle_mem), and pass (handle, handle_mem) into the
  fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple.
- _Ctx: stash handle_mem alongside handle so the bwd can hand both back
  to ep_combine_bwd and ep_dispatch_bwd.
- _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine
  calls.
te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap.
Tests dispatch bf16 tokens; without this arg the group lands with the
legacy kByte default (1 byte) and every dispatch aborts at the
ep_backend.cpp:349 dtype check.
…ill fix for real in later commits

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR removes shard_map from MoEBlock and replaces it with NCCL-backed Expert Parallelism (EP) primitives (ep_dispatch / ep_combine) and JAX custom_partitioning rules on grouped GEMM and grouped quantization, enabling quantization before the FSDP all-gather.

  • transformer_engine/jax/moe.py is substantially rewritten: the old shard_map + ragged_all_to_all + Triton permutation path is replaced by a single custom_vjp (_moe) that calls into new ep_dispatch_fwd/bwd and ep_combine_fwd/bwd C++ primitives, with grouped GEMM and MXFP8 quantization via custom-partitioning.
  • transformer_engine/jax/cpp_extensions/ep.py (new) implements EpPreparePrimitive, EpDispatchPrimitive, EpCombinePrimitive, and their backward counterparts as JAX custom primitives with full partition / infer_sharding_from_operands support for SPMD.
  • transformer_engine/jax/ep.py (new) adds the public ep_bootstrap / ep_dispatch / ep_combine API, bootstrapping NCCL communicators via a cross-host UID allgather with a fallback path for problematic launchers.

Confidence Score: 5/5

Safe to merge — only P2 findings; no correctness-breaking bugs in the hot path for the primary MXFP8 grouped-quantizer use case.

All three flagged issues are P2: the custom_vjp re-creation is a performance concern, the wrong cotangent only affects subclasses with JAX-array leaves (DelayedScaleQuantizer) which may not be exercised in the current call sites, and the 128-alignment floor is a documentation/API-contract issue. No P0 or P1 defects were found.

transformer_engine/jax/moe.py (lines 61-75, 648-650, 1069) warrants a second look for the three P2 issues.

Important Files Changed

Filename Overview
transformer_engine/jax/moe.py Major rewrite replacing shard_map with NCCL EP primitives and custom_vjp; three P2 issues: custom_vjp re-created per call, wrong cotangent for quantizer pytree leaves, and hardcoded 128-token alignment floor.
transformer_engine/jax/cpp_extensions/ep.py New file implementing EpPrepare/Dispatch/Combine as JAX custom primitives with full partition and infer_sharding_from_operands support; logic looks correct.
transformer_engine/jax/ep.py New public EP API: ep_bootstrap (with NCCL UID allgather and fallback), ep_dispatch and ep_combine as custom_vjp wrappers; bootstrap validation guards are thorough.
transformer_engine/jax/sharding.py Adds spec_axes, filter_spec_axes, local_shape_from_spec, and merge_axis_specs helpers; with_sharding_constraint refactored to filter auto-axes only.
transformer_engine/jax/flax/moe.py Removes permutation_backend/expert_bias_init, adds ffn_quantizer_set/align_size/apply_topk_weights_early params; straightforward parameter-forwarding change.
transformer_engine/jax/cpp_extensions/gemm.py Adds GroupedGemmPrimitive._parse_partition_specs for EP/FSDP custom partitioning, handling RHS FSDP all-gather and EP axis propagation on weights.
transformer_engine/jax/cpp_extensions/quantization.py GroupedQuantizePrimitive gains infer_sharding_from_operands and partition; out_shape changed from flat prod to x_aval.shape to support ND outputs.
transformer_engine/jax/dense.py Removes kernel_fsdp_info parameter and _all_gather_kernel/_psum_scatter_kernel (superseded by custom partitioning); clean removal with no dangling references.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant C as Caller
    participant M as _moe (custom_vjp)
    participant G as Gate+TopK
    participant D as ep_dispatch_fwd
    participant F as _ffn_fwd_global
    participant CB as ep_combine_fwd

    C->>M: tokens, topk_idx, topk_weights, ffn_quantizer_set
    M->>G: tokens
    G-->>M: topk_idx, topk_weights
    M->>D: handle, topk_idx, tokens, topk_weights
    D-->>M: recv_tokens, recv_topk_weights, handle_mem, token_counts
    M->>F: recv_tokens (grouped GEMM + quant)
    F-->>M: expert_out
    M->>CB: handle, handle_mem, token_counts, expert_out, recv_topk_weights
    CB-->>M: combined_output
    M-->>C: output, residuals

    Note over M,CB: Backward mirrors each step:
    Note over M,CB: _combine_bwd → _ffn_bwd_global → _dispatch_bwd
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant C as Caller
    participant M as _moe (custom_vjp)
    participant G as Gate+TopK
    participant D as ep_dispatch_fwd
    participant F as _ffn_fwd_global
    participant CB as ep_combine_fwd

    C->>M: tokens, topk_idx, topk_weights, ffn_quantizer_set
    M->>G: tokens
    G-->>M: topk_idx, topk_weights
    M->>D: handle, topk_idx, tokens, topk_weights
    D-->>M: recv_tokens, recv_topk_weights, handle_mem, token_counts
    M->>F: recv_tokens (grouped GEMM + quant)
    F-->>M: expert_out
    M->>CB: handle, handle_mem, token_counts, expert_out, recv_topk_weights
    CB-->>M: combined_output
    M-->>C: output, residuals

    Note over M,CB: Backward mirrors each step:
    Note over M,CB: _combine_bwd → _ffn_bwd_global → _dispatch_bwd
Loading

Reviews (2): Last reviewed commit: "Use 2D output from grouped quant and sup..." | Re-trigger Greptile

Comment on lines +497 to 500
ep_resource: str = None
pp_resource: str = None
cp_resource: str = None
ep_resource: str = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The MeshResource dataclass declares ep_resource twice — once at line 497 (the newly-inserted position before pp_resource) and again at line 500 (the old trailing position). Python silently handles this because __annotations__ is a dict and the second annotation is a no-op update to an already-present key, so the field ends up at the correct new position. However, the duplicate is misleading to readers and also duplicates the docstring entry (lines 482 and 485-490). The trailing declaration at line 500 should be removed.

Suggested change
ep_resource: str = None
pp_resource: str = None
cp_resource: str = None
ep_resource: str = None
ep_resource: str = None
pp_resource: str = None
cp_resource: str = None

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +482 to +490
ep_resource: Axis name for expert parallelism (expert sharding), default is None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None
ep_resource: Axis name for expert parallelism. Dispatch input tokens
must be sharded on their leading dim by ``ep_resource`` (alone or
compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g.
``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output
``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource``
on the leading ``ep_size`` dim.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The docstring for MeshResource lists ep_resource twice (lines 482 and 485). The old one-liner should be removed so the expanded description is the only entry.

Suggested change
ep_resource: Axis name for expert parallelism (expert sharding), default is None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None
ep_resource: Axis name for expert parallelism. Dispatch input tokens
must be sharded on their leading dim by ``ep_resource`` (alone or
compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g.
``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output
``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource``
on the leading ``ep_size`` dim.
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
ep_resource: Axis name for expert parallelism. Dispatch input tokens
must be sharded on their leading dim by ``ep_resource`` (alone or
compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g.
``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output
``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource``
on the leading ``ep_size`` dim.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +121 to +130
def _get_or_make_ep_handle(top_k: int, dispatch_output_per_expert_alignment: int):
key = (int(top_k), int(dispatch_output_per_expert_alignment))
h = _te_ep_handle_cache.get(key)
if h is None:
h = tex.ep_make_handle(
top_k=key[0],
dispatch_output_per_expert_alignment=key[1],
)
_te_ep_handle_cache[key] = h
return h

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Handle cache keyed only on config, not layer identity

_get_or_make_ep_handle uses (top_k, alignment) as the cache key, so two distinct MoE layers in the same model that happen to share the same top_k and alignment will receive the same EpHandle. The ep_make_handle docstring itself says "distinct layers must hold distinct handles." If the C++ EP backend stores any mutable per-layer routing state indexed by handle_id — rather than relying entirely on the per-call handle_mem buffer — these two layers would collide on that state and produce silent incorrect results or data corruption, especially during backward passes that interleave the two layers' ep_dispatch_bwd/ep_combine_bwd calls. Is the C++ EP backend fully stateless once handle_mem is allocated per call? If mutable per-layer routing state is kept under handle_id, two layers with identical (top_k, alignment) sharing a handle would corrupt each other. Could you clarify whether sharing handles across distinct layers is intentional and safe?

Comment on lines +351 to +366
if wi_0_bias is not None:
wi_0_bias = jnp.broadcast_to(
wi_0_bias.reshape(1, num_ep, num_local_experts, *wi_0_bias.shape[1:]),
(dp_size, num_ep, num_local_experts, *wi_0_bias.shape[1:]),
).reshape(num_groups, *wi_0_bias.shape[1:])
wi_1_bias = jnp.broadcast_to(
wi_1_bias.reshape(1, num_ep, num_local_experts, *wi_1_bias.shape[1:]),
(dp_size, num_ep, num_local_experts, *wi_1_bias.shape[1:]),
).reshape(num_groups, *wi_1_bias.shape[1:])
wo_bias = jnp.broadcast_to(
wo_bias.reshape(1, num_ep, num_local_experts, *wo_bias.shape[1:]),
(dp_size, num_ep, num_local_experts, *wo_bias.shape[1:]),
).reshape(num_groups, *wo_bias.shape[1:])
wi_0_bias = jax.lax.with_sharding_constraint(wi_0_bias, bias_sharding)
wi_1_bias = jax.lax.with_sharding_constraint(wi_1_bias, bias_sharding)
wo_bias = jax.lax.with_sharding_constraint(wo_bias, bias_sharding)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unchecked bias consistency assumption

_ffn_fwd_global guards the bias reshape block on wi_0_bias is not None, then unconditionally accesses wi_1_bias.reshape(...) and wo_bias.reshape(...) inside that branch. If a caller passes only wi_0_bias without the other two, this raises AttributeError: 'NoneType' object has no attribute 'reshape' at trace time. The public moe() signature accepts all three as independently optional and contains no validation that they are all-or-nothing. Adding a consistency check at the moe() boundary would make the failure mode clear rather than crashing inside the FFN helper.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft June 15, 2026 22:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants