From b14df81d33f3522bfecdaa9ad8570430642d0764 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 26 May 2026 16:12:16 -0700 Subject: [PATCH 01/36] Add grouped GMM custom partitioning rules --- tests/jax/test_grouped_gemm_partitioning.py | 248 +++++++++++ ..._multi_process_distributed_grouped_gemm.py | 146 +++++-- transformer_engine/jax/cpp_extensions/gemm.py | 392 ++++++++++++++++++ .../jax/cpp_extensions/quantization.py | 192 +++++++++ transformer_engine/jax/dense.py | 56 ++- transformer_engine/jax/flax/module.py | 2 +- transformer_engine/jax/sharding.py | 2 + 7 files changed, 989 insertions(+), 49 deletions(-) create mode 100644 tests/jax/test_grouped_gemm_partitioning.py diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py new file mode 100644 index 0000000000..5fb93e30ae --- /dev/null +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -0,0 +1,248 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Partitioning tests for grouped quantize and grouped GEMM.""" + +from types import SimpleNamespace + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.cpp_extensions.gemm import GroupedGemmPrimitive +from transformer_engine.jax.cpp_extensions.quantization import GroupedQuantizePrimitive +from transformer_engine.jax.dense import grouped_dense +from transformer_engine.jax.quantize import QuantizeLayout, QuantizerFactory, ScalingMode +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def _mesh(): + return Mesh(np.asarray(jax.devices()[:1]).reshape(1, 1), ("expert", "fsdp")) + + +def _arg_info(mesh, shape, spec): + return SimpleNamespace( + shape=shape, + ndim=len(shape), + size=int(np.prod(shape)), + sharding=NamedSharding(mesh, PartitionSpec(*spec)), + ) + + +def _normalize_spec(spec): + if isinstance(spec, PartitionSpec): + return tuple(spec) + return spec + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) + + +def test_grouped_quantize_specs_preserve_ep_and_fsdp_for_block_scales(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, _ = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 64), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_quantize_mxfp8_colwise_specs_preserve_ep_and_fsdp(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, _ = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE_COLWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[1]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[3]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_gemm_rhs_weight_specs_gather_fsdp_but_preserve_ep(): + mesh = _mesh() + arg_infos = ( + _arg_info(mesh, (8192,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (65536,), (("expert", "fsdp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + (), + ) + + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == (None, None, None) + + +def test_grouped_partitioning_shardy_rules_smoke(): + mesh = _mesh() + quantize_rule = GroupedQuantizePrimitive.shardy_sharding_rule( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + SimpleNamespace(shape=(8, 128, 64)), + SimpleNamespace(shape=(8,)), + SimpleNamespace(shape=(8,)), + ), + ( + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8,)), + ), + ) + gemm_rule = GroupedGemmPrimitive.shardy_sharding_rule( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 2, + (128, 64), + 128, + 64, + 128, + 64, + mesh, + tuple(SimpleNamespace(shape=(1,)) for _ in range(13)), + (SimpleNamespace(shape=(128, 64)),), + ) + + assert quantize_rule is not None + assert gemm_rule is not None + + +def test_grouped_dense_mxfp8_ep_fsdp_outside_shard_map_single_process(): + mesh = _mesh() + n_groups = 2 + group_tokens = 128 + hidden = 128 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + + x_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + w_sharding = NamedSharding(mesh, PartitionSpec("expert", "fsdp", None)) + group_sharding = NamedSharding(mesh, PartitionSpec("expert")) + out_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + with mesh, global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) + + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + kernel_fsdp_info=("fsdp", 1), + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=(out_sharding, x_sharding, w_sharding), + )(x, w, group_sizes) + out, dx, dw = jax.block_until_ready((out, dx, dw)) + + assert tuple(out.sharding.spec) == ("expert", None) + assert tuple(dx.sharding.spec) == ("expert", None) + assert tuple(dw.sharding.spec) == ("expert", "fsdp", None) + for value in (out, dx, dw): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)) + assert np.any(local_value != 0.0) diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 94fed0859f..30a1452a07 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -7,18 +7,34 @@ import jax import jax.numpy as jnp import jax.experimental.multihost_utils as jem +import numpy as np +from jax.experimental import shard_map +from jax.sharding import NamedSharding, PartitionSpec from transformer_engine.jax.dense import grouped_dense as te_grouped_dense from transformer_engine.jax.quantize import ( QuantizerFactory, ScalingMode, ) +from transformer_engine.jax.sharding import MeshResource, global_shard_guard from utils import assert_allclose, dtype_tols N_GROUP = 8 -MESH_AXIS_NAME = "fsdp" +EP_AXIS_NAME = "ep" +FSDP_AXIS_NAME = "fsdp" +MESH_AXIS_NAME = FSDP_AXIS_NAME + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): @@ -36,13 +52,18 @@ def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): def init_data(): x_key = jax.random.PRNGKey(0) w_key = jax.random.PRNGKey(1) - x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) - w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) - w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) - return x, w, w, w_amax + x = ( + jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16) + ) + w = ( + jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16) + ) + return x, w, w - def test_func(outter_x, outter_w, outter_w_amax): - in_specs = (x_sharding.spec, w_sharding.spec, None) + def test_func(outter_x, outter_w): + in_specs = (x_sharding.spec, w_sharding.spec) out_specs = x_sharding.spec @partial( @@ -52,36 +73,29 @@ def test_func(outter_x, outter_w, outter_w_amax): out_specs=out_specs, check_rep=False, ) - def sharded_group_gemm(x, w, w_amax): + def sharded_group_gemm(x, w): group_size = x.shape[0] x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, - ) + quantizer_set = _mxfp8_grouped_quantizer_set(group_size) output = te_grouped_dense( x_reshaped, w, n_groups, - kernel_amax=w_amax, quantizer_set=quantizer_set, kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), ) output = output.reshape(*x.shape[:-1], -1) return output - def run(x, w, w_amax): - output = sharded_group_gemm(x, w, w_amax) + def run(x, w): + output = sharded_group_gemm(x, w) return output - output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) - dx, dw, _ = vjp_fn(output) + output, vjp_fn = jax.vjp(run, outter_x, outter_w) + dx, dw = vjp_fn(output) return output, dx, dw def ref_func(outter_x, outter_w): @@ -101,13 +115,7 @@ def sharded_group_gemm(x, w): x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, - ) + quantizer_set = _mxfp8_grouped_quantizer_set(group_size) output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) output = output.reshape(*x.shape[:-1], -1) return output @@ -120,13 +128,13 @@ def run(x, w): dx, dw = vjp_fn(output) return output, dx, dw - init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) - x, w, w_global, w_amax = init_func() + init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding)) + x, w, w_global = init_func() o_sharding = x_sharding test_func_jitted = jax.jit( test_func, - in_shardings=(x_sharding, w_sharding, None), + in_shardings=(x_sharding, w_sharding), out_shardings=(o_sharding, x_sharding, w_sharding), ) ref_func_jitted = jax.jit( @@ -135,24 +143,76 @@ def run(x, w): out_shardings=(o_sharding, x_sharding, w_no_sharding), ) - out, dx, dw = test_func_jitted(x, w, w_amax) + out, dx, dw = test_func_jitted(x, w) ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) e4m3_tols = dtype_tols(jnp.float8_e4m3fn) - e5m2_tols = dtype_tols(jnp.float8_e5m2) - out, ref_out = jem.process_allgather((out, ref_out)) - dx, ref_dx = jem.process_allgather((dx, ref_dx)) - dw, ref_dw = jem.process_allgather((dw, ref_dw)) + out, ref_out = jem.process_allgather((out, ref_out), tiled=True) + dx, ref_dx = jem.process_allgather((dx, ref_dx), tiled=True) + dw, ref_dw = jem.process_allgather((dw, ref_dw), tiled=True) + + assert_allclose(out, ref_out, **e4m3_tols) + assert_allclose(dx, ref_dx, **e4m3_tols) + assert_allclose(dw, ref_dw, **e4m3_tols) + + +def run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(): + n_groups = 4 + group_tokens = 128 + hidden = 256 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + x_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) + w_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, FSDP_AXIS_NAME, None)) + group_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME)) + out_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) + + with mesh, global_shard_guard( + MeshResource(ep_resource=EP_AXIS_NAME, fsdp_resource=FSDP_AXIS_NAME) + ): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) - jnp.allclose(out, ref_out, **e4m3_tols) - jnp.allclose(dx, ref_dx, **e5m2_tols) - jnp.allclose(dw, ref_dw, **e5m2_tols) + def apply(x, w, group_sizes): + return te_grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + kernel_fsdp_info=(FSDP_AXIS_NAME, 1), + ) + + out = jax.jit( + apply, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=out_sharding, + )(x, w, group_sizes) + jax.block_until_ready(out) + + local_out = np.asarray(jax.device_get(out.addressable_data(0))) + assert tuple(out.sharding.spec) == (EP_AXIS_NAME, None) + assert np.all(np.isfinite(local_out)) + assert np.any(local_out != 0.0) if __name__ == "__main__": - from jax.sharding import NamedSharding, PartitionSpec - from jax.experimental import shard_map import sys coord_addr = sys.argv[1] @@ -163,10 +223,14 @@ def run(x, w): coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id ) - mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) + mesh = jax.make_mesh((num_procs,), (FSDP_AXIS_NAME,)) with mesh: data_shapes = [((4, 16, 128, 7168), (7168, 2048))] for data_shape in data_shapes: for kernel_fsdp_axis in [1, 2]: test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) + + if num_procs == 4: + mesh = jax.make_mesh((2, 2), (EP_AXIS_NAME, FSDP_AXIS_NAME)) + run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..b0d971ee25 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -211,6 +211,115 @@ def _get_nvfp4_tensor_scale_inv(amax): return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) +def _axis_spec_contains(axis_spec, axis): + if axis is None or axis_spec is None: + return False + if isinstance(axis_spec, tuple): + return axis in axis_spec + return axis_spec == axis + + +def _spec_contains_axis(spec, axis): + return any(_axis_spec_contains(axis_spec, axis) for axis_spec in spec) + + +def _strip_axis_from_axis_spec(axis_spec, axis): + if axis is None or axis_spec is None: + return axis_spec + if isinstance(axis_spec, tuple): + stripped = tuple(a for a in axis_spec if a != axis) + if len(stripped) == 0: + return None + return stripped[0] if len(stripped) == 1 else stripped + return None if axis_spec == axis else axis_spec + + +def _strip_axis_from_spec(spec, axis): + return tuple(_strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) + + +def _common_axis(spec_a, spec_b): + axes = [] + for spec in (spec_a, spec_b): + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + for axis in axes: + if _spec_contains_axis(spec_a, axis) and _spec_contains_axis(spec_b, axis): + return axis + return None + + +def _merge_axis_spec(axis_spec_a, axis_spec_b): + axes = [] + for axis_spec in (axis_spec_a, axis_spec_b): + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def _partition_spec_from_result(mesh, result_info, fallback_spec): + if result_info is not None and result_info.sharding is not None: + return result_info.sharding + return NamedSharding(mesh, PartitionSpec(*fallback_spec)) + + +def _local_shape_from_spec(global_shape, spec, mesh): + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + axis_size = _axis_spec_size(axis_spec, mesh) + local_shape.append(dim // axis_size) + return tuple(local_shape) + + +def _axis_spec_size(axis_spec, mesh): + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def _spec_size(spec, mesh): + axis_size = 1 + for axis_spec in spec: + axis_size *= _axis_spec_size(axis_spec, mesh) + return axis_size + + +def _local_2d_sizes_from_spec(shape, spec, axis_boundary, left_size, right_size, mesh): + if len(shape) == len(spec) and len(shape) > 1: + local_shape = _local_shape_from_spec(shape, spec, mesh) + return ( + math.prod(local_shape[:axis_boundary]), + math.prod(local_shape[axis_boundary:]), + ) + + spec_size = _spec_size(spec, mesh) + if spec_size == 1: + return left_size, right_size + if left_size % spec_size == 0: + return left_size // spec_size, right_size + if right_size % spec_size == 0: + return left_size, right_size // spec_size + raise ValueError( + "Cannot derive local grouped GEMM 2D sizes from sharding spec. " + f"shape={shape}, spec={spec}, axis_boundary={axis_boundary}, " + f"left_size={left_size}, right_size={right_size}, spec_size={spec_size}" + ) + + def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -1738,6 +1847,289 @@ def impl( ) return (out,) + @staticmethod + def _parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape=None, + lhs_is_trans=None, + lhs_axis_boundary=None, + ): + del mesh + gsr = global_mesh_resource() + fsdp_axis = gsr.fsdp_resource + + lhs_data_spec = get_padded_spec(arg_infos[0]) + lhs_scale_spec = get_padded_spec(arg_infos[1]) + rhs_data_spec = get_padded_spec(arg_infos[2]) + rhs_scale_spec = get_padded_spec(arg_infos[3]) + bias_spec = get_padded_spec(arg_infos[4]) + + lhs_first_dims_spec = get_padded_spec(arg_infos[5]) + lhs_last_dims_spec = get_padded_spec(arg_infos[6]) + rhs_first_dims_spec = get_padded_spec(arg_infos[7]) + rhs_last_dims_spec = get_padded_spec(arg_infos[8]) + out_first_dims_spec = get_padded_spec(arg_infos[9]) + out_last_dims_spec = get_padded_spec(arg_infos[10]) + additional_arg_0_spec = get_padded_spec(arg_infos[11]) + additional_arg_1_spec = get_padded_spec(arg_infos[12]) + + grouped_dim_specs = ( + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + ) + grouped_dim_infos = arg_infos[5:11] + active_group_spec = next( + (spec for spec, info in zip(grouped_dim_specs, grouped_dim_infos) if info.size > 0), + (None,), + ) + if arg_infos[11].size > 1: + additional_arg_0_spec = active_group_spec + if arg_infos[12].size > 1: + additional_arg_1_spec = active_group_spec + + rhs_is_ragged = arg_infos[7].size > 0 or arg_infos[8].size > 0 + ep_axis = gsr.ep_resource + if ep_axis is not None and not rhs_is_ragged and _spec_contains_axis(active_group_spec, ep_axis): + if len(rhs_data_spec) > 0 and not _spec_contains_axis(rhs_data_spec, ep_axis): + rhs_data_spec = ( + _merge_axis_spec(rhs_data_spec[0], ep_axis), + *rhs_data_spec[1:], + ) + if len(rhs_scale_spec) > 0 and not _spec_contains_axis(rhs_scale_spec, ep_axis): + rhs_scale_spec = ( + _merge_axis_spec(rhs_scale_spec[0], ep_axis), + *rhs_scale_spec[1:], + ) + if len(bias_spec) > 0 and not _spec_contains_axis(bias_spec, ep_axis): + bias_spec = (_merge_axis_spec(bias_spec[0], ep_axis), *bias_spec[1:]) + + gather_rhs_fsdp = ( + fsdp_axis is not None + and not rhs_is_ragged + and ( + _spec_contains_axis(rhs_data_spec, fsdp_axis) + or _spec_contains_axis(rhs_scale_spec, fsdp_axis) + or _spec_contains_axis(bias_spec, fsdp_axis) + ) + ) + + if gather_rhs_fsdp: + rhs_data_spec = _strip_axis_from_spec(rhs_data_spec, fsdp_axis) + rhs_scale_spec = _strip_axis_from_spec(rhs_scale_spec, fsdp_axis) + bias_spec = _strip_axis_from_spec(bias_spec, fsdp_axis) + + reduce_axis = _common_axis(lhs_data_spec, rhs_data_spec) + if reduce_axis not in (gsr.dp_resource, gsr.fsdp_resource): + reduce_axis = None + if reduce_axis is not None and gather_rhs_fsdp: + reduce_axis = None + + if result_infos: + out_spec = get_padded_spec(result_infos[0]) + else: + out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) + + if rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: + lhs_non_contracting_dims = ( + range(lhs_axis_boundary, len(lhs_data_spec)) + if lhs_is_trans + else range(0, lhs_axis_boundary) + ) + lhs_data_spec = list(lhs_data_spec) + for out_idx, lhs_dim in enumerate(lhs_non_contracting_dims, start=1): + if out_idx < len(out_spec): + lhs_data_spec[lhs_dim] = _merge_axis_spec( + lhs_data_spec[lhs_dim], out_spec[out_idx] + ) + lhs_data_spec = tuple(lhs_data_spec) + + return ( + ( + lhs_data_spec, + lhs_scale_spec, + rhs_data_spec, + rhs_scale_spec, + bias_spec, + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + additional_arg_0_spec, + additional_arg_1_spec, + ), + out_spec, + reduce_axis, + ) + + @staticmethod + def partition( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + arg_infos, + result_infos, + ): + arg_specs, out_spec, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape, + lhs_is_trans=lhs_is_trans, + lhs_axis_boundary=lhs_axis_boundary, + ) + arg_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in arg_specs) + result_info = result_infos[0] if result_infos else None + out_sharding = (_partition_spec_from_result(mesh, result_info, out_spec),) + local_out_shape = _local_shape_from_spec(out_shape, out_spec, mesh) + local_lhs_left_size, local_lhs_right_size = _local_2d_sizes_from_spec( + arg_infos[0].shape, + arg_specs[0], + lhs_axis_boundary, + lhs_left_size, + lhs_right_size, + mesh, + ) + local_rhs_left_size, local_rhs_right_size = _local_2d_sizes_from_spec( + arg_infos[2].shape, + arg_specs[2], + rhs_axis_boundary, + rhs_left_size, + rhs_right_size, + mesh, + ) + + def sharded_impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + ): + (out,) = GroupedGemmPrimitive.impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=has_bias, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + out_shape=local_out_shape, + lhs_left_size=local_lhs_left_size, + lhs_right_size=local_lhs_right_size, + rhs_left_size=local_rhs_left_size, + rhs_right_size=local_rhs_right_size, + ) + + if reduce_axis is not None: + if is_all_reduce_in_float32(): + out = jax.lax.psum(out.astype(jnp.float32), reduce_axis).astype(out_dtype) + else: + out = jax.lax.psum(out, reduce_axis) + return (out,) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + operand_types, + result_types, + ): + del ( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + ) + + prefix = "GroupedGemm" + + def spec_for(name, rank): + if rank == 0: + return () + return tuple(f"{prefix}_{name}_{i}" for i in range(rank)) + + operand_mappings = tuple( + spec_for(f"arg{i}", len(operand_type.shape)) + for i, operand_type in enumerate(operand_types) + ) + result_mappings = tuple( + spec_for(f"out{i}", len(result_type.shape)) + for i, result_type in enumerate(result_types) + ) + return SdyShardingRule( + operand_mappings=operand_mappings, + result_mappings=result_mappings, + ) + register_primitive(GroupedGemmPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7138cfcf40..761933b3f9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -32,6 +32,8 @@ all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, get_num_devices_in_mesh, + global_mesh_resource, + lax_paral_op, ) from ..quantize import ( ScaledTensor2x, @@ -52,6 +54,59 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _merge_axis_specs(axis_specs): + axes = [] + for axis_spec in axis_specs: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def _flat_data_spec(input_spec): + return (_merge_axis_specs(input_spec),) + + +def _axis_spec_size(axis_spec, mesh): + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def _local_shape_from_spec(global_shape, spec, mesh): + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + local_shape.append(dim // _axis_spec_size(axis_spec, mesh)) + return tuple(local_shape) + + +def _pad_or_slice_to_shape(x, target_shape): + if target_shape is None or x.shape == target_shape: + return x + target_size = math.prod(target_shape) + current_size = math.prod(x.shape) + x = x.reshape(-1) + if current_size > target_size: + return x[:target_size].reshape(target_shape) + return jnp.pad(x, (0, target_size - current_size)).reshape(target_shape) + + +def _all_reduce_grouped_amax_along_dp_fsdp(amax, mesh): + gsr = global_mesh_resource() + for axis in (gsr.dp_resource, gsr.fsdp_resource): + if axis is not None and axis in mesh.axis_names: + amax = lax_paral_op(amax, jax.lax.pmax, axis, mesh) + return amax + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -1236,6 +1291,143 @@ def impl( ) return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax + @staticmethod + def _parse_partition_specs(scaling_mode, q_layout, mesh, arg_infos): + del mesh + x_spec = get_padded_spec(arg_infos[0]) + group_spec = get_padded_spec(arg_infos[2]) + if group_spec == (None,) and len(x_spec) > 0: + group_spec = (x_spec[0],) + flat_spec = _flat_data_spec(x_spec) + replicated_spec = (None,) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else replicated_spec + + rowwise_scale_inv_spec = replicated_spec + colwise_scale_inv_spec = replicated_spec + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = flat_spec if q_layout.has_colwise else replicated_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_inv_spec = group_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = group_spec if q_layout.has_colwise else replicated_spec + + updated_amax_spec = group_spec + return ( + x_spec, + group_spec, + ( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_inv_spec, + colwise_scale_inv_spec, + updated_amax_spec, + ), + ) + + @staticmethod + def partition( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + arg_infos, + result_infos, + ): + x_spec, group_spec, out_specs = GroupedQuantizePrimitive._parse_partition_specs( + scaling_mode, q_layout, mesh, arg_infos + ) + local_out_shapes = ( + tuple(_local_shape_from_spec(info.shape, spec, mesh) for info, spec in zip(result_infos, out_specs)) + if result_infos + else (None,) * len(out_specs) + ) + + arg_shardings = ( + NamedSharding(mesh, PartitionSpec(*x_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + ) + out_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in out_specs) + + def sharded_impl(x, scale, group_sizes): + ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.impl( + x, + scale, + group_sizes, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + flatten_axis=flatten_axis, + scale_dtype=scale_dtype, + ) + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv = _pad_or_slice_to_shape(rowwise_scale_inv, local_out_shapes[2]) + colwise_scale_inv = _pad_or_slice_to_shape(colwise_scale_inv, local_out_shapes[3]) + if ScalingMode(scaling_mode).is_tensor_scaling(): + updated_amax = _all_reduce_grouped_amax_along_dp_fsdp(updated_amax, mesh) + return ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, mesh, result_types, flatten_axis + + prefix = "GroupedQuantize" + input_spec = tuple(f"{prefix}_x_{i}" for i in range(len(value_types[0].shape))) + flat_spec = (f"{prefix}_flat",) + group_spec = (BATCHING + f"{prefix}_group",) + scalar_spec = (BATCHING + f"{prefix}_scalar",) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else scalar_spec + + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = flat_spec if q_layout.has_colwise else scalar_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_spec = group_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = group_spec if q_layout.has_colwise else scalar_spec + else: + rowwise_scale_spec = scalar_spec + colwise_scale_spec = scalar_spec + + return SdyShardingRule( + operand_mappings=(input_spec, group_spec, group_spec), + result_mappings=( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_spec, + colwise_scale_spec, + group_spec, + ), + ) + register_primitive(GroupedQuantizePrimitive) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index f8c30ffccb..70151a44b7 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -14,9 +14,11 @@ import warnings import jax import jax.numpy as jnp +from jax.sharding import PartitionSpec from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope +from .sharding import global_mesh_resource, with_sharding_constraint from .quantize import ( ScaledTensor, QuantizerSet, @@ -54,6 +56,10 @@ def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): return kernel +def _is_manual_mesh_axis(mesh_axis): + return mesh_axis is not None and mesh_axis in jax.sharding.get_abstract_mesh().manual_axes + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -349,6 +355,18 @@ def grouped_dense( Returns: A jnp.ndarray containing the result of the grouped linear operation """ + x_contracting_dims, kernel_contracting_dims = contracting_dims + x_contracting_dims = tex.sanitize_dims(x.ndim, x_contracting_dims) + kernel_contracting_dims = tex.sanitize_dims(kernel.ndim, kernel_contracting_dims) + contracting_dims = (x_contracting_dims, kernel_contracting_dims) + + restore_leading_ep_axis = False + if x.ndim == 3 and x.shape[0] == 1: + if x_contracting_dims == (x.ndim - 1,): + restore_leading_ep_axis = True + x = x.reshape(*x.shape[1:]) + contracting_dims = ((x.ndim - 1,), kernel_contracting_dims) + output = _grouped_dense( x, kernel, @@ -361,6 +379,8 @@ def grouped_dense( quantizer_set, kernel_fsdp_info, ) + if restore_leading_ep_axis: + output = output.reshape(1, *output.shape) return output @@ -406,10 +426,7 @@ def _grouped_dense_fwd_rule( ): use_bias = bias is not None - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." - del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled + del kernel_fsdp_info x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) @@ -478,9 +495,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): - kernel_fsdp_mesh_axis, _ = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims @@ -530,6 +545,14 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): + if kernel_fsdp_axis_idx in fwd_k_contracting_dims: + dgrad_axis_idx = fwd_x_contracting_dims[ + fwd_k_contracting_dims.index(kernel_fsdp_axis_idx) + ] + dgrad = _all_gather_kernel(dgrad, kernel_fsdp_mesh_axis, dgrad_axis_idx) + else: + dgrad = jax.lax.psum(dgrad, kernel_fsdp_mesh_axis) wgrad = tex.grouped_gemm( wgrad_x_T, @@ -539,6 +562,25 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): + if kernel_fsdp_axis_idx in fwd_k_contracting_dims: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + else: + wgrad = jax.lax.psum(wgrad, kernel_fsdp_mesh_axis) + if kernel_fsdp_mesh_axis is not None: + wgrad_spec = [None] * len(kernel_shape) + ep_resource = None + try: + ep_resource = global_mesh_resource().ep_resource + except AssertionError: + pass + if len(wgrad_spec) > 0: + wgrad_spec[0] = ep_resource + if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): + wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis + wgrad = with_sharding_constraint(wgrad, PartitionSpec(*wgrad_spec)) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 17c9a242f0..14783ecbe2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1471,7 +1471,7 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa x, kernel, group_sizes=group_sizes, - contracting_dims=((1,), (1,)), + contracting_dims=((-1,), (1,)), quantizer_set=quantizer_set, ) return out diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..16a10c860d 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -330,6 +330,7 @@ class MeshResource: tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None fsdp_resource: Axis name for full-sharded data parallelism, default is None + ep_resource: Axis name for expert parallelism (expert sharding), default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None """ @@ -338,6 +339,7 @@ class MeshResource: tp_resource: str = None tpsp_resource: str = None fsdp_resource: str = None + ep_resource: str = None pp_resource: str = None cp_resource: str = None From 60a0b50400b2309c27bd8b3d8566115004176bcc Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 26 May 2026 16:13:53 -0700 Subject: [PATCH 02/36] Add outside shard_map grouped GMM backward test --- ..._multi_process_distributed_grouped_gemm.py | 42 +++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 30a1452a07..de5e031566 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -189,27 +189,35 @@ def run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(): group_sharding, ) - def apply(x, w, group_sizes): - return te_grouped_dense( - x, - w, - group_sizes, - contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set, - kernel_fsdp_info=(FSDP_AXIS_NAME, 1), - ) - - out = jax.jit( - apply, + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return te_grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + kernel_fsdp_info=(FSDP_AXIS_NAME, 1), + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, in_shardings=(x_sharding, w_sharding, group_sharding), - out_shardings=out_sharding, + out_shardings=(out_sharding, x_sharding, w_sharding), )(x, w, group_sizes) - jax.block_until_ready(out) + out, dx, dw = jax.block_until_ready((out, dx, dw)) - local_out = np.asarray(jax.device_get(out.addressable_data(0))) assert tuple(out.sharding.spec) == (EP_AXIS_NAME, None) - assert np.all(np.isfinite(local_out)) - assert np.any(local_out != 0.0) + assert tuple(dx.sharding.spec) == (EP_AXIS_NAME, None) + assert tuple(dw.sharding.spec) == (EP_AXIS_NAME, FSDP_AXIS_NAME, None) + for value in (out, dx, dw): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)) + assert np.any(local_value != 0.0) if __name__ == "__main__": From 786fa1d961de8d5ae7acac09b544241adda9dc7d Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 28 May 2026 16:07:07 -0700 Subject: [PATCH 03/36] Use 4 GPU mesh for grouped GEMM partitioning test --- tests/jax/test_grouped_gemm_partitioning.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py index 5fb93e30ae..d69f0c29b8 100644 --- a/tests/jax/test_grouped_gemm_partitioning.py +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp import numpy as np +import pytest from jax.sharding import Mesh, NamedSharding, PartitionSpec from transformer_engine.jax.cpp_extensions.gemm import GroupedGemmPrimitive @@ -18,7 +19,10 @@ def _mesh(): - return Mesh(np.asarray(jax.devices()[:1]).reshape(1, 1), ("expert", "fsdp")) + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 2), ("expert", "fsdp")) def _arg_info(mesh, shape, spec): @@ -187,9 +191,9 @@ def test_grouped_partitioning_shardy_rules_smoke(): def test_grouped_dense_mxfp8_ep_fsdp_outside_shard_map_single_process(): mesh = _mesh() - n_groups = 2 + n_groups = 4 group_tokens = 128 - hidden = 128 + hidden = 256 out_hidden = 128 x_shape = (n_groups * group_tokens, hidden) w_shape = (n_groups, hidden, out_hidden) From ff0407dada98fed561347146162720fb94d4c51f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 1 Jun 2026 08:46:14 -0700 Subject: [PATCH 04/36] progress Signed-off-by: Jeremy Berchtold --- tests/jax/test_grouped_gemm_partitioning.py | 242 +++++++++++++++++- ..._multi_process_distributed_grouped_gemm.py | 78 ++++-- transformer_engine/jax/cpp_extensions/gemm.py | 73 ++++-- .../jax/cpp_extensions/quantization.py | 45 +++- transformer_engine/jax/dense.py | 69 ++++- transformer_engine/jax/sharding.py | 26 +- 6 files changed, 448 insertions(+), 85 deletions(-) diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py index d69f0c29b8..831e245c0e 100644 --- a/tests/jax/test_grouped_gemm_partitioning.py +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -25,6 +25,23 @@ def _mesh(): return Mesh(np.asarray(devices[:4]).reshape(2, 2), ("expert", "fsdp")) +def _mesh_with_dp_tp(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 1, 2, 1), ("expert", "dp", "fsdp", "tp")) + + +def _mesh_with_arbitrary_axis(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh( + np.asarray(devices[:4]).reshape(2, 1, 2, 1), + ("expert", "dp", "fsdp", "myaxis123"), + ) + + def _arg_info(mesh, shape, spec): return SimpleNamespace( shape=shape, @@ -40,6 +57,14 @@ def _normalize_spec(spec): return spec +def _spec_contains_axis(spec, axis): + for axis_spec in spec: + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + if axis in axis_tuple: + return True + return False + + def _mxfp8_grouped_quantizer_set(n_groups): return QuantizerFactory.create_set( scaling_mode=ScalingMode.MXFP8_1D_SCALING, @@ -50,10 +75,10 @@ def _mxfp8_grouped_quantizer_set(n_groups): ) -def test_grouped_quantize_specs_preserve_ep_and_fsdp_for_block_scales(): +def test_grouped_quantize_gathers_hidden_axis_for_block_scales(): mesh = _mesh() with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): - _, _, out_shardings, _ = GroupedQuantizePrimitive.partition( + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( jnp.float8_e4m3fn, ScalingMode.MXFP8_1D_SCALING.value, QuantizeLayout.ROWWISE, @@ -68,16 +93,17 @@ def test_grouped_quantize_specs_preserve_ep_and_fsdp_for_block_scales(): (), ) + assert tuple(arg_shardings[0].spec) == ("expert", None, None) specs = tuple(tuple(sharding.spec) for sharding in out_shardings) - assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) - assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) assert _normalize_spec(specs[4]) == ("expert",) -def test_grouped_quantize_mxfp8_colwise_specs_preserve_ep_and_fsdp(): +def test_grouped_quantize_mxfp8_colwise_specs_gather_hidden_axis(): mesh = _mesh() with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): - _, _, out_shardings, _ = GroupedQuantizePrimitive.partition( + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( jnp.float8_e4m3fn, ScalingMode.MXFP8_1D_SCALING.value, QuantizeLayout.ROWWISE_COLWISE, @@ -92,14 +118,47 @@ def test_grouped_quantize_mxfp8_colwise_specs_preserve_ep_and_fsdp(): (), ) + assert tuple(arg_shardings[0].spec) == ("expert", None, None) specs = tuple(tuple(sharding.spec) for sharding in out_shardings) - assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) - assert _normalize_spec(specs[1]) == (("expert", "fsdp"),) - assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) - assert _normalize_spec(specs[3]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[1]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) + assert _normalize_spec(specs[3]) == ("expert",) assert _normalize_spec(specs[4]) == ("expert",) +def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): + mesh = _mesh_with_dp_tp() + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "dp", ("fsdp", "tp"))), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + assert tuple(arg_shardings[1].spec) == ("expert",) + assert tuple(arg_shardings[2].spec) == ("expert",) + + out_specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(out_specs[0]) == ("expert",) + assert _normalize_spec(out_specs[2]) == ("expert",) + assert _normalize_spec(out_specs[4]) == ("expert",) + for spec in (*out_specs, *(tuple(sharding.spec) for sharding in arg_shardings)): + assert not _spec_contains_axis(spec, "tp") + + def test_grouped_gemm_rhs_weight_specs_gather_fsdp_but_preserve_ep(): mesh = _mesh() arg_infos = ( @@ -143,6 +202,169 @@ def test_grouped_gemm_rhs_weight_specs_gather_fsdp_but_preserve_ep(): assert tuple(out_sharding[0].spec) == (None, None, None) +def test_grouped_gemm_strips_unsupported_axes_preserves_dp_and_gathers_rhs_fsdp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (0,), (("fsdp", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (1,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + ) + result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "tp", None)),) + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + result_infos, + ) + + assert tuple(arg_shardings[0].spec) == ("dp",) + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(arg_shardings[5].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == ("expert", None, None) + for spec in ( + *(tuple(sharding.spec) for sharding in arg_shardings), + tuple(out_sharding[0].spec), + ): + assert not _spec_contains_axis(spec, "tp") + + +def test_grouped_gemm_reduce_axis_skips_ep_and_uses_dp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + (), + out_shape=(1, 128, 64), + lhs_is_trans=False, + lhs_axis_boundary=1, + ) + + assert reduce_axis == "dp" + + +def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): + mesh = _mesh_with_arbitrary_axis() + mesh_resource = MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): + _, _, quantize_out_shardings, quantize_arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + ), + (), + ) + + gemm_arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (0,), (("fsdp", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (1,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + ) + gemm_result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "myaxis123", None)),) + _, _, gemm_out_sharding, gemm_arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + gemm_arg_infos, + gemm_result_infos, + ) + + assert tuple(quantize_arg_shardings[0].spec) == ("expert", None, None) + assert tuple(quantize_arg_shardings[1].spec) == ("expert",) + quantize_out_specs = tuple(tuple(sharding.spec) for sharding in quantize_out_shardings) + assert _normalize_spec(quantize_out_specs[0]) == ("expert",) + assert _normalize_spec(quantize_out_specs[2]) == ("expert",) + + assert tuple(gemm_arg_shardings[0].spec) == ("dp",) + assert tuple(gemm_arg_shardings[2].spec) == ("expert",) + assert tuple(gemm_arg_shardings[3].spec) == ("expert",) + assert tuple(gemm_out_sharding[0].spec) == ("expert", None, None) + + all_specs = ( + *quantize_out_specs, + *(tuple(sharding.spec) for sharding in quantize_arg_shardings), + *(tuple(sharding.spec) for sharding in gemm_arg_shardings), + tuple(gemm_out_sharding[0].spec), + ) + for spec in all_specs: + assert not _spec_contains_axis(spec, "myaxis123") + + def test_grouped_partitioning_shardy_rules_smoke(): mesh = _mesh() quantize_rule = GroupedQuantizePrimitive.shardy_sharding_rule( diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index de5e031566..ce52126dea 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -47,11 +47,18 @@ def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): if kernel_fsdp_axis == 2 else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) ) + b_sharding = ( + NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME)) + if kernel_fsdp_axis == 2 + else NamedSharding(mesh, PartitionSpec(None, None)) + ) w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) + b_no_sharding = NamedSharding(mesh, PartitionSpec(None, None)) def init_data(): x_key = jax.random.PRNGKey(0) w_key = jax.random.PRNGKey(1) + b_key = jax.random.PRNGKey(2) x = ( jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) * jnp.asarray(0.01, dtype=jnp.bfloat16) @@ -60,10 +67,14 @@ def init_data(): jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) * jnp.asarray(0.01, dtype=jnp.bfloat16) ) - return x, w, w + b = ( + jax.random.normal(b_key, shape=(N_GROUP, w_shape[-1]), dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16) + ) + return x, w, w, b, b - def test_func(outter_x, outter_w): - in_specs = (x_sharding.spec, w_sharding.spec) + def test_func(outter_x, outter_w, outter_b): + in_specs = (x_sharding.spec, w_sharding.spec, b_sharding.spec) out_specs = x_sharding.spec @partial( @@ -73,7 +84,7 @@ def test_func(outter_x, outter_w): out_specs=out_specs, check_rep=False, ) - def sharded_group_gemm(x, w): + def sharded_group_gemm(x, w, b): group_size = x.shape[0] x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) @@ -84,23 +95,24 @@ def sharded_group_gemm(x, w): x_reshaped, w, n_groups, + bias=b, quantizer_set=quantizer_set, kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), ) output = output.reshape(*x.shape[:-1], -1) return output - def run(x, w): - output = sharded_group_gemm(x, w) + def run(x, w, b): + output = sharded_group_gemm(x, w, b) return output - output, vjp_fn = jax.vjp(run, outter_x, outter_w) - dx, dw = vjp_fn(output) - return output, dx, dw + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) + dx, dw, db = vjp_fn(output) + return output, dx, dw, db - def ref_func(outter_x, outter_w): + def ref_func(outter_x, outter_w, outter_b): - in_specs = (x_sharding.spec, w_no_sharding.spec) + in_specs = (x_sharding.spec, w_no_sharding.spec, b_no_sharding.spec) out_specs = x_sharding.spec @partial( @@ -110,51 +122,63 @@ def ref_func(outter_x, outter_w): out_specs=out_specs, check_rep=False, ) - def sharded_group_gemm(x, w): + def sharded_group_gemm(x, w, b): group_size = x.shape[0] x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) quantizer_set = _mxfp8_grouped_quantizer_set(group_size) - output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) + output = te_grouped_dense( + x_reshaped, + w, + n_groups, + bias=b, + quantizer_set=quantizer_set, + ) output = output.reshape(*x.shape[:-1], -1) return output - def run(x, w): - output = sharded_group_gemm(x, w) + def run(x, w, b): + output = sharded_group_gemm(x, w, b) return output - output, vjp_fn = jax.vjp(run, outter_x, outter_w) - dx, dw = vjp_fn(output) - return output, dx, dw + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) + dx, dw, db = vjp_fn(output) + return output, dx, dw, db - init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding)) - x, w, w_global = init_func() + init_func = jax.jit( + init_data, + out_shardings=(x_sharding, w_sharding, w_no_sharding, b_sharding, b_no_sharding), + ) + x, w, w_global, b, b_global = init_func() o_sharding = x_sharding test_func_jitted = jax.jit( test_func, - in_shardings=(x_sharding, w_sharding), - out_shardings=(o_sharding, x_sharding, w_sharding), + in_shardings=(x_sharding, w_sharding, b_sharding), + out_shardings=(o_sharding, x_sharding, w_sharding, b_sharding), ) ref_func_jitted = jax.jit( ref_func, - in_shardings=(x_sharding, w_no_sharding), - out_shardings=(o_sharding, x_sharding, w_no_sharding), + in_shardings=(x_sharding, w_no_sharding, b_no_sharding), + out_shardings=(o_sharding, x_sharding, w_no_sharding, b_no_sharding), ) - out, dx, dw = test_func_jitted(x, w) - ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) + out, dx, dw, db = test_func_jitted(x, w, b) + ref_out, ref_dx, ref_dw, ref_db = ref_func_jitted(x, w_global, b_global) - e4m3_tols = dtype_tols(jnp.float8_e4m3fn) + # Avoid creating a host scalar JAX array under the multi-process mesh in dtype_tols. + e4m3_tols = dtype_tols(jnp.float8_e4m3fn, rtol=0.25, atol=0.25) out, ref_out = jem.process_allgather((out, ref_out), tiled=True) dx, ref_dx = jem.process_allgather((dx, ref_dx), tiled=True) dw, ref_dw = jem.process_allgather((dw, ref_dw), tiled=True) + db, ref_db = jem.process_allgather((db, ref_db), tiled=True) assert_allclose(out, ref_out, **e4m3_tols) assert_allclose(dx, ref_dx, **e4m3_tols) assert_allclose(dw, ref_dw, **e4m3_tols) + assert_allclose(db, ref_db, **e4m3_tols) def run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b0d971ee25..61521873d2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -238,7 +238,30 @@ def _strip_axis_from_spec(spec, axis): return tuple(_strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) -def _common_axis(spec_a, spec_b): +def _filter_axis_spec(axis_spec, allowed_axes): + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def _filter_spec_axes(spec, allowed_axes): + return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def _supported_grouped_gemm_axes(mesh): + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + +def _common_axis(spec_a, spec_b, allowed_axes=None): axes = [] for spec in (spec_a, spec_b): for axis_spec in spec: @@ -249,6 +272,8 @@ def _common_axis(spec_a, spec_b): if axis is not None and axis not in axes: axes.append(axis) for axis in axes: + if allowed_axes is not None and axis not in allowed_axes: + continue if _spec_contains_axis(spec_a, axis) and _spec_contains_axis(spec_b, axis): return axis return None @@ -1856,24 +1881,24 @@ def _parse_partition_specs( lhs_is_trans=None, lhs_axis_boundary=None, ): - del mesh - gsr = global_mesh_resource() + gsr = global_mesh_resource(validate=False) fsdp_axis = gsr.fsdp_resource - - lhs_data_spec = get_padded_spec(arg_infos[0]) - lhs_scale_spec = get_padded_spec(arg_infos[1]) - rhs_data_spec = get_padded_spec(arg_infos[2]) - rhs_scale_spec = get_padded_spec(arg_infos[3]) - bias_spec = get_padded_spec(arg_infos[4]) - - lhs_first_dims_spec = get_padded_spec(arg_infos[5]) - lhs_last_dims_spec = get_padded_spec(arg_infos[6]) - rhs_first_dims_spec = get_padded_spec(arg_infos[7]) - rhs_last_dims_spec = get_padded_spec(arg_infos[8]) - out_first_dims_spec = get_padded_spec(arg_infos[9]) - out_last_dims_spec = get_padded_spec(arg_infos[10]) - additional_arg_0_spec = get_padded_spec(arg_infos[11]) - additional_arg_1_spec = get_padded_spec(arg_infos[12]) + allowed_axes = _supported_grouped_gemm_axes(mesh) + + lhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) + lhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[1]), allowed_axes) + rhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) + rhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[3]), allowed_axes) + bias_spec = _filter_spec_axes(get_padded_spec(arg_infos[4]), allowed_axes) + + lhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[5]), allowed_axes) + lhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[6]), allowed_axes) + rhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[7]), allowed_axes) + rhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[8]), allowed_axes) + out_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[9]), allowed_axes) + out_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[10]), allowed_axes) + additional_arg_0_spec = _filter_spec_axes(get_padded_spec(arg_infos[11]), allowed_axes) + additional_arg_1_spec = _filter_spec_axes(get_padded_spec(arg_infos[12]), allowed_axes) grouped_dim_specs = ( lhs_first_dims_spec, @@ -1924,14 +1949,15 @@ def _parse_partition_specs( rhs_scale_spec = _strip_axis_from_spec(rhs_scale_spec, fsdp_axis) bias_spec = _strip_axis_from_spec(bias_spec, fsdp_axis) - reduce_axis = _common_axis(lhs_data_spec, rhs_data_spec) - if reduce_axis not in (gsr.dp_resource, gsr.fsdp_resource): - reduce_axis = None + reducible_axes = tuple( + axis for axis in (gsr.dp_resource, gsr.fsdp_resource) if axis is not None + ) + reduce_axis = _common_axis(lhs_data_spec, rhs_data_spec, reducible_axes) if reduce_axis is not None and gather_rhs_fsdp: reduce_axis = None if result_infos: - out_spec = get_padded_spec(result_infos[0]) + out_spec = _filter_spec_axes(get_padded_spec(result_infos[0]), allowed_axes) else: out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) @@ -1998,8 +2024,7 @@ def partition( lhs_axis_boundary=lhs_axis_boundary, ) arg_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in arg_specs) - result_info = result_infos[0] if result_infos else None - out_sharding = (_partition_spec_from_result(mesh, result_info, out_spec),) + out_sharding = (NamedSharding(mesh, PartitionSpec(*out_spec)),) local_out_shape = _local_shape_from_spec(out_shape, out_spec, mesh) local_lhs_left_size, local_lhs_right_size = _local_2d_sizes_from_spec( arg_infos[0].shape, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 761933b3f9..f74ca21a38 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -72,6 +72,40 @@ def _flat_data_spec(input_spec): return (_merge_axis_specs(input_spec),) +def _normalize_flatten_axis(flatten_axis, ndim): + return flatten_axis + ndim if flatten_axis < 0 else flatten_axis + + +def _contiguous_flat_input_spec(input_spec, flatten_axis): + flatten_axis = _normalize_flatten_axis(flatten_axis, len(input_spec)) + if flatten_axis <= 0 or len(input_spec) == 0: + return (None,) * len(input_spec) + return (input_spec[0], *((None,) * (len(input_spec) - 1))) + + +def _filter_axis_spec(axis_spec, allowed_axes): + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def _filter_spec_axes(spec, allowed_axes): + return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def _supported_grouped_quantize_axes(mesh): + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + def _axis_spec_size(axis_spec, mesh): axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) axis_size = 1 @@ -1292,10 +1326,11 @@ def impl( return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax @staticmethod - def _parse_partition_specs(scaling_mode, q_layout, mesh, arg_infos): - del mesh - x_spec = get_padded_spec(arg_infos[0]) - group_spec = get_padded_spec(arg_infos[2]) + def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos): + allowed_axes = _supported_grouped_quantize_axes(mesh) + x_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) + x_spec = _contiguous_flat_input_spec(x_spec, flatten_axis) + group_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) if group_spec == (None,) and len(x_spec) > 0: group_spec = (x_spec[0],) flat_spec = _flat_data_spec(x_spec) @@ -1338,7 +1373,7 @@ def partition( result_infos, ): x_spec, group_spec, out_specs = GroupedQuantizePrimitive._parse_partition_specs( - scaling_mode, q_layout, mesh, arg_infos + scaling_mode, q_layout, flatten_axis, mesh, arg_infos ) local_out_shapes = ( tuple(_local_shape_from_spec(info.shape, spec, mesh) for info, spec in zip(result_infos, out_specs)) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 70151a44b7..35eed4d6cc 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -18,7 +18,7 @@ from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope -from .sharding import global_mesh_resource, with_sharding_constraint +from .sharding import global_mesh_resource, get_mesh_axis_size, with_sharding_constraint from .quantize import ( ScaledTensor, QuantizerSet, @@ -60,6 +60,16 @@ def _is_manual_mesh_axis(mesh_axis): return mesh_axis is not None and mesh_axis in jax.sharding.get_abstract_mesh().manual_axes +def _kernel_non_contracting_axis_to_bias_axis(kernel_axis_idx, kernel_contracting_dims): + if kernel_axis_idx in kernel_contracting_dims: + return None + bias_axis_idx = 1 + for dim in range(1, kernel_axis_idx): + if dim not in kernel_contracting_dims: + bias_axis_idx += 1 + return bias_axis_idx + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -426,9 +436,36 @@ def _grouped_dense_fwd_rule( ): use_bias = bias is not None - del kernel_fsdp_info - x_contracting_dims, k_contracting_dims = contracting_dims + local_kernel_shape = kernel.shape + kernel_was_gathered = False + bias_shape = bias.shape if use_bias else None + bias_fsdp_axis_idx = -1 + bias_was_gathered = False + + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + if ( + _is_manual_mesh_axis(kernel_fsdp_mesh_axis) + and 0 < kernel_fsdp_axis_idx < kernel.ndim + ): + kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) + kernel_was_gathered = True + + if use_bias and kernel_fsdp_axis_idx not in k_contracting_dims: + bias_fsdp_axis_idx = _kernel_non_contracting_axis_to_bias_axis( + kernel_fsdp_axis_idx, k_contracting_dims + ) + mesh_axis_size = get_mesh_axis_size(kernel_fsdp_mesh_axis) + if ( + bias_fsdp_axis_idx is not None + and 0 < bias_fsdp_axis_idx < bias.ndim + and mesh_axis_size > 1 + and bias.shape[bias_fsdp_axis_idx] * mesh_axis_size + == kernel.shape[kernel_fsdp_axis_idx] + ): + bias = _all_gather_kernel(bias, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx) + bias_was_gathered = True + flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -484,10 +521,14 @@ def _grouped_dense_fwd_rule( else ctx_kernel ), x.shape, - kernel.shape, + local_kernel_shape, use_bias, quantizer_set, flatten_axis_k, + kernel_was_gathered, + bias_shape, + bias_fsdp_axis_idx, + bias_was_gathered, ) return output, ctx @@ -508,6 +549,10 @@ def _grouped_dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + kernel_was_gathered, + bias_shape, + bias_fsdp_axis_idx, + bias_was_gathered, ) = ctx # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) @@ -545,7 +590,7 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) - if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis) and not kernel_was_gathered: if kernel_fsdp_axis_idx in fwd_k_contracting_dims: dgrad_axis_idx = fwd_x_contracting_dims[ fwd_k_contracting_dims.index(kernel_fsdp_axis_idx) @@ -563,7 +608,11 @@ def _grouped_dense_bwd_rule( group_offset=group_offset, ) if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): - if kernel_fsdp_axis_idx in fwd_k_contracting_dims: + if kernel_was_gathered: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + elif kernel_fsdp_axis_idx in fwd_k_contracting_dims: wgrad = _psum_scatter_kernel( wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx ) @@ -584,6 +633,14 @@ def _grouped_dense_bwd_rule( group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + if ( + dbias is not None + and _is_manual_mesh_axis(kernel_fsdp_mesh_axis) + and bias_was_gathered + ): + dbias = _psum_scatter_kernel( + dbias, bias_shape, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx + ) return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 16a10c860d..8dffb71196 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -133,8 +133,8 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ A wrapper function to jax.lax.with_sharding_constraint 1. Does nothing if mesh is empty. - 2. If all mesh axes are manual axes, replaces pspec with all Nones. - 3. Otherwise, strips only the manual axes. + 2. Keeps only auto axes in pspec. + 3. Returns x unchanged if no auto axes remain. """ if pspec is None: return x @@ -143,22 +143,21 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): if mesh.empty: return x - # We want to exclude the axes that already used by shard_map and shard_map - # only sets those in the abstract_mesh, not the physical one - manual_axis_names = get_abstract_mesh().manual_axes + # with_sharding_constraint can only refer to auto axes. Explicit axes are + # already fixed by the active mesh, and manual axes are managed by shard_map. + abstract_mesh = get_abstract_mesh() + auto_axis_names = set(abstract_mesh.auto_axes) # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too - def filter_manual_axes(name_or_tuple): + def filter_non_auto_axes(name_or_tuple): if isinstance(name_or_tuple, tuple): - out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + out = tuple(n for n in name_or_tuple if n in auto_axis_names) if len(out) == 0: return None return out - if name_or_tuple in manual_axis_names: - return None - return name_or_tuple + return name_or_tuple if name_or_tuple in auto_axis_names else None - cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + cleaned_axis_names = tuple(filter_non_auto_axes(name_or_tuple) for name_or_tuple in pspec) if cleaned_axis_names == (None,) * len(cleaned_axis_names): return x @@ -366,7 +365,7 @@ def global_shard_guard(resource: MeshResource): _GLOBAL_MESH_RESOURCE = old_resources -def global_mesh_resource() -> MeshResource: +def global_mesh_resource(validate: bool = True) -> MeshResource: """Get the current global mesh resource configuration. Returns: @@ -377,7 +376,8 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + if validate: + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE From 1bd6b54db95fbd7b6f02d855d6854c94a70adf82 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 13:10:53 -0700 Subject: [PATCH 05/36] Add warnings Signed-off-by: Jeremy Berchtold --- tests/jax/test_grouped_gemm_partitioning.py | 136 +++++++++--------- transformer_engine/jax/cpp_extensions/gemm.py | 109 ++++++++++---- .../jax/cpp_extensions/quantization.py | 34 ++++- 3 files changed, 182 insertions(+), 97 deletions(-) diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py index 831e245c0e..e62e33a616 100644 --- a/tests/jax/test_grouped_gemm_partitioning.py +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -132,20 +132,21 @@ def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): with jax.set_mesh(mesh), global_shard_guard( MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") ): - _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( - jnp.float8_e4m3fn, - ScalingMode.MXFP8_1D_SCALING.value, - QuantizeLayout.ROWWISE, - -1, - jnp.float8_e8m0fnu, - mesh, - ( - _arg_info(mesh, (8, 128, 128), ("expert", "dp", ("fsdp", "tp"))), - _arg_info(mesh, (8,), (("expert", "tp"),)), - _arg_info(mesh, (8,), (("expert", "tp"),)), - ), - (), - ) + with pytest.warns(RuntimeWarning, match="Grouped quantize.*tp"): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "dp", ("fsdp", "tp"))), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + ), + (), + ) assert tuple(arg_shardings[0].spec) == ("expert", None, None) assert tuple(arg_shardings[1].spec) == ("expert",) @@ -223,25 +224,26 @@ def test_grouped_gemm_strips_unsupported_axes_preserves_dp_and_gathers_rhs_fsdp( with jax.set_mesh(mesh), global_shard_guard( MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") ): - _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( - False, - False, - ScalingMode.NO_SCALING.value, - jnp.bfloat16, - False, - False, - False, - 1, - 1, - (1, 128, 64), - 128, - 64, - 128, - 64, - mesh, - arg_infos, - result_infos, - ) + with pytest.warns(RuntimeWarning, match="Grouped GEMM.*tp"): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + result_infos, + ) assert tuple(arg_shardings[0].spec) == ("dp",) assert tuple(arg_shardings[2].spec) == ("expert",) @@ -293,20 +295,21 @@ def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): mesh_resource = MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") with jax.set_mesh(mesh), global_shard_guard(mesh_resource): - _, _, quantize_out_shardings, quantize_arg_shardings = GroupedQuantizePrimitive.partition( - jnp.float8_e4m3fn, - ScalingMode.MXFP8_1D_SCALING.value, - QuantizeLayout.ROWWISE, - -1, - jnp.float8_e8m0fnu, - mesh, - ( - _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), - _arg_info(mesh, (8,), (("expert", "myaxis123"),)), - _arg_info(mesh, (8,), (("expert", "myaxis123"),)), - ), - (), - ) + with pytest.warns(RuntimeWarning, match="Grouped quantize.*myaxis123"): + _, _, quantize_out_shardings, quantize_arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + ), + (), + ) gemm_arg_infos = ( _arg_info(mesh, (8192,), (("dp", "myaxis123"),)), @@ -324,25 +327,26 @@ def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): _arg_info(mesh, (0,), (("myaxis123",),)), ) gemm_result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "myaxis123", None)),) - _, _, gemm_out_sharding, gemm_arg_shardings = GroupedGemmPrimitive.partition( - False, - False, - ScalingMode.NO_SCALING.value, - jnp.bfloat16, - False, - False, - False, - 1, - 1, - (1, 128, 64), - 128, - 64, - 128, - 64, - mesh, - gemm_arg_infos, - gemm_result_infos, - ) + with pytest.warns(RuntimeWarning, match="Grouped GEMM.*myaxis123"): + _, _, gemm_out_sharding, gemm_arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + gemm_arg_infos, + gemm_result_infos, + ) assert tuple(quantize_arg_shardings[0].spec) == ("expert", None, None) assert tuple(quantize_arg_shardings[1].spec) == ("expert",) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 61521873d2..ce364597c6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -252,6 +252,30 @@ def _filter_spec_axes(spec, allowed_axes): return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) +def _spec_axes(spec): + axes = [] + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + return axes + + +def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): + ignored_axes = tuple(axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec)) + if ignored_axes: + warnings.warn( + "Grouped GEMM custom partitioning will ignore/replicate sharding " + f"axes {ignored_axes} from {arg_name}; only DP/FSDP/EP grouped " + "partitioning axes are preserved.", + RuntimeWarning, + stacklevel=3, + ) + + def _supported_grouped_gemm_axes(mesh): gsr = global_mesh_resource(validate=False) return { @@ -1885,20 +1909,21 @@ def _parse_partition_specs( fsdp_axis = gsr.fsdp_resource allowed_axes = _supported_grouped_gemm_axes(mesh) - lhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) - lhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[1]), allowed_axes) - rhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) - rhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[3]), allowed_axes) - bias_spec = _filter_spec_axes(get_padded_spec(arg_infos[4]), allowed_axes) - - lhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[5]), allowed_axes) - lhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[6]), allowed_axes) - rhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[7]), allowed_axes) - rhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[8]), allowed_axes) - out_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[9]), allowed_axes) - out_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[10]), allowed_axes) - additional_arg_0_spec = _filter_spec_axes(get_padded_spec(arg_infos[11]), allowed_axes) - additional_arg_1_spec = _filter_spec_axes(get_padded_spec(arg_infos[12]), allowed_axes) + original_arg_specs = tuple(get_padded_spec(arg_info) for arg_info in arg_infos) + lhs_data_spec = _filter_spec_axes(original_arg_specs[0], allowed_axes) + lhs_scale_spec = _filter_spec_axes(original_arg_specs[1], allowed_axes) + rhs_data_spec = _filter_spec_axes(original_arg_specs[2], allowed_axes) + rhs_scale_spec = _filter_spec_axes(original_arg_specs[3], allowed_axes) + bias_spec = _filter_spec_axes(original_arg_specs[4], allowed_axes) + + lhs_first_dims_spec = _filter_spec_axes(original_arg_specs[5], allowed_axes) + lhs_last_dims_spec = _filter_spec_axes(original_arg_specs[6], allowed_axes) + rhs_first_dims_spec = _filter_spec_axes(original_arg_specs[7], allowed_axes) + rhs_last_dims_spec = _filter_spec_axes(original_arg_specs[8], allowed_axes) + out_first_dims_spec = _filter_spec_axes(original_arg_specs[9], allowed_axes) + out_last_dims_spec = _filter_spec_axes(original_arg_specs[10], allowed_axes) + additional_arg_0_spec = _filter_spec_axes(original_arg_specs[11], allowed_axes) + additional_arg_1_spec = _filter_spec_axes(original_arg_specs[12], allowed_axes) grouped_dim_specs = ( lhs_first_dims_spec, @@ -1957,8 +1982,10 @@ def _parse_partition_specs( reduce_axis = None if result_infos: - out_spec = _filter_spec_axes(get_padded_spec(result_infos[0]), allowed_axes) + original_out_spec = get_padded_spec(result_infos[0]) + out_spec = _filter_spec_axes(original_out_spec, allowed_axes) else: + original_out_spec = None out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) if rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: @@ -1975,22 +2002,46 @@ def _parse_partition_specs( ) lhs_data_spec = tuple(lhs_data_spec) - return ( + final_arg_specs = ( + lhs_data_spec, + lhs_scale_spec, + rhs_data_spec, + rhs_scale_spec, + bias_spec, + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + additional_arg_0_spec, + additional_arg_1_spec, + ) + for arg_name, original_spec, partition_spec in zip( ( - lhs_data_spec, - lhs_scale_spec, - rhs_data_spec, - rhs_scale_spec, - bias_spec, - lhs_first_dims_spec, - lhs_last_dims_spec, - rhs_first_dims_spec, - rhs_last_dims_spec, - out_first_dims_spec, - out_last_dims_spec, - additional_arg_0_spec, - additional_arg_1_spec, + "lhs_data", + "lhs_scale_inv", + "rhs_data", + "rhs_scale_inv", + "bias", + "lhs_first_dims", + "lhs_last_dims", + "rhs_first_dims", + "rhs_last_dims", + "out_first_dims", + "out_last_dims", + "additional_arg_0", + "additional_arg_1", ), + original_arg_specs, + final_arg_specs, + ): + _warn_if_axes_ignored(arg_name, original_spec, partition_spec) + if original_out_spec is not None: + _warn_if_axes_ignored("output", original_out_spec, out_spec) + + return ( + final_arg_specs, out_spec, reduce_axis, ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f74ca21a38..31884fe354 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,6 +6,7 @@ from functools import reduce from typing import Tuple, Optional, Union import math +import warnings import jax @@ -97,6 +98,30 @@ def _filter_spec_axes(spec, allowed_axes): return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) +def _spec_axes(spec): + axes = [] + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + return axes + + +def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): + ignored_axes = tuple(axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec)) + if ignored_axes: + warnings.warn( + "Grouped quantize custom partitioning will ignore/replicate sharding " + f"axes {ignored_axes} from {arg_name}; only supported packed grouped " + "data axes are preserved.", + RuntimeWarning, + stacklevel=3, + ) + + def _supported_grouped_quantize_axes(mesh): gsr = global_mesh_resource(validate=False) return { @@ -1328,11 +1353,16 @@ def impl( @staticmethod def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos): allowed_axes = _supported_grouped_quantize_axes(mesh) - x_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) + original_x_spec = get_padded_spec(arg_infos[0]) + x_spec = _filter_spec_axes(original_x_spec, allowed_axes) x_spec = _contiguous_flat_input_spec(x_spec, flatten_axis) - group_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) + _warn_if_axes_ignored("x", original_x_spec, x_spec) + + original_group_spec = get_padded_spec(arg_infos[2]) + group_spec = _filter_spec_axes(original_group_spec, allowed_axes) if group_spec == (None,) and len(x_spec) > 0: group_spec = (x_spec[0],) + _warn_if_axes_ignored("group_sizes", original_group_spec, group_spec) flat_spec = _flat_data_spec(x_spec) replicated_spec = (None,) From 3c30c9b3912f0a4b092874e9a5916e64c48a75ec Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 13:25:35 -0700 Subject: [PATCH 06/36] Remove kernel_fsdp_info Signed-off-by: Jeremy Berchtold --- tests/jax/test_grouped_gemm_partitioning.py | 31 +++- ..._multi_process_distributed_grouped_gemm.py | 2 - .../jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/dense.py | 136 +----------------- 4 files changed, 31 insertions(+), 140 deletions(-) diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py index e62e33a616..9fa34a8f9f 100644 --- a/tests/jax/test_grouped_gemm_partitioning.py +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -127,6 +127,30 @@ def test_grouped_quantize_mxfp8_colwise_specs_gather_hidden_axis(): assert _normalize_spec(specs[4]) == ("expert",) +def test_grouped_quantize_preserves_row_side_fsdp_for_kernel(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 64), ("expert", "fsdp", None)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", "fsdp", None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) + + def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): mesh = _mesh_with_dp_tp() with jax.set_mesh(mesh), global_shard_guard( @@ -148,13 +172,13 @@ def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): (), ) - assert tuple(arg_shardings[0].spec) == ("expert", None, None) + assert tuple(arg_shardings[0].spec) == ("expert", "dp", None) assert tuple(arg_shardings[1].spec) == ("expert",) assert tuple(arg_shardings[2].spec) == ("expert",) out_specs = tuple(tuple(sharding.spec) for sharding in out_shardings) - assert _normalize_spec(out_specs[0]) == ("expert",) - assert _normalize_spec(out_specs[2]) == ("expert",) + assert _normalize_spec(out_specs[0]) == (("expert", "dp"),) + assert _normalize_spec(out_specs[2]) == (("expert", "dp"),) assert _normalize_spec(out_specs[4]) == ("expert",) for spec in (*out_specs, *(tuple(sharding.spec) for sharding in arg_shardings)): assert not _spec_contains_axis(spec, "tp") @@ -455,7 +479,6 @@ def apply(x, w): group_sizes, contracting_dims=((1,), (1,)), quantizer_set=quantizer_set, - kernel_fsdp_info=("fsdp", 1), ) out, vjp_fn = jax.vjp(apply, x, w) diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index ce52126dea..cb7ea2cd60 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -97,7 +97,6 @@ def sharded_group_gemm(x, w, b): n_groups, bias=b, quantizer_set=quantizer_set, - kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), ) output = output.reshape(*x.shape[:-1], -1) return output @@ -221,7 +220,6 @@ def apply(x, w): group_sizes, contracting_dims=((1,), (1,)), quantizer_set=quantizer_set, - kernel_fsdp_info=(FSDP_AXIS_NAME, 1), ) out, vjp_fn = jax.vjp(apply, x, w) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 31884fe354..34b2cadfab 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -81,7 +81,7 @@ def _contiguous_flat_input_spec(input_spec, flatten_axis): flatten_axis = _normalize_flatten_axis(flatten_axis, len(input_spec)) if flatten_axis <= 0 or len(input_spec) == 0: return (None,) * len(input_spec) - return (input_spec[0], *((None,) * (len(input_spec) - 1))) + return (*input_spec[:flatten_axis], *((None,) * (len(input_spec) - flatten_axis))) def _filter_axis_spec(axis_spec, allowed_axes): diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 35eed4d6cc..13ee446fbb 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -14,11 +14,9 @@ import warnings import jax import jax.numpy as jnp -from jax.sharding import PartitionSpec from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope -from .sharding import global_mesh_resource, get_mesh_axis_size, with_sharding_constraint from .quantize import ( ScaledTensor, QuantizerSet, @@ -28,48 +26,6 @@ ) -def _all_gather_kernel(kernel, mesh_axis, axis_idx): - assert mesh_axis is not None - assert 0 < axis_idx < len(kernel.shape) - - # TODO(Ming Hunag): Add a condition branch for with/without shmap. - kernel_shape = kernel.shape - kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :]) - global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx) - global_kernel = global_kernel.reshape(*kernel_whole_shape) - return global_kernel - - -def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): - assert mesh_axis is not None - assert 0 < axis_idx < len(scattered_kernel_shape) - - # TODO(Ming Hunag): Add a condition branch for with/without shmap. - kernel = kernel.reshape( - *scattered_kernel_shape[:axis_idx], - -1, - scattered_kernel_shape[axis_idx], - *scattered_kernel_shape[axis_idx + 1 :], - ) - kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx) - kernel = kernel.reshape(scattered_kernel_shape) - return kernel - - -def _is_manual_mesh_axis(mesh_axis): - return mesh_axis is not None and mesh_axis in jax.sharding.get_abstract_mesh().manual_axes - - -def _kernel_non_contracting_axis_to_bias_axis(kernel_axis_idx, kernel_contracting_dims): - if kernel_axis_idx in kernel_contracting_dims: - return None - bias_axis_idx = 1 - for dim in range(1, kernel_axis_idx): - if dim not in kernel_contracting_dims: - bias_axis_idx += 1 - return bias_axis_idx - - def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -341,7 +297,6 @@ def grouped_dense( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, - kernel_fsdp_info: Tuple[str, int] = (None, -1), ): """ Perform grouped dense (linear) layer transformation with optional quantization. @@ -357,10 +312,6 @@ def grouped_dense( preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) quantizer_set: Set of quantizers for FP8 quantization of the input and output - kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix - represented in the format (str, int). The first element is the - FSDP mesh axis, and the second element is the dimension along - which the weight is sharded. Returns: A jnp.ndarray containing the result of the grouped linear operation @@ -387,14 +338,13 @@ def grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ) if restore_leading_ep_axis: output = output.reshape(1, *output.shape) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) def _grouped_dense( x, kernel, @@ -405,7 +355,6 @@ def _grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ): output, _ = _grouped_dense_fwd_rule( x, @@ -417,7 +366,6 @@ def _grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ) return output @@ -432,39 +380,10 @@ def _grouped_dense_fwd_rule( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ): use_bias = bias is not None x_contracting_dims, k_contracting_dims = contracting_dims - local_kernel_shape = kernel.shape - kernel_was_gathered = False - bias_shape = bias.shape if use_bias else None - bias_fsdp_axis_idx = -1 - bias_was_gathered = False - - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - if ( - _is_manual_mesh_axis(kernel_fsdp_mesh_axis) - and 0 < kernel_fsdp_axis_idx < kernel.ndim - ): - kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) - kernel_was_gathered = True - - if use_bias and kernel_fsdp_axis_idx not in k_contracting_dims: - bias_fsdp_axis_idx = _kernel_non_contracting_axis_to_bias_axis( - kernel_fsdp_axis_idx, k_contracting_dims - ) - mesh_axis_size = get_mesh_axis_size(kernel_fsdp_mesh_axis) - if ( - bias_fsdp_axis_idx is not None - and 0 < bias_fsdp_axis_idx < bias.ndim - and mesh_axis_size > 1 - and bias.shape[bias_fsdp_axis_idx] * mesh_axis_size - == kernel.shape[kernel_fsdp_axis_idx] - ): - bias = _all_gather_kernel(bias, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx) - bias_was_gathered = True flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -521,23 +440,17 @@ def _grouped_dense_fwd_rule( else ctx_kernel ), x.shape, - local_kernel_shape, + kernel.shape, use_bias, quantizer_set, flatten_axis_k, - kernel_was_gathered, - bias_shape, - bias_fsdp_axis_idx, - bias_was_gathered, ) return output, ctx def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad ): - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims ( @@ -549,10 +462,6 @@ def _grouped_dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, - kernel_was_gathered, - bias_shape, - bias_fsdp_axis_idx, - bias_was_gathered, ) = ctx # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) @@ -590,14 +499,6 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) - if _is_manual_mesh_axis(kernel_fsdp_mesh_axis) and not kernel_was_gathered: - if kernel_fsdp_axis_idx in fwd_k_contracting_dims: - dgrad_axis_idx = fwd_x_contracting_dims[ - fwd_k_contracting_dims.index(kernel_fsdp_axis_idx) - ] - dgrad = _all_gather_kernel(dgrad, kernel_fsdp_mesh_axis, dgrad_axis_idx) - else: - dgrad = jax.lax.psum(dgrad, kernel_fsdp_mesh_axis) wgrad = tex.grouped_gemm( wgrad_x_T, @@ -607,40 +508,9 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) - if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): - if kernel_was_gathered: - wgrad = _psum_scatter_kernel( - wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) - elif kernel_fsdp_axis_idx in fwd_k_contracting_dims: - wgrad = _psum_scatter_kernel( - wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) - else: - wgrad = jax.lax.psum(wgrad, kernel_fsdp_mesh_axis) - if kernel_fsdp_mesh_axis is not None: - wgrad_spec = [None] * len(kernel_shape) - ep_resource = None - try: - ep_resource = global_mesh_resource().ep_resource - except AssertionError: - pass - if len(wgrad_spec) > 0: - wgrad_spec[0] = ep_resource - if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): - wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis - wgrad = with_sharding_constraint(wgrad, PartitionSpec(*wgrad_spec)) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None - if ( - dbias is not None - and _is_manual_mesh_axis(kernel_fsdp_mesh_axis) - and bias_was_gathered - ): - dbias = _psum_scatter_kernel( - dbias, bias_shape, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx - ) return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set From 273e066e5be8341b329988defe331c3e5f24e2c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 20:26:43 +0000 Subject: [PATCH 07/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_grouped_gemm_partitioning.py | 28 ++++++++++--------- ..._multi_process_distributed_grouped_gemm.py | 17 +++++------ transformer_engine/jax/cpp_extensions/gemm.py | 10 +++++-- .../jax/cpp_extensions/quantization.py | 9 ++++-- 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py index 9fa34a8f9f..be2487a8df 100644 --- a/tests/jax/test_grouped_gemm_partitioning.py +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -320,19 +320,21 @@ def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): with jax.set_mesh(mesh), global_shard_guard(mesh_resource): with pytest.warns(RuntimeWarning, match="Grouped quantize.*myaxis123"): - _, _, quantize_out_shardings, quantize_arg_shardings = GroupedQuantizePrimitive.partition( - jnp.float8_e4m3fn, - ScalingMode.MXFP8_1D_SCALING.value, - QuantizeLayout.ROWWISE, - -1, - jnp.float8_e8m0fnu, - mesh, - ( - _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), - _arg_info(mesh, (8,), (("expert", "myaxis123"),)), - _arg_info(mesh, (8,), (("expert", "myaxis123"),)), - ), - (), + _, _, quantize_out_shardings, quantize_arg_shardings = ( + GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + ), + (), + ) ) gemm_arg_infos = ( diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index cb7ea2cd60..b2978922e1 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -59,18 +59,15 @@ def init_data(): x_key = jax.random.PRNGKey(0) w_key = jax.random.PRNGKey(1) b_key = jax.random.PRNGKey(2) - x = ( - jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) - * jnp.asarray(0.01, dtype=jnp.bfloat16) + x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) * jnp.asarray( + 0.01, dtype=jnp.bfloat16 ) - w = ( - jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) - * jnp.asarray(0.01, dtype=jnp.bfloat16) - ) - b = ( - jax.random.normal(b_key, shape=(N_GROUP, w_shape[-1]), dtype=jnp.bfloat16) - * jnp.asarray(0.01, dtype=jnp.bfloat16) + w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) * jnp.asarray( + 0.01, dtype=jnp.bfloat16 ) + b = jax.random.normal( + b_key, shape=(N_GROUP, w_shape[-1]), dtype=jnp.bfloat16 + ) * jnp.asarray(0.01, dtype=jnp.bfloat16) return x, w, w, b, b def test_func(outter_x, outter_w, outter_b): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ce364597c6..9bd0a1700e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -265,7 +265,9 @@ def _spec_axes(spec): def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): - ignored_axes = tuple(axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec)) + ignored_axes = tuple( + axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec) + ) if ignored_axes: warnings.warn( "Grouped GEMM custom partitioning will ignore/replicate sharding " @@ -1945,7 +1947,11 @@ def _parse_partition_specs( rhs_is_ragged = arg_infos[7].size > 0 or arg_infos[8].size > 0 ep_axis = gsr.ep_resource - if ep_axis is not None and not rhs_is_ragged and _spec_contains_axis(active_group_spec, ep_axis): + if ( + ep_axis is not None + and not rhs_is_ragged + and _spec_contains_axis(active_group_spec, ep_axis) + ): if len(rhs_data_spec) > 0 and not _spec_contains_axis(rhs_data_spec, ep_axis): rhs_data_spec = ( _merge_axis_spec(rhs_data_spec[0], ep_axis), diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 34b2cadfab..828bdb6067 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -111,7 +111,9 @@ def _spec_axes(spec): def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): - ignored_axes = tuple(axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec)) + ignored_axes = tuple( + axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec) + ) if ignored_axes: warnings.warn( "Grouped quantize custom partitioning will ignore/replicate sharding " @@ -1406,7 +1408,10 @@ def partition( scaling_mode, q_layout, flatten_axis, mesh, arg_infos ) local_out_shapes = ( - tuple(_local_shape_from_spec(info.shape, spec, mesh) for info, spec in zip(result_infos, out_specs)) + tuple( + _local_shape_from_spec(info.shape, spec, mesh) + for info, spec in zip(result_infos, out_specs) + ) if result_infos else (None,) * len(out_specs) ) From fe906fda11d93486e4e9fbd35d648c844a51865a Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 13:42:33 -0700 Subject: [PATCH 08/36] Remove unnecessary reshape Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/dense.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 13ee446fbb..b60810f5eb 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -321,13 +321,6 @@ def grouped_dense( kernel_contracting_dims = tex.sanitize_dims(kernel.ndim, kernel_contracting_dims) contracting_dims = (x_contracting_dims, kernel_contracting_dims) - restore_leading_ep_axis = False - if x.ndim == 3 and x.shape[0] == 1: - if x_contracting_dims == (x.ndim - 1,): - restore_leading_ep_axis = True - x = x.reshape(*x.shape[1:]) - contracting_dims = ((x.ndim - 1,), kernel_contracting_dims) - output = _grouped_dense( x, kernel, @@ -339,8 +332,6 @@ def grouped_dense( group_offset, quantizer_set, ) - if restore_leading_ep_axis: - output = output.reshape(1, *output.shape) return output From e76d20aff1ae112f89fd1094ad4e7806320495d4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 14:06:13 -0700 Subject: [PATCH 09/36] Remove multi-process grouped GEMM tests Signed-off-by: Jeremy Berchtold --- qa/L1_jax_distributed_unittest/test.sh | 1 - ..._multi_process_distributed_grouped_gemm.py | 263 ------------------ 2 files changed, 264 deletions(-) delete mode 100644 tests/jax/test_multi_process_distributed_grouped_gemm.py diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 4f92d1c783..031bb72995 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -38,7 +38,6 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" # TODO(Phuong): add this test back after it is verified -# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py deleted file mode 100644 index b2978922e1..0000000000 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -from functools import partial - -import jax -import jax.numpy as jnp -import jax.experimental.multihost_utils as jem -import numpy as np -from jax.experimental import shard_map -from jax.sharding import NamedSharding, PartitionSpec - -from transformer_engine.jax.dense import grouped_dense as te_grouped_dense -from transformer_engine.jax.quantize import ( - QuantizerFactory, - ScalingMode, -) -from transformer_engine.jax.sharding import MeshResource, global_shard_guard - -from utils import assert_allclose, dtype_tols - - -N_GROUP = 8 -EP_AXIS_NAME = "ep" -FSDP_AXIS_NAME = "fsdp" -MESH_AXIS_NAME = FSDP_AXIS_NAME - - -def _mxfp8_grouped_quantizer_set(n_groups): - return QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e4m3fn, - is_2x2x=True, - n_groups=n_groups, - ) - - -def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): - assert kernel_fsdp_axis in [1, 2] - x_shape, w_shape = data_shapes - - x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None)) - w_sharding = ( - NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME)) - if kernel_fsdp_axis == 2 - else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) - ) - b_sharding = ( - NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME)) - if kernel_fsdp_axis == 2 - else NamedSharding(mesh, PartitionSpec(None, None)) - ) - w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) - b_no_sharding = NamedSharding(mesh, PartitionSpec(None, None)) - - def init_data(): - x_key = jax.random.PRNGKey(0) - w_key = jax.random.PRNGKey(1) - b_key = jax.random.PRNGKey(2) - x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) * jnp.asarray( - 0.01, dtype=jnp.bfloat16 - ) - w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) * jnp.asarray( - 0.01, dtype=jnp.bfloat16 - ) - b = jax.random.normal( - b_key, shape=(N_GROUP, w_shape[-1]), dtype=jnp.bfloat16 - ) * jnp.asarray(0.01, dtype=jnp.bfloat16) - return x, w, w, b, b - - def test_func(outter_x, outter_w, outter_b): - in_specs = (x_sharding.spec, w_sharding.spec, b_sharding.spec) - out_specs = x_sharding.spec - - @partial( - shard_map.shard_map, - mesh=mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - def sharded_group_gemm(x, w, b): - group_size = x.shape[0] - x_reshaped = x.reshape(-1, x.shape[-1]) - n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - - quantizer_set = _mxfp8_grouped_quantizer_set(group_size) - - output = te_grouped_dense( - x_reshaped, - w, - n_groups, - bias=b, - quantizer_set=quantizer_set, - ) - output = output.reshape(*x.shape[:-1], -1) - return output - - def run(x, w, b): - output = sharded_group_gemm(x, w, b) - return output - - output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) - dx, dw, db = vjp_fn(output) - return output, dx, dw, db - - def ref_func(outter_x, outter_w, outter_b): - - in_specs = (x_sharding.spec, w_no_sharding.spec, b_no_sharding.spec) - out_specs = x_sharding.spec - - @partial( - shard_map.shard_map, - mesh=mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - def sharded_group_gemm(x, w, b): - group_size = x.shape[0] - x_reshaped = x.reshape(-1, x.shape[-1]) - n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - - quantizer_set = _mxfp8_grouped_quantizer_set(group_size) - output = te_grouped_dense( - x_reshaped, - w, - n_groups, - bias=b, - quantizer_set=quantizer_set, - ) - output = output.reshape(*x.shape[:-1], -1) - return output - - def run(x, w, b): - output = sharded_group_gemm(x, w, b) - return output - - output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) - dx, dw, db = vjp_fn(output) - return output, dx, dw, db - - init_func = jax.jit( - init_data, - out_shardings=(x_sharding, w_sharding, w_no_sharding, b_sharding, b_no_sharding), - ) - x, w, w_global, b, b_global = init_func() - - o_sharding = x_sharding - test_func_jitted = jax.jit( - test_func, - in_shardings=(x_sharding, w_sharding, b_sharding), - out_shardings=(o_sharding, x_sharding, w_sharding, b_sharding), - ) - ref_func_jitted = jax.jit( - ref_func, - in_shardings=(x_sharding, w_no_sharding, b_no_sharding), - out_shardings=(o_sharding, x_sharding, w_no_sharding, b_no_sharding), - ) - - out, dx, dw, db = test_func_jitted(x, w, b) - ref_out, ref_dx, ref_dw, ref_db = ref_func_jitted(x, w_global, b_global) - - # Avoid creating a host scalar JAX array under the multi-process mesh in dtype_tols. - e4m3_tols = dtype_tols(jnp.float8_e4m3fn, rtol=0.25, atol=0.25) - - out, ref_out = jem.process_allgather((out, ref_out), tiled=True) - dx, ref_dx = jem.process_allgather((dx, ref_dx), tiled=True) - dw, ref_dw = jem.process_allgather((dw, ref_dw), tiled=True) - db, ref_db = jem.process_allgather((db, ref_db), tiled=True) - - assert_allclose(out, ref_out, **e4m3_tols) - assert_allclose(dx, ref_dx, **e4m3_tols) - assert_allclose(dw, ref_dw, **e4m3_tols) - assert_allclose(db, ref_db, **e4m3_tols) - - -def run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(): - n_groups = 4 - group_tokens = 128 - hidden = 256 - out_hidden = 128 - x_shape = (n_groups * group_tokens, hidden) - w_shape = (n_groups, hidden, out_hidden) - quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) - - x_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) - w_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, FSDP_AXIS_NAME, None)) - group_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME)) - out_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) - - with mesh, global_shard_guard( - MeshResource(ep_resource=EP_AXIS_NAME, fsdp_resource=FSDP_AXIS_NAME) - ): - x = jax.device_put( - jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) - * jnp.asarray(0.01, dtype=jnp.bfloat16), - x_sharding, - ) - w = jax.device_put( - jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) - * jnp.asarray(0.01, dtype=jnp.bfloat16), - w_sharding, - ) - group_sizes = jax.device_put( - jnp.full((n_groups,), group_tokens, dtype=jnp.int32), - group_sharding, - ) - - def apply_with_vjp(x, w, group_sizes): - def apply(x, w): - return te_grouped_dense( - x, - w, - group_sizes, - contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set, - ) - - out, vjp_fn = jax.vjp(apply, x, w) - dx, dw = vjp_fn(out) - return out, dx, dw - - out, dx, dw = jax.jit( - apply_with_vjp, - in_shardings=(x_sharding, w_sharding, group_sharding), - out_shardings=(out_sharding, x_sharding, w_sharding), - )(x, w, group_sizes) - out, dx, dw = jax.block_until_ready((out, dx, dw)) - - assert tuple(out.sharding.spec) == (EP_AXIS_NAME, None) - assert tuple(dx.sharding.spec) == (EP_AXIS_NAME, None) - assert tuple(dw.sharding.spec) == (EP_AXIS_NAME, FSDP_AXIS_NAME, None) - for value in (out, dx, dw): - local_value = np.asarray(jax.device_get(value.addressable_data(0))) - assert np.all(np.isfinite(local_value)) - assert np.any(local_value != 0.0) - - -if __name__ == "__main__": - import sys - - coord_addr = sys.argv[1] - proc_id = int(sys.argv[2]) - num_procs = int(sys.argv[3]) - - jax.distributed.initialize( - coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id - ) - - mesh = jax.make_mesh((num_procs,), (FSDP_AXIS_NAME,)) - - with mesh: - data_shapes = [((4, 16, 128, 7168), (7168, 2048))] - for data_shape in data_shapes: - for kernel_fsdp_axis in [1, 2]: - test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) - - if num_procs == 4: - mesh = jax.make_mesh((2, 2), (EP_AXIS_NAME, FSDP_AXIS_NAME)) - run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map() From 1aa1e827597535cbd13dfafc9da618cfe421d0fe Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 14:24:09 -0700 Subject: [PATCH 10/36] Refactor helpers into sharding.py Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 225 ++++-------------- .../jax/cpp_extensions/quantization.py | 82 +------ transformer_engine/jax/sharding.py | 150 ++++++++++++ 3 files changed, 205 insertions(+), 252 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9bd0a1700e..655aa6e3f9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -49,7 +49,16 @@ ) from .misc import get_padded_spec, is_all_reduce_in_float32, get_min_device_compute_capability from ..sharding import ( + common_spec_axis, + filter_spec_axes, global_mesh_resource, + local_2d_sizes_from_spec, + local_shape_from_spec, + merge_axis_specs, + spec_axes, + spec_contains_axis, + strip_axis_from_spec, + supported_grouped_partition_axes, tpsp_axis_size, dp_or_fsdp_axis_size, ) @@ -211,62 +220,9 @@ def _get_nvfp4_tensor_scale_inv(amax): return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) -def _axis_spec_contains(axis_spec, axis): - if axis is None or axis_spec is None: - return False - if isinstance(axis_spec, tuple): - return axis in axis_spec - return axis_spec == axis - - -def _spec_contains_axis(spec, axis): - return any(_axis_spec_contains(axis_spec, axis) for axis_spec in spec) - - -def _strip_axis_from_axis_spec(axis_spec, axis): - if axis is None or axis_spec is None: - return axis_spec - if isinstance(axis_spec, tuple): - stripped = tuple(a for a in axis_spec if a != axis) - if len(stripped) == 0: - return None - return stripped[0] if len(stripped) == 1 else stripped - return None if axis_spec == axis else axis_spec - - -def _strip_axis_from_spec(spec, axis): - return tuple(_strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) - - -def _filter_axis_spec(axis_spec, allowed_axes): - if axis_spec is None: - return None - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) - if len(axes) == 0: - return None - return axes[0] if len(axes) == 1 else axes - - -def _filter_spec_axes(spec, allowed_axes): - return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) - - -def _spec_axes(spec): - axes = [] - for axis_spec in spec: - if axis_spec is None: - continue - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - for axis in axis_tuple: - if axis is not None and axis not in axes: - axes.append(axis) - return axes - - def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): ignored_axes = tuple( - axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec) + axis for axis in spec_axes(original_spec) if axis not in spec_axes(partition_spec) ) if ignored_axes: warnings.warn( @@ -278,99 +234,6 @@ def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): ) -def _supported_grouped_gemm_axes(mesh): - gsr = global_mesh_resource(validate=False) - return { - axis - for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) - if axis is not None and axis in mesh.axis_names - } - - -def _common_axis(spec_a, spec_b, allowed_axes=None): - axes = [] - for spec in (spec_a, spec_b): - for axis_spec in spec: - if axis_spec is None: - continue - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - for axis in axis_tuple: - if axis is not None and axis not in axes: - axes.append(axis) - for axis in axes: - if allowed_axes is not None and axis not in allowed_axes: - continue - if _spec_contains_axis(spec_a, axis) and _spec_contains_axis(spec_b, axis): - return axis - return None - - -def _merge_axis_spec(axis_spec_a, axis_spec_b): - axes = [] - for axis_spec in (axis_spec_a, axis_spec_b): - if axis_spec is None: - continue - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - for axis in axis_tuple: - if axis is not None and axis not in axes: - axes.append(axis) - if len(axes) == 0: - return None - return axes[0] if len(axes) == 1 else tuple(axes) - - -def _partition_spec_from_result(mesh, result_info, fallback_spec): - if result_info is not None and result_info.sharding is not None: - return result_info.sharding - return NamedSharding(mesh, PartitionSpec(*fallback_spec)) - - -def _local_shape_from_spec(global_shape, spec, mesh): - local_shape = [] - for dim, axis_spec in zip(global_shape, spec): - axis_size = _axis_spec_size(axis_spec, mesh) - local_shape.append(dim // axis_size) - return tuple(local_shape) - - -def _axis_spec_size(axis_spec, mesh): - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - axis_size = 1 - for axis in axis_tuple: - if axis is not None: - axis_size *= mesh.shape[axis] - return axis_size - - -def _spec_size(spec, mesh): - axis_size = 1 - for axis_spec in spec: - axis_size *= _axis_spec_size(axis_spec, mesh) - return axis_size - - -def _local_2d_sizes_from_spec(shape, spec, axis_boundary, left_size, right_size, mesh): - if len(shape) == len(spec) and len(shape) > 1: - local_shape = _local_shape_from_spec(shape, spec, mesh) - return ( - math.prod(local_shape[:axis_boundary]), - math.prod(local_shape[axis_boundary:]), - ) - - spec_size = _spec_size(spec, mesh) - if spec_size == 1: - return left_size, right_size - if left_size % spec_size == 0: - return left_size // spec_size, right_size - if right_size % spec_size == 0: - return left_size, right_size // spec_size - raise ValueError( - "Cannot derive local grouped GEMM 2D sizes from sharding spec. " - f"shape={shape}, spec={spec}, axis_boundary={axis_boundary}, " - f"left_size={left_size}, right_size={right_size}, spec_size={spec_size}" - ) - - def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -1909,23 +1772,23 @@ def _parse_partition_specs( ): gsr = global_mesh_resource(validate=False) fsdp_axis = gsr.fsdp_resource - allowed_axes = _supported_grouped_gemm_axes(mesh) + allowed_axes = supported_grouped_partition_axes(mesh) original_arg_specs = tuple(get_padded_spec(arg_info) for arg_info in arg_infos) - lhs_data_spec = _filter_spec_axes(original_arg_specs[0], allowed_axes) - lhs_scale_spec = _filter_spec_axes(original_arg_specs[1], allowed_axes) - rhs_data_spec = _filter_spec_axes(original_arg_specs[2], allowed_axes) - rhs_scale_spec = _filter_spec_axes(original_arg_specs[3], allowed_axes) - bias_spec = _filter_spec_axes(original_arg_specs[4], allowed_axes) - - lhs_first_dims_spec = _filter_spec_axes(original_arg_specs[5], allowed_axes) - lhs_last_dims_spec = _filter_spec_axes(original_arg_specs[6], allowed_axes) - rhs_first_dims_spec = _filter_spec_axes(original_arg_specs[7], allowed_axes) - rhs_last_dims_spec = _filter_spec_axes(original_arg_specs[8], allowed_axes) - out_first_dims_spec = _filter_spec_axes(original_arg_specs[9], allowed_axes) - out_last_dims_spec = _filter_spec_axes(original_arg_specs[10], allowed_axes) - additional_arg_0_spec = _filter_spec_axes(original_arg_specs[11], allowed_axes) - additional_arg_1_spec = _filter_spec_axes(original_arg_specs[12], allowed_axes) + lhs_data_spec = filter_spec_axes(original_arg_specs[0], allowed_axes) + lhs_scale_spec = filter_spec_axes(original_arg_specs[1], allowed_axes) + rhs_data_spec = filter_spec_axes(original_arg_specs[2], allowed_axes) + rhs_scale_spec = filter_spec_axes(original_arg_specs[3], allowed_axes) + bias_spec = filter_spec_axes(original_arg_specs[4], allowed_axes) + + lhs_first_dims_spec = filter_spec_axes(original_arg_specs[5], allowed_axes) + lhs_last_dims_spec = filter_spec_axes(original_arg_specs[6], allowed_axes) + rhs_first_dims_spec = filter_spec_axes(original_arg_specs[7], allowed_axes) + rhs_last_dims_spec = filter_spec_axes(original_arg_specs[8], allowed_axes) + out_first_dims_spec = filter_spec_axes(original_arg_specs[9], allowed_axes) + out_last_dims_spec = filter_spec_axes(original_arg_specs[10], allowed_axes) + additional_arg_0_spec = filter_spec_axes(original_arg_specs[11], allowed_axes) + additional_arg_1_spec = filter_spec_axes(original_arg_specs[12], allowed_axes) grouped_dim_specs = ( lhs_first_dims_spec, @@ -1950,46 +1813,46 @@ def _parse_partition_specs( if ( ep_axis is not None and not rhs_is_ragged - and _spec_contains_axis(active_group_spec, ep_axis) + and spec_contains_axis(active_group_spec, ep_axis) ): - if len(rhs_data_spec) > 0 and not _spec_contains_axis(rhs_data_spec, ep_axis): + if len(rhs_data_spec) > 0 and not spec_contains_axis(rhs_data_spec, ep_axis): rhs_data_spec = ( - _merge_axis_spec(rhs_data_spec[0], ep_axis), + merge_axis_specs(rhs_data_spec[0], ep_axis), *rhs_data_spec[1:], ) - if len(rhs_scale_spec) > 0 and not _spec_contains_axis(rhs_scale_spec, ep_axis): + if len(rhs_scale_spec) > 0 and not spec_contains_axis(rhs_scale_spec, ep_axis): rhs_scale_spec = ( - _merge_axis_spec(rhs_scale_spec[0], ep_axis), + merge_axis_specs(rhs_scale_spec[0], ep_axis), *rhs_scale_spec[1:], ) - if len(bias_spec) > 0 and not _spec_contains_axis(bias_spec, ep_axis): - bias_spec = (_merge_axis_spec(bias_spec[0], ep_axis), *bias_spec[1:]) + if len(bias_spec) > 0 and not spec_contains_axis(bias_spec, ep_axis): + bias_spec = (merge_axis_specs(bias_spec[0], ep_axis), *bias_spec[1:]) gather_rhs_fsdp = ( fsdp_axis is not None and not rhs_is_ragged and ( - _spec_contains_axis(rhs_data_spec, fsdp_axis) - or _spec_contains_axis(rhs_scale_spec, fsdp_axis) - or _spec_contains_axis(bias_spec, fsdp_axis) + spec_contains_axis(rhs_data_spec, fsdp_axis) + or spec_contains_axis(rhs_scale_spec, fsdp_axis) + or spec_contains_axis(bias_spec, fsdp_axis) ) ) if gather_rhs_fsdp: - rhs_data_spec = _strip_axis_from_spec(rhs_data_spec, fsdp_axis) - rhs_scale_spec = _strip_axis_from_spec(rhs_scale_spec, fsdp_axis) - bias_spec = _strip_axis_from_spec(bias_spec, fsdp_axis) + rhs_data_spec = strip_axis_from_spec(rhs_data_spec, fsdp_axis) + rhs_scale_spec = strip_axis_from_spec(rhs_scale_spec, fsdp_axis) + bias_spec = strip_axis_from_spec(bias_spec, fsdp_axis) reducible_axes = tuple( axis for axis in (gsr.dp_resource, gsr.fsdp_resource) if axis is not None ) - reduce_axis = _common_axis(lhs_data_spec, rhs_data_spec, reducible_axes) + reduce_axis = common_spec_axis(lhs_data_spec, rhs_data_spec, reducible_axes) if reduce_axis is not None and gather_rhs_fsdp: reduce_axis = None if result_infos: original_out_spec = get_padded_spec(result_infos[0]) - out_spec = _filter_spec_axes(original_out_spec, allowed_axes) + out_spec = filter_spec_axes(original_out_spec, allowed_axes) else: original_out_spec = None out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) @@ -2003,7 +1866,7 @@ def _parse_partition_specs( lhs_data_spec = list(lhs_data_spec) for out_idx, lhs_dim in enumerate(lhs_non_contracting_dims, start=1): if out_idx < len(out_spec): - lhs_data_spec[lhs_dim] = _merge_axis_spec( + lhs_data_spec[lhs_dim] = merge_axis_specs( lhs_data_spec[lhs_dim], out_spec[out_idx] ) lhs_data_spec = tuple(lhs_data_spec) @@ -2082,8 +1945,8 @@ def partition( ) arg_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in arg_specs) out_sharding = (NamedSharding(mesh, PartitionSpec(*out_spec)),) - local_out_shape = _local_shape_from_spec(out_shape, out_spec, mesh) - local_lhs_left_size, local_lhs_right_size = _local_2d_sizes_from_spec( + local_out_shape = local_shape_from_spec(out_shape, out_spec, mesh) + local_lhs_left_size, local_lhs_right_size = local_2d_sizes_from_spec( arg_infos[0].shape, arg_specs[0], lhs_axis_boundary, @@ -2091,7 +1954,7 @@ def partition( lhs_right_size, mesh, ) - local_rhs_left_size, local_rhs_right_size = _local_2d_sizes_from_spec( + local_rhs_left_size, local_rhs_right_size = local_2d_sizes_from_spec( arg_infos[2].shape, arg_specs[2], rhs_axis_boundary, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 828bdb6067..9266ab08f0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -32,9 +32,14 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, + filter_spec_axes, get_num_devices_in_mesh, global_mesh_resource, lax_paral_op, + local_shape_from_spec, + merge_axis_specs, + spec_axes, + supported_grouped_partition_axes, ) from ..quantize import ( ScaledTensor2x, @@ -55,22 +60,8 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] -def _merge_axis_specs(axis_specs): - axes = [] - for axis_spec in axis_specs: - if axis_spec is None: - continue - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - for axis in axis_tuple: - if axis is not None and axis not in axes: - axes.append(axis) - if len(axes) == 0: - return None - return axes[0] if len(axes) == 1 else tuple(axes) - - def _flat_data_spec(input_spec): - return (_merge_axis_specs(input_spec),) + return (merge_axis_specs(*input_spec),) def _normalize_flatten_axis(flatten_axis, ndim): @@ -84,35 +75,9 @@ def _contiguous_flat_input_spec(input_spec, flatten_axis): return (*input_spec[:flatten_axis], *((None,) * (len(input_spec) - flatten_axis))) -def _filter_axis_spec(axis_spec, allowed_axes): - if axis_spec is None: - return None - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) - if len(axes) == 0: - return None - return axes[0] if len(axes) == 1 else axes - - -def _filter_spec_axes(spec, allowed_axes): - return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) - - -def _spec_axes(spec): - axes = [] - for axis_spec in spec: - if axis_spec is None: - continue - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - for axis in axis_tuple: - if axis is not None and axis not in axes: - axes.append(axis) - return axes - - def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): ignored_axes = tuple( - axis for axis in _spec_axes(original_spec) if axis not in _spec_axes(partition_spec) + axis for axis in spec_axes(original_spec) if axis not in spec_axes(partition_spec) ) if ignored_axes: warnings.warn( @@ -124,31 +89,6 @@ def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): ) -def _supported_grouped_quantize_axes(mesh): - gsr = global_mesh_resource(validate=False) - return { - axis - for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) - if axis is not None and axis in mesh.axis_names - } - - -def _axis_spec_size(axis_spec, mesh): - axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) - axis_size = 1 - for axis in axis_tuple: - if axis is not None: - axis_size *= mesh.shape[axis] - return axis_size - - -def _local_shape_from_spec(global_shape, spec, mesh): - local_shape = [] - for dim, axis_spec in zip(global_shape, spec): - local_shape.append(dim // _axis_spec_size(axis_spec, mesh)) - return tuple(local_shape) - - def _pad_or_slice_to_shape(x, target_shape): if target_shape is None or x.shape == target_shape: return x @@ -1354,14 +1294,14 @@ def impl( @staticmethod def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos): - allowed_axes = _supported_grouped_quantize_axes(mesh) + allowed_axes = supported_grouped_partition_axes(mesh) original_x_spec = get_padded_spec(arg_infos[0]) - x_spec = _filter_spec_axes(original_x_spec, allowed_axes) + x_spec = filter_spec_axes(original_x_spec, allowed_axes) x_spec = _contiguous_flat_input_spec(x_spec, flatten_axis) _warn_if_axes_ignored("x", original_x_spec, x_spec) original_group_spec = get_padded_spec(arg_infos[2]) - group_spec = _filter_spec_axes(original_group_spec, allowed_axes) + group_spec = filter_spec_axes(original_group_spec, allowed_axes) if group_spec == (None,) and len(x_spec) > 0: group_spec = (x_spec[0],) _warn_if_axes_ignored("group_sizes", original_group_spec, group_spec) @@ -1409,7 +1349,7 @@ def partition( ) local_out_shapes = ( tuple( - _local_shape_from_spec(info.shape, spec, mesh) + local_shape_from_spec(info.shape, spec, mesh) for info, spec in zip(result_infos, out_specs) ) if result_infos diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 8dffb71196..86ddf72ff7 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Optional +import math import warnings import jax @@ -233,6 +234,155 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) +def spec_axes(spec): + """Return unique non-None mesh axes used by a PartitionSpec-like tuple.""" + axes = [] + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + return axes + + +def axis_spec_contains(axis_spec, axis): + """Return whether one dimension's axis spec contains a mesh axis.""" + if axis is None or axis_spec is None: + return False + if isinstance(axis_spec, tuple): + return axis in axis_spec + return axis_spec == axis + + +def spec_contains_axis(spec, axis): + """Return whether a PartitionSpec-like tuple contains a mesh axis.""" + return any(axis_spec_contains(axis_spec, axis) for axis_spec in spec) + + +def filter_axis_spec(axis_spec, allowed_axes): + """Keep only allowed axes in one dimension's axis spec.""" + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def filter_spec_axes(spec, allowed_axes): + """Keep only allowed axes in a PartitionSpec-like tuple.""" + return tuple(filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def supported_grouped_partition_axes(mesh): + """Return mesh axes supported by grouped quantize/GEMM custom partitioning.""" + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + +def strip_axis_from_axis_spec(axis_spec, axis): + """Remove one mesh axis from one dimension's axis spec.""" + if axis is None or axis_spec is None: + return axis_spec + if isinstance(axis_spec, tuple): + stripped = tuple(a for a in axis_spec if a != axis) + if len(stripped) == 0: + return None + return stripped[0] if len(stripped) == 1 else stripped + return None if axis_spec == axis else axis_spec + + +def strip_axis_from_spec(spec, axis): + """Remove one mesh axis from a PartitionSpec-like tuple.""" + return tuple(strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) + + +def merge_axis_specs(*axis_specs): + """Merge dimension axis specs while preserving first-seen axis order.""" + axes = [] + for axis_spec in axis_specs: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def common_spec_axis(spec_a, spec_b, allowed_axes=None): + """Return the first mesh axis that appears in both specs.""" + axes = [] + for spec in (spec_a, spec_b): + for axis in spec_axes(spec): + if axis not in axes: + axes.append(axis) + for axis in axes: + if allowed_axes is not None and axis not in allowed_axes: + continue + if spec_contains_axis(spec_a, axis) and spec_contains_axis(spec_b, axis): + return axis + return None + + +def axis_spec_size(axis_spec, mesh): + """Return the device count represented by one dimension's axis spec.""" + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def spec_size(spec, mesh): + """Return the total device count represented by a PartitionSpec-like tuple.""" + axis_size = 1 + for axis_spec in spec: + axis_size *= axis_spec_size(axis_spec, mesh) + return axis_size + + +def local_shape_from_spec(global_shape, spec, mesh): + """Derive a local shape from a global shape and PartitionSpec-like tuple.""" + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + local_shape.append(dim // axis_spec_size(axis_spec, mesh)) + return tuple(local_shape) + + +def local_2d_sizes_from_spec(shape, spec, axis_boundary, left_size, right_size, mesh): + """Derive local collapsed 2D dimensions from a global shape and sharding spec.""" + if len(shape) == len(spec) and len(shape) > 1: + local_shape = local_shape_from_spec(shape, spec, mesh) + return ( + math.prod(local_shape[:axis_boundary]), + math.prod(local_shape[axis_boundary:]), + ) + + size = spec_size(spec, mesh) + if size == 1: + return left_size, right_size + if left_size % size == 0: + return left_size // size, right_size + if right_size % size == 0: + return left_size, right_size // size + raise ValueError( + "Cannot derive local 2D sizes from sharding spec. " + f"shape={shape}, spec={spec}, axis_boundary={axis_boundary}, " + f"left_size={left_size}, right_size={right_size}, spec_size={size}" + ) + + def lax_paral_op( x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs ): From 6fd04b573446e27e0ac9d1d7666f24a423a971bf Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 2 Jun 2026 15:00:31 -0700 Subject: [PATCH 11/36] Rename distributed grouped GEMM tests Signed-off-by: Jeremy Berchtold --- qa/L1_jax_distributed_unittest/test.sh | 4 ++-- ..._gemm_partitioning.py => test_distributed_grouped_gemm.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename tests/jax/{test_grouped_gemm_partitioning.py => test_distributed_grouped_gemm.py} (100%) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 031bb72995..9d481d4332 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -25,6 +25,8 @@ export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_grouped_gemm.xml $TE_PATH/tests/jax/test_distributed_grouped_gemm.py || test_fail "test_distributed_grouped_gemm.py" + python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py" @@ -37,8 +39,6 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" -# TODO(Phuong): add this test back after it is verified - if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" exit 1 diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_distributed_grouped_gemm.py similarity index 100% rename from tests/jax/test_grouped_gemm_partitioning.py rename to tests/jax/test_distributed_grouped_gemm.py From 1ebffa3ccb2bed9eb79a431d8a08ca902a207a05 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:06 +0000 Subject: [PATCH 12/36] Expert Parallelism: common C API + NCCL EP v0.1 backend Signed-off-by: Phuong Nguyen --- .gitmodules | 4 + 3rdparty/nccl | 1 + qa/L1_cpp_distributed/test.sh | 3 + setup.py | 127 +++ tests/cpp_distributed/CMakeLists.txt | 91 +- tests/cpp_distributed/run_test_ep.sh | 137 +++ tests/cpp_distributed/test_ep_common.h | 308 ++++++ tests/cpp_distributed/test_ep_coverage.cu | 379 ++++++++ tests/cpp_distributed/test_ep_init.cu | 64 ++ tests/cpp_distributed/test_ep_pipeline.cu | 890 ++++++++++++++++++ transformer_engine/common/CMakeLists.txt | 90 ++ transformer_engine/common/ep/ep_api.cpp | 76 ++ transformer_engine/common/ep/ep_api_stub.cpp | 61 ++ transformer_engine/common/ep/ep_backend.cpp | 514 ++++++++++ transformer_engine/common/ep/ep_backend.h | 114 +++ .../include/transformer_engine/comm_window.h | 32 + .../common/include/transformer_engine/ep.h | 161 ++++ 17 files changed, 3050 insertions(+), 2 deletions(-) create mode 160000 3rdparty/nccl create mode 100755 tests/cpp_distributed/run_test_ep.sh create mode 100644 tests/cpp_distributed/test_ep_common.h create mode 100644 tests/cpp_distributed/test_ep_coverage.cu create mode 100644 tests/cpp_distributed/test_ep_init.cu create mode 100644 tests/cpp_distributed/test_ep_pipeline.cu create mode 100644 transformer_engine/common/ep/ep_api.cpp create mode 100644 transformer_engine/common/ep/ep_api_stub.cpp create mode 100644 transformer_engine/common/ep/ep_backend.cpp create mode 100644 transformer_engine/common/ep/ep_backend.h create mode 100644 transformer_engine/common/include/transformer_engine/comm_window.h create mode 100644 transformer_engine/common/include/transformer_engine/ep.h diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..e531c95507 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,7 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git + branch = v2.30u1 diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..6a9bc953ac --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..7e5ce2cf0d 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then cmake -GNinja -S. -Bbuild cmake --build build mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + + # EP suites; runner self-skips on pre-Hopper GPUs. + bash ./run_test_ep.sh 4 ./build fi diff --git a/setup.py b/setup.py index ec277b6349..db360c8a29 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,34 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP: on by default; auto-disabled if no arch >= 90. + # Set NVTE_BUILD_WITH_NCCL_EP=0/1 to force off/on. + nccl_ep_env = os.getenv("NVTE_BUILD_WITH_NCCL_EP") + explicit_nccl_ep = nccl_ep_env is not None + build_with_nccl_ep = bool(int(nccl_ep_env)) if explicit_nccl_ep else True + + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any(t.lower() == "native" for t in arch_tokens) or any( + int(t.rstrip("af")) >= 90 for t in arch_tokens if t.rstrip("af").isdigit() + ) + if not has_hopper_or_newer: + if explicit_nccl_ep: + raise RuntimeError( + "NVTE_BUILD_WITH_NCCL_EP=1 requires at least one CUDA arch >= 90 in " + f"NVTE_CUDA_ARCHS (got '{archs}'). Add '90' or unset NVTE_BUILD_WITH_NCCL_EP." + ) + print( + "[NCCL EP] No CUDA arch >= 90 in NVTE_CUDA_ARCHS" + f" ('{archs}'); auto-disabling NCCL EP (nvte_ep_* will throw at runtime)." + ) + build_with_nccl_ep = False + + if build_with_nccl_ep: + build_nccl_ep_submodule() + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -128,6 +156,105 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / "lib" / name).exists() or (p / "lib64" / name).exists() + for name in ("libnccl.so", "libnccl.so.2") + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + root = lib_path.parent.parent + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.so from the 3rdparty/nccl submodule. + + NCCL EP is on by default; the system NCCL core (libnccl.so) supplies the + headers and runtime symbols. Returns the submodule build directory. + """ + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so" + + archs = cuda_archs() or "90" + arch_list = [] + for a in str(archs).split(";"): + a = a.strip().rstrip("af") + if a and a.isdigit() and int(a) >= 90: + arch_list.append(a) + if not arch_list: + arch_list = ["90"] + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + if not nccl_ep_lib.exists(): + print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + + # TE's CMake expects nccl.h under 3rdparty/nccl/build/include/ for its + # version check. Mirror the top-level host headers from the system NCCL + # install — DON'T mirror nccl_device/ because the submodule ships its own + # newer copy at src/include/nccl_device/ with device-side templates that + # conflict with older system versions, and the JIT include path picks the + # submodule's. + nccl_include = build_dir / "include" + nccl_include.mkdir(parents=True, exist_ok=True) + for cand in (Path(nccl_home) / "include", Path("/usr/include")): + p = Path(cand) + if (p / "nccl.h").exists(): + for name in ("nccl.h", "nccl_net.h", "nccl_tuner.h"): + src = p / name + dst = nccl_include / name + if src.exists() and not dst.exists(): + dst.symlink_to(src) + break + + return str(build_dir) + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 0d7258a81d..3870f57911 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -46,12 +46,99 @@ add_executable(test_comm_gemm find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# ── NCCL library ────────────────────────────────────────────────────────────── +# Search order: NCCL_HOME env → 3rdparty/nccl submodule build → system paths. +set(NCCL_SUBMODULE_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build") find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + HINTS $ENV{NCCL_HOME}/lib ${NCCL_SUBMODULE_BUILD}/lib + PATH_SUFFIXES lib lib64 REQUIRED) + +# NCCL headers: prefer submodule build output (has the handle_init API), +# then submodule src, then system (CUDA toolkit). +set(NCCL_SUBMODULE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +set(NCCL_SUBMODULE_SRC_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/src/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_INCLUDE}") +elseif(EXISTS "${NCCL_SUBMODULE_SRC_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIR "${NCCL_SUBMODULE_SRC_INCLUDE}") +elseif(DEFINED ENV{NCCL_HOME}) + set(NCCL_INCLUDE_DIR "$ENV{NCCL_HOME}/include") +endif() target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# ── EP distributed tests (HT mode) ───────────────────────────────────────── +# No MPI dependency — processes are spawned by run_test_ep.sh with +# --rank / --nranks flags. ncclUniqueId exchange uses a +# shared temp file (see test_ep_common.h for details). +# Headers + libs come from the in-tree 3rdparty/nccl submodule build. +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_SUBMODULE_ROOT}/build/lib + NO_DEFAULT_PATH + REQUIRED) + +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "EP test: NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# Collect NCCL include dirs shared by all EP test targets (nccl_ep.h + nccl.h). +set(EP_TEST_NCCL_INCLUDES ${NCCL_EP_INCLUDE_DIR}) +if(DEFINED NCCL_INCLUDE_DIR) + list(APPEND EP_TEST_NCCL_INCLUDES ${NCCL_INCLUDE_DIR}) + message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +endif() + +set(EP_TEST_COMMON_INCLUDES + ${EP_TEST_NCCL_INCLUDES} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + CUDA::nvrtc + GTest::gtest + ${TE_LIB} + ${NCCL_LIB} + ${NCCL_EP_LIB}) + +# nvrtc symbols are referenced from libtransformer_engine.so but not in its +# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc +# explicitly with --no-as-needed so the linker keeps the dependency. +set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") + +# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── +add_executable(test_ep_init test_ep_init.cu) +target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── +add_executable(test_ep_pipeline test_ep_pipeline.cu) +target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) + +# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── +add_executable(test_ep_coverage test_ep_coverage.cu) +target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) +target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + +# Do NOT use gtest_discover_tests — these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled: ${NCCL_EP_LIB}") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..017d3f807b --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests across multiple GPUs. +# +# Spawns one background bash process per GPU (no MPI dependency), matching the +# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared +# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and +# passes it to nvte_ep_initialize. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") +# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +OVERALL_FAIL=0 + +# --------------------------------------------------------------------------- +# run_suite BINARY SUITE_NAME MIN_GPUS +# --------------------------------------------------------------------------- +run_suite() { + local BINARY="$1" + local SUITE_NAME="$2" + local MIN_GPUS="${3:-2}" + + local TEST_BIN="${BUILD_DIR}/${BINARY}" + + if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + OVERALL_FAIL=1 + return + fi + + if (( NUM_GPUS < MIN_GPUS )); then + echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." + return + fi + + local TMPDIR_L="${TMPDIR:-/tmp}" + local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" + rm -f "${UID_FILE}" + + local LOG_DIR + LOG_DIR=$(mktemp -d) + local FAIL=0 + + echo "=== ${SUITE_NAME} ===" + echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" + echo + + # Spawn one background process per GPU. ncclUniqueId is exchanged via the + # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. + local PIDS=() + for i in $(seq 0 $((NUM_GPUS - 1))); do + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + "${TEST_BIN}" \ + --rank="${i}" \ + --nranks="${NUM_GPUS}" \ + --uid-file="${UID_FILE}" \ + ${GTEST_ARGS} \ + > "${LOG_DIR}/rank_${i}.log" 2>&1 & + PIDS+=($!) + done + for i in $(seq 0 $((NUM_GPUS - 1))); do + if ! wait "${PIDS[$i]}"; then + local rc=$? + FAIL=1 + if [[ $rc -eq 137 || $rc -eq 124 ]]; then + echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" + fi + fi + done + + echo "--- Rank 0 output ---" + cat "${LOG_DIR}/rank_0.log" + + if (( FAIL )); then + for i in $(seq 1 $((NUM_GPUS - 1))); do + echo "--- Rank ${i} output ---" + cat "${LOG_DIR}/rank_${i}.log" + done + echo "=== ${SUITE_NAME}: FAILED ===" + OVERALL_FAIL=1 + else + echo "=== ${SUITE_NAME}: ALL PASSED ===" + fi + + rm -rf "${LOG_DIR}" + rm -f "${UID_FILE}" +} + +# --------------------------------------------------------------------------- +# Cleanup on abort +# --------------------------------------------------------------------------- +cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } +trap cleanup EXIT INT TERM + +# --------------------------------------------------------------------------- +# Run all suites +# --------------------------------------------------------------------------- +run_suite "test_ep_init" "EP Init Tests" 2 +run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 +run_suite "test_ep_coverage" "EP Coverage Tests" 2 + +echo +if (( OVERALL_FAIL )); then + echo "=== SOME SUITES FAILED ===" +else + echo "=== ALL SUITES PASSED ===" +fi + +exit "${OVERALL_FAIL}" diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..77baa92b0c --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// ── Error-checking macros ───────────────────────────────────────────────────── + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) \ + FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ + } while (false) + +#define CHECK_CUDA(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) \ + FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ + } while (false) + +#define ASSERT_CUDA_OK(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +#define ASSERT_NCCL_OK(expr) \ + do { \ + ncclResult_t _err = (expr); \ + if (_err != ncclSuccess) { \ + fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (false) + +// ── Process-level state ─────────────────────────────────────────────────────── + +static int g_process_id = -1; +static int g_num_processes = -1; +static std::string g_uid_file; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── + +// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. +struct TensorHandle { + NVTETensor tensor = nullptr; + void* dev_ptr = nullptr; + + ~TensorHandle() { + if (tensor) nvte_destroy_tensor(tensor); + } + + TensorHandle() = default; + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { + o.tensor = nullptr; o.dev_ptr = nullptr; + } + TensorHandle& operator=(TensorHandle&& o) noexcept { + if (this != &o) { + if (tensor) nvte_destroy_tensor(tensor); + tensor = o.tensor; dev_ptr = o.dev_ptr; + o.tensor = nullptr; o.dev_ptr = nullptr; + } + return *this; + } +}; + +static TensorHandle make_nvte_tensor(void* dev_ptr, + const std::vector& shape, + NVTEDType dtype) { + TensorHandle h; + h.dev_ptr = dev_ptr; + h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); + + NVTEShape s; + s.ndim = shape.size(); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; + + NVTEBasicTensor bt; + bt.data_ptr = dev_ptr; + bt.dtype = dtype; + bt.shape = s; + nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); + + return h; +} + +// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +template +struct DevBuf { + T* ptr = nullptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + ~DevBuf() { reset(); } + + DevBuf(const DevBuf&) = delete; + DevBuf& operator=(const DevBuf&) = delete; + DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } + DevBuf& operator=(DevBuf&& o) noexcept { + if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } + return *this; + } + + void alloc(size_t n) { + reset(); + count = n; + if (n > 0) { + cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); + if (e != cudaSuccess) { + fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), + cudaGetErrorString(e)); + ptr = nullptr; + count = 0; + } + } + } + + void reset() { + if (ptr) { cudaFree(ptr); ptr = nullptr; } + count = 0; + } + + T* get() const { return ptr; } + size_t bytes() const { return count * sizeof(T); } +}; + +// ── Shared routing helper ───────────────────────────────────────────────────── + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_local_experts + t * top_k + k) % num_experts +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts; + return idx; +} + +// ── File-based ncclUniqueId exchange ───────────────────────────────────────── + +static void exchange_unique_id(ncclUniqueId* uid) { + const size_t sz = sizeof(ncclUniqueId); + + if (g_process_id == 0) { + ASSERT_NCCL_OK(ncclGetUniqueId(uid)); + FILE* f = fopen(g_uid_file.c_str(), "wb"); + if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } + fwrite(uid, 1, sz, f); + fclose(f); + } else { + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); + while (true) { + FILE* f = fopen(g_uid_file.c_str(), "rb"); + if (f) { + fseek(f, 0, SEEK_END); + if (static_cast(ftell(f)) >= sz) { + fseek(f, 0, SEEK_SET); + size_t n = fread(uid, 1, sz, f); + fclose(f); + if (n == sz) break; + } else { + fclose(f); + } + } + if (std::chrono::steady_clock::now() > deadline) { + fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); + exit(EXIT_FAILURE); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ── CLI parsing ─────────────────────────────────────────────────────────────── + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); + else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); + else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); + else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); + else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + } + + if (g_process_id < 0 || g_num_processes <= 0) { + fprintf(stderr, + "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" + " Aliases: --process-id=N, --num-processes=N\n", + argc > 0 ? argv[0] : "test_ep"); + exit(EXIT_FAILURE); + } + + if (g_uid_file.empty()) { + const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; + g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + } +} + +// ── Bootstrap / teardown ────────────────────────────────────────────────────── + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + + ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + if (g_process_id == 0) remove(g_uid_file.c_str()); +} diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu new file mode 100644 index 0000000000..ef7941905d --- /dev/null +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -0,0 +1,379 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP C-API coverage tests (paths not exercised by the pipeline suite). + * + * MultiHandleAllocTest — distinct handle ids; each works end-to-end. + * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. + * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. + * NegativeTests — alignment mismatch and null handle_mem must throw. + */ + +#include "test_ep_common.h" + +#include +#include + +// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two +// full experts. Requires top_k >= 2 and num_experts >= 3. +static std::vector routing_skip_middle(int num_tokens, int top_k) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) { + idx[t * top_k + 0] = 0; + if (top_k >= 2) idx[t * top_k + 1] = 2; + for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers + } + return idx; +} + +static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { + std::vector v(num_tokens * hidden_dim); + nv_bfloat16 b = __float2bfloat16(val); + std::fill(v.begin(), v.end(), b); + return v; +} + +namespace { + +class EpCoverageBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + } + + // Helper: allocate buffers + tensor views for a single dispatch+combine. + struct Bundle { + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + }; + + Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, + size_t alignment) { + Bundle b; + b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; + b.topk_idx.alloc(num_tokens * top_k); + b.topk_weights.alloc(num_tokens * top_k); + b.tokens.alloc(num_tokens * hidden_dim_); + b.token_counts.alloc(num_local_experts); + b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); + b.recv_topk_weights.alloc(b.recv_capacity); + b.result.alloc(num_tokens * hidden_dim_); + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); + b.handle_mem.alloc(b.handle_mem_size); + return b; + } +}; + +} // namespace + +// ============================================================================= +// MultiHandleAllocTest: ids are distinct and each is independently usable. +// ============================================================================= + +class MultiHandleAllocTest : public EpCoverageBase {}; + +TEST_F(MultiHandleAllocTest, IdsAreDistinct) { + NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; + const int kN = 8; + std::vector ids(kN); + for (int i = 0; i < kN; ++i) { + size_t sz = 0; + ids[i] = nvte_ep_register_layer(cfg, &sz); + } + for (int i = 0; i < kN; ++i) { + EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; + for (int j = i + 1; j < kN; ++j) + EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; + } +} + +TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { + const int num_tokens = 16, top_k = 2; + Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + for (Bundle* x : {&a, &b}) { + CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NE(a.handle_id, b.handle_id); + + auto run_one = [&](Bundle& x) { + auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); + auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); + auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + NVTEEpHandle h{x.handle_id, handle_mem.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, + NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, + recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, + NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, + result.tensor, stream)); + }; + run_one(a); + run_one(b); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Both round-trips must produce result == top_k * 0.5 = 1.0. + for (Bundle* x : {&a, &b}) { + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + } + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. +// ============================================================================= + +class TopK1Test : public EpCoverageBase {}; + +TEST_F(TopK1Test, RoundTrip) { + const int num_tokens = 16, top_k = 1; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) + << "tok " << t << " hidden " << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives +// tokens. Round-trip must produce result == top_k * tokens regardless of the +// per-expert padding choice. +// ============================================================================= + +class EmptyExpertsTest : public EpCoverageBase, + public ::testing::WithParamInterface {}; + +TEST_P(EmptyExpertsTest, RoundTripCorrect) { + // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. + ASSERT_GE(num_experts_, 3); + const size_t alignment = GetParam(); + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); + + // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 + // tokens between two non-empty experts. + std::vector h_idx = routing_skip_middle(num_tokens, top_k); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); + + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + alignment, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, + tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, + NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, + NVTECommWindow{}, result_t.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // Identity expert + uniform weights: result[t] == top_k * tokens[t]. + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float expected = static_cast(top_k) * 0.3f; + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) + << "alignment=" << alignment << " tok=" << t << " hidden=" << p; + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, + ::testing::Values(0, 2, 8, 16)); + +// ============================================================================= +// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. +// ============================================================================= + +class NegativeTests : public EpCoverageBase {}; + +TEST_F(NegativeTests, AlignmentMismatchThrows) { + const int num_tokens = 8, top_k = 2; + // Allocate handle for alignment=0, then call prepare with alignment=16. + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/16, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(NegativeTests, NullHandleMemThrows) { + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + // Construct a tensor view backed by a null device pointer. + auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu new file mode 100644 index 0000000000..08744dfee5 --- /dev/null +++ b/tests/cpp_distributed/test_ep_init.cu @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Unit tests for EP initialization paths. + * + * Tests: + * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 + * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values + * + * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). + */ + +#include "test_ep_common.h" + +// ── Fixture ─────────────────────────────────────────────────────────────────── + +class EPInitTest : public ::testing::Test { + protected: + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; + ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; + } +}; + +// ── Tests ───────────────────────────────────────────────────────────────────── + +TEST_F(EPInitTest, InitPath) { + int nle = g_num_experts / g_ep_size; + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; + + if (g_process_id == 0) { + printf(" handle_mem : %zu bytes\n", sz); + } +} + +TEST_F(EPInitTest, NumLocalExperts) { + // handle_mem_size should be > 0 for any valid num_local_experts value. + for (int nle : {1, g_num_experts / g_ep_size}) { + NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; + size_t sz = 0; + (void)nvte_ep_register_layer(cfg, &sz); + ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; + if (g_process_id == 0) + printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); + } +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep_pipeline.cu new file mode 100644 index 0000000000..41f83a6d11 --- /dev/null +++ b/tests/cpp_distributed/test_ep_pipeline.cu @@ -0,0 +1,890 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch — exact recv values + per-expert counts + * EPCombineTest/Combine — round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck — exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck — exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip — exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward — fwd + bwd NaN/Inf check + * + * Routing: token t on rank r → expert (r * num_local_experts + t * top_k + k) % num_experts + * Token values: rank r, token t → all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// ── Deterministic routing helpers ───────────────────────────────────────────── + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(__bfloat162float(__float2bfloat16(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for +// accumulation noise inside dispatch/combine. +static float bf16_tol(float magnitude) { + return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); +} + +static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = __bfloat162float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// ── Forward buffer set with RAII ────────────────────────────────────────────── + +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + uint64_t handle_id = 0; + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; + handle_id = nvte_ep_register_layer(cfg, &handle_mem_size); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers — one place to update the shape +// conventions when the C-API evolves. +struct EPTensors { + TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorHandle recv_tokens, recv_topk_weights, result; + TensorHandle grad_result, grad_expert, grad_tokens; + TensorHandle g_recv_topk_weights, grad_topk_weights; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + topk_idx = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + topk_weights = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + token_counts = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + handle_mem = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + tokens = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + recv_tokens = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + result = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_result = make_nvte_tensor(b.grad_result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + grad_expert = make_nvte_tensor(b.grad_expert.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + grad_tokens = make_nvte_tensor(b.grad_tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + } +}; + +// ── Shared fixture base ─────────────────────────────────────────────────────── + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + } + + NVTEEpLayerConfig layer_config(size_t alignment = 0) const { + return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; + } + + // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +class EPDispatchTest : public EpOpTestBase {}; + +TEST_F(EPDispatchTest, PrepareAndDispatch) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity — overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer — avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i])) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert → result == top_k * tokens. +// ============================================================================= + +class EPCombineTest : public EpOpTestBase {}; + +TEST_F(EPCombineTest, Combine) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + // Spot-check 3 hidden-dim positions per token to catch partial-row writes. + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p : probes) { + float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +class EPCombineBwdTest : public EpOpTestBase {}; + +TEST_F(EPCombineBwdTest, CombineBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = read_total_recv(buf); + + std::vector cnt(num_local_experts_); + CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = 0.1f; + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + float v = __bfloat162float(h_ge[slot * hidden_dim_]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +class EPDispatchBwdTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + + 0.0001f * (g_process_id + 1); + CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), + {buf.recv_capacity}, kNVTEFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase {}; + +TEST_F(EPPipelineTest, FullForwardBackward) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, + NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, + t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, + t.grad_expert.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); + CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +class EPZeroCopyTest : public EpOpTestBase {}; + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TEST_F(EPZeroCopyTest, IdentityAllSymm) { + // HBM reference run. + EPBuffers ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(ref_buf); + EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t ref_hid = ref_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, + ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, + NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, + ref_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + uint64_t sym_hid = sym_buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, + sym_t.tokens.tensor, symm_window(sym_tokens), + sym_t.topk_weights.tensor, NVTECommWindow{}, + sym_t.recv_tokens.tensor, symm_window(sym_recv), + sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, + symm_window(sym_recv), sym_t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Same buffers, 2 iterations — catches window-lifecycle regressions where the +// symm-mem registration goes stale between calls. +TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + float got = __bfloat162float(h_res[tok * hidden_dim_]); + ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; + } + } + + if (g_process_id == 0) + printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// Full forward+backward with symm-mem on every spec-mandated buffer: +// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. +// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior +// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once +// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is +// understood. Tracked separately. +TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + + // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND + // combine input), grad_result (combine_bwd input), grad_expert + // (combine_bwd output AND dispatch_bwd input). + SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; + sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + t.tokens = make_nvte_tensor(sym_tokens.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.recv_tokens = make_nvte_tensor(sym_recv.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_result = make_nvte_tensor(sym_grad_result.ptr, + {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); + t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, + {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, + t.tokens.tensor, symm_window(sym_tokens), + t.topk_weights.tensor, NVTECommWindow{}, + t.recv_tokens.tensor, symm_window(sym_recv), + t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, + symm_window(sym_recv), t.result.tensor, stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); + CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, + symm_window(sym_grad_result), t.grad_expert.tensor, + symm_window(sym_grad_expert), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, + symm_window(sym_grad_expert), + t.g_recv_topk_weights.tensor, NVTECommWindow{}, + t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + std::vector h_gt(num_tokens_ * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) + << "grad_tokens token " << tok; + + if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ── main ────────────────────────────────────────────────────────────────────── + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..7c93f0e1da 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -391,6 +391,96 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# ── NCCL EP (on by default, HT mode only) ───────────────────────────────── +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely — useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# ── NCCL EP headers ──────────────────────────────────────────────────────── +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/contrib/nccl_ep/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# ── libnccl_ep.so ────────────────────────────────────────────────────────── +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# ── NCCL + GIN headers ───────────────────────────────────────────────────── +# libnccl.so and all GIN headers (ncclGin.h, ncclWindow_t, ncclDevComm_t) +# ship with the base CUDA Toolkit OR the 3rdparty/nccl submodule build +# (preferred when present; auto-built by setup.py via build_nccl_ep_submodule). +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + HINTS ${NCCL_EP_LIB_DIR} ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +set(NCCL_SUBMODULE_INCLUDE + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl/build/include") +if(EXISTS "${NCCL_SUBMODULE_INCLUDE}/nccl.h") + set(NCCL_INCLUDE_DIRS_FOR_TE ${NCCL_SUBMODULE_INCLUDE}) +else() + set(NCCL_INCLUDE_DIRS_FOR_TE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() + +# Diagnostic: log detected NCCL header version (minimum enforced at runtime). +find_file(_nvte_nccl_header_path nccl.h + PATHS ${NCCL_INCLUDE_DIRS_FOR_TE} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + NO_DEFAULT_PATH) +if(_nvte_nccl_header_path) + file(READ "${_nvte_nccl_header_path}" _nvte_nccl_h) + string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_major "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_minor "${CMAKE_MATCH_1}") + string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") + set(_nvte_nccl_patch "${CMAKE_MATCH_1}") + if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch) + message(STATUS "NCCL header: ${_nvte_nccl_header_path} (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})") + endif() +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIRS_FOR_TE}) # covers nccl.h + nccl_device/ + +target_link_libraries(transformer_engine PUBLIC + ${NCCL_EP_LIB} + ${NCCL_LIB}) + +# Embed rpath so the installed wheel finds libnccl_ep.so at runtime. +# libnccl.so is already on the system via the Toolkit — no rpath needed for it. +set_target_properties(transformer_engine PROPERTIES + INSTALL_RPATH "$ORIGIN;${NCCL_EP_LIB_DIR}") + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp) + +message(STATUS "NCCL EP enabled: ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: export throwing nvte_ep_* stubs so framework bindings link. + target_sources(transformer_engine PRIVATE ep/ep_api_stub.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) — using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..89d8b38607 --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + */ + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + return EPBackend::get().register_layer(layer_config, handle_mem_size); +} + +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().prepare(handle.id, topk_idx, token_counts, mem_ptr, + dispatch_output_per_expert_alignment, stream); +} + +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch(handle.id, mem_ptr, topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine(handle.id, mem_ptr, expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().dispatch_bwd(handle.id, mem_ptr, grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + void* mem_ptr = nvte_tensor_data(handle.mem); + NVTE_CHECK(mem_ptr != nullptr, "handle_mem tensor data must not be null"); + EPBackend::get().combine_bwd(handle.id, mem_ptr, grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} diff --git a/transformer_engine/common/ep/ep_api_stub.cpp b/transformer_engine/common/ep/ep_api_stub.cpp new file mode 100644 index 0000000000..fe4127d87d --- /dev/null +++ b/transformer_engine/common/ep/ep_api_stub.cpp @@ -0,0 +1,61 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api_stub.cpp + * \brief Throwing nvte_ep_* stubs compiled when NVTE_WITH_NCCL_EP=OFF. + */ + +#include + +#include "../util/logging.h" + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig /*layer_config*/, size_t* /*handle_mem_size*/) { + ep_not_built(); +} + +void nvte_ep_prepare(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*token_counts*/, + size_t /*dispatch_output_per_expert_alignment*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTEEpHandle /*handle*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTEEpHandle /*handle*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTEEpHandle /*handle*/, NVTETensor /*grad*/, NVTECommWindow /*grad_win*/, + NVTETensor /*grad_expert_out*/, NVTECommWindow /*grad_expert_out_win*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..ae0f3ab888 --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,514 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +// Build a by-value ncclEpTensor_t descriptor. `sizes` is caller-owned and must +// outlive any NCCL EP call that consumes the descriptor. +inline ncclEpTensor_t make_tensor(void* data, unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t t = NCCL_EP_TENSOR_INIT; + t.ndim = ndim; + t.datatype = datatype; + t.data = data; + t.sizes = sizes; + return t; +} + +// Payload descriptor: prefer the symmem window when set, else fall back to the +// NVTETensor's raw device pointer. +inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWindow& win, + unsigned int ndim, ncclDataType_t datatype, + size_t* sizes) { + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = ndim; + desc.datatype = datatype; + desc.sizes = sizes; + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "payload tensor data must not be null"); + } + return desc; +} + +// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. +class ScopedEpHandle { + public: + ScopedEpHandle() = default; + explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} + ~ScopedEpHandle() { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + } + ScopedEpHandle(const ScopedEpHandle&) = delete; + ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; + ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } + ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { + if (this != &other) { + if (h_ != nullptr) ncclEpHandleDestroy(h_); + h_ = other.h_; + other.h_ = nullptr; + } + return *this; + } + operator ncclEpHandle_t() const { return h_; } + ncclEpHandle_t get() const { return h_; } + + private: + ncclEpHandle_t h_ = nullptr; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, + "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", + config.hidden_dim); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + int device, major; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + NVTE_CHECK(major >= 9, + "NCCL EP requires SM_90+ (Hopper or later), " + "but current device has compute capability ", + major, ".x"); + + // NCCL EP needs CUDA multicast (NVLS); init hangs without it. + NVTE_CHECK(cuda::supports_multicast(device), + "NCCL EP requires CUDA multicast (NVLS) support on device ", device, + " but CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED reports 0."); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + inst.handles_.clear(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + ncclEpGroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed — caller destroys + inst.initialized_ = false; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = make_tensor(handle_mem, 1, ncclUint8, hm_sizes); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(ncclEpInitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, num_topk, + &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + handles_.clear(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + + NVTE_CHECK_NCCL(ncclEpCreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Per-handle_id config cache +// --------------------------------------------------------------------------- + +uint64_t EPBackend::insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment) { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + handle_cache_cap_ = (cap_env != nullptr) ? std::max(1, std::atoi(cap_env)) : 8192; + } + NVTE_CHECK(handles_.size() < handle_cache_cap_, "EP handle cache full (", handle_cache_cap_, + " entries). Raise via NVTE_EP_HANDLE_CACHE_SIZE."); + uint64_t id = next_handle_id_.fetch_add(1, std::memory_order_relaxed); + handles_.emplace(id, HandleEntry{handle_mem_size, alignment, top_k}); + return id; +} + +EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { + auto it = handles_.find(handle_id); + NVTE_CHECK(it != handles_.end(), "ep op on handle_id=", handle_id, + " with no cached config — call ep_prepare first."); + return it->second; +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +uint64_t EPBackend::register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(layer_config.top_k > 0, "NVTEEpLayerConfig.top_k must be > 0"); + NVTE_CHECK(handle_mem_size != nullptr, "handle_mem_size must not be null"); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, + layer_config.top_k)); + *handle_mem_size = hm_size; + std::lock_guard lock(mutex_); + return insert_new_entry(hm_size, layer_config.top_k, + layer_config.dispatch_output_per_expert_alignment); +} + +void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + void* idx_data = nvte_tensor_data(topk_idx); + NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null"); + + const size_t num_tokens = idx_shape.data[0]; + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + const size_t num_local_experts = + static_cast(group_config_.num_experts / group_config_.ep_size); + + size_t idx_sizes[2] = {num_tokens, top_k}; + ncclEpTensor_t nccl_topk_idx = make_tensor(idx_data, 2, ncclInt64, idx_sizes); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + size_t cnt_sizes[1] = {num_local_experts}; + ncclEpTensor_t token_counts_desc; + void* token_counts_data = (token_counts != nullptr) ? nvte_tensor_data(token_counts) : nullptr; + if (token_counts_data != nullptr) { + token_counts_desc = make_tensor(token_counts_data, 1, ncclInt32, cnt_sizes); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape tok_shape = nvte_tensor_shape(tokens); + NVTEDType tok_dtype = nvte_tensor_type(tokens); + + const size_t num_tokens = tok_shape.data[0]; + const size_t hidden_dim = tok_shape.data[1]; + + size_t tok_sizes[2] = {num_tokens, hidden_dim}; + ncclEpTensor_t nccl_tokens_in = + make_payload_tensor(tokens, tokens_win, 2, nvte_dtype_to_nccl(tok_dtype), tok_sizes); + + const bool is_forward = (topk_weights != nullptr); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + size_t weights_in_sizes[2] = {0, 0}; + ncclEpTensor_t nccl_topk_weights_in; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTEShape idx_shape = nvte_tensor_shape(topk_idx); + const size_t top_k = idx_shape.ndim > 1 ? idx_shape.data[1] : 1; + weights_in_sizes[0] = num_tokens; + weights_in_sizes[1] = top_k; + nccl_topk_weights_in = + make_payload_tensor(topk_weights, topk_weights_win, 2, ncclFloat32, weights_in_sizes); + } + + NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + + size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; + ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, + nvte_dtype_to_nccl(recv_dtype), recv_sizes); + + size_t weights_out_sizes[1] = {recv_shape.data[0]}; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTEShape recv_w_shape = nvte_tensor_shape(recv_topk_weights); + NVTE_CHECK(recv_w_shape.ndim == 1, "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_out = make_payload_tensor(recv_topk_weights, recv_topk_weights_win, 1, + ncclFloat32, weights_out_sizes); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape exp_shape = nvte_tensor_shape(expert_out); + NVTEDType exp_dtype = nvte_tensor_type(expert_out); + + size_t exp_sizes[2] = {exp_shape.data[0], exp_shape.data[1]}; + ncclEpTensor_t nccl_expert_in = + make_payload_tensor(expert_out, expert_out_win, 2, nvte_dtype_to_nccl(exp_dtype), exp_sizes); + + NVTEShape res_shape = nvte_tensor_shape(result); + void* res_data = nvte_tensor_data(result); + NVTEDType res_dtype = nvte_tensor_type(result); + NVTE_CHECK(res_data != nullptr, "result data must not be null"); + + size_t res_sizes[2] = {res_shape.data[0], res_shape.data[1]}; + ncclEpTensor_t nccl_result_out = + make_tensor(res_data, 2, nvte_dtype_to_nccl(res_dtype), res_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + transient = + ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape g_shape = nvte_tensor_shape(grad); + NVTEDType g_dtype = nvte_tensor_type(grad); + size_t g_sizes[2] = {g_shape.data[0], g_shape.data[1]}; + ncclEpTensor_t nccl_tok_in = + make_payload_tensor(grad, grad_win, 2, nvte_dtype_to_nccl(g_dtype), g_sizes); + + // g_recv_topk_weights must be 1D [recv_capacity] — caller flattens. + NVTEShape gw_shape = nvte_tensor_shape(g_recv_topk_weights); + NVTE_CHECK(gw_shape.ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + size_t gw_sizes[1] = {gw_shape.data[0]}; + ncclEpTensor_t nccl_w_in = + make_payload_tensor(g_recv_topk_weights, g_recv_topk_weights_win, 1, ncclFloat32, gw_sizes); + + NVTEShape gt_shape = nvte_tensor_shape(grad_tokens); + void* gt_data = nvte_tensor_data(grad_tokens); + NVTE_CHECK(gt_data != nullptr, "grad_tokens data must not be null"); + size_t gt_sizes[2] = {gt_shape.data[0], gt_shape.data[1]}; + ncclEpTensor_t nccl_tok_out = make_tensor(gt_data, 2, nvte_dtype_to_nccl(g_dtype), gt_sizes); + + NVTEShape gtw_shape = nvte_tensor_shape(grad_topk_weights); + void* gtw_data = nvte_tensor_data(grad_topk_weights); + NVTE_CHECK(gtw_data != nullptr, "grad_topk_weights data must not be null"); + NVTE_CHECK(gtw_shape.ndim == 2, "grad_topk_weights must be 2D [T, top_k]"); + size_t gtw_sizes[2] = {gtw_shape.data[0], gtw_shape.data[1]}; + ncclEpTensor_t nccl_w_out = make_tensor(gtw_data, 2, ncclFloat32, gtw_sizes); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + ScopedEpHandle transient; + { + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + transient = ScopedEpHandle( + open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); + } + NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_id, handle_mem, /*topk_idx=*/nullptr, grad, grad_win, /*topk_weights=*/nullptr, + /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..18307ebb4f --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. + * + * Per handle_id the cache stores config only (no device pointers), so + * handle_mem may be relocated between ops. Cap: NVTE_EP_HANDLE_CACHE_SIZE + * (default 8192); overflow throws. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: reserve a fresh handle_id, cache the layer config, and report + // the handle_mem buffer size the caller must allocate. + uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + + void prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETensor token_counts, + void* handle_mem, size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + + void dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, + const NVTETensor tokens, const NVTECommWindow& tokens_win, + const NVTETensor topk_weights, const NVTECommWindow& topk_weights_win, + NVTETensor recv_tokens, const NVTECommWindow& recv_tokens_win, + NVTETensor recv_topk_weights, const NVTECommWindow& recv_topk_weights_win, + cudaStream_t stream); + + void combine(uint64_t handle_id, void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, NVTETensor grad_expert_out, + const NVTECommWindow& grad_expert_out_win, cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed — caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + static ncclDataType_t nvte_dtype_to_nccl(NVTEDType dtype); + // Open a transient ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + bool initialized_{false}; + std::mutex mutex_; + struct HandleEntry { + size_t handle_mem_size; + size_t alignment; + int top_k; + }; + std::unordered_map handles_; + std::atomic next_handle_id_{1}; // 0 reserved as "no id" + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + + // Caller must hold mutex_. Throws on cap overflow. + uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); + HandleEntry& lookup_config(uint64_t handle_id); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..088ea7f0c3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops. + * Pass ``{NULL, 0}`` to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */ +typedef struct { + ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within ``window``. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..8c3a06b5f0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are allocation-free + * and CUDA graph-capturable. + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ── Config structs ─────────────────────────────────────────────────────── */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + int ep_size; /*!< EP world size. */ + int num_experts; /*!< Total experts across all ranks. */ + int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */ + /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */ + int max_recv_tokens_per_rank; + int hidden_dim; /*!< Token hidden dimension. */ + int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ + /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ + int allow_handle_mem_reloc; +} NVTEEpGroupConfig; + +/*! \brief Per-layer EP configuration. */ +typedef struct { + int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */ + int top_k; /*!< Per-token expert fan-out. Required. */ + size_t dispatch_output_per_expert_alignment; + /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match + * between nvte_ep_register_layer and nvte_ep_prepare. */ +} NVTEEpLayerConfig; + +/* ── Bootstrap ──────────────────────────────────────────────────────────── */ + +/*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * Re-init after shutdown is allowed; double-init throws. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* ── Layer registration (host-only, eager) ───────────────────────────────── */ + +/*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer + * size the caller must allocate. Host-only. + * + * \param[in] layer_config Per-layer EP configuration. + * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. + * \return uint64_t handle_id (non-zero). + */ +uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); + +/*! \brief Per-step handle: the registered handle_id paired with its handle_mem buffer. */ +typedef struct { + uint64_t id; /*!< Handle id from nvte_ep_register_layer. */ + NVTETensor mem; /*!< Caller-allocated handle_mem buffer (size from nvte_ep_register_layer). */ +} NVTEEpHandle; + +/* ── Per-step ops (all allocation-free, CUDA graph-capturable) ──────────── */ + +/*! \brief AllGather the routing map; write per-expert counts and cache routing + * metadata in handle.mem for the subsequent dispatch/combine. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] dispatch_output_per_expert_alignment Must match the handle_mem sizing. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts, + size_t dispatch_output_per_expert_alignment, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for ``tokens``. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for ``topk_weights``. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. Unweighted — + * caller must pre-multiply expert_out by recv_topk_weights (and the + * valid-slot mask) before calling. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for ``expert_out``. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch — routes token and weight grads back to source. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine. Padded slots in grad_expert_out are zeroed. + * + * \param[in] handle EP handle (id + mem buffer). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_ From 0b9bf7ec367d3642428469f7e94538d1784ec204 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 19:36:55 +0000 Subject: [PATCH 13/36] Expert Parallelism: persistent ncclEpHandle cache with allow_handle_mem_reloc gating Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep_coverage.cu | 183 ++++++++++++++++++++ transformer_engine/common/ep/ep_backend.cpp | 109 +++++------- transformer_engine/common/ep/ep_backend.h | 8 + 3 files changed, 238 insertions(+), 62 deletions(-) diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu index ef7941905d..e9e532386c 100644 --- a/tests/cpp_distributed/test_ep_coverage.cu +++ b/tests/cpp_distributed/test_ep_coverage.cu @@ -369,6 +369,189 @@ TEST_F(NegativeTests, NullHandleMemThrows) { CHECK_CUDA(cudaStreamDestroy(stream)); } +// ============================================================================= +// HandleCacheTest: persistent ncclEpHandle is reused across ops on the same +// handle_mem ptr; relocation triggers throw by default and rebuild when +// NVTEEpGroupConfig.allow_handle_mem_reloc=1. +// ============================================================================= + +class HandleCacheTest : public EpCoverageBase {}; + +// Run prepare → dispatch → combine on bundle b. handle_mem_data overrides the +// device ptr used for handle_mem (must be the buffer owned by b unless +// reloc-allowed mode is active). Templated on Bundle because EpCoverageBase:: +// Bundle is declared in a protected section. +template +static void run_round_trip(B& b, void* handle_mem_data, + int num_tokens, int top_k, int num_local_experts, + int hidden_dim, size_t alignment, + cudaStream_t stream) { + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts}, kNVTEInt32); + auto handle_mem_t = make_nvte_tensor(handle_mem_data, + {b.handle_mem_size}, kNVTEByte); + auto tokens_t = make_nvte_tensor(b.tokens.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), + {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); + auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), + {b.recv_capacity}, kNVTEFloat32); + auto result_t = make_nvte_tensor(b.result.get(), + {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); + + NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; + nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, alignment, stream); + nvte_ep_dispatch(h, topk_idx_t.tensor, tokens_t.tensor, NVTECommWindow{}, + topk_weights_t.tensor, NVTECommWindow{}, + recv_tokens_t.tensor, NVTECommWindow{}, + recv_w_t.tensor, NVTECommWindow{}, stream); + nvte_ep_combine(h, recv_tokens_t.tensor, NVTECommWindow{}, result_t.tensor, stream); +} + +// Re-bootstrap EP backend with a different allow_handle_mem_reloc setting. +// Reuses the existing g_ep_comm; caller is responsible for restoring defaults. +static void reinit_ep_with_reloc(int allow_reloc) { + nvte_ep_shutdown(); + NVTEEpGroupConfig cfg{}; + cfg.ep_size = g_ep_size; + cfg.num_experts = g_num_experts; + cfg.max_tokens_per_rank = g_max_tokens_per_rank; + cfg.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + cfg.hidden_dim = g_hidden_dim; + cfg.allow_handle_mem_reloc = allow_reloc; + nvte_ep_initialize(static_cast(g_ep_comm), cfg); +} + +TEST_F(HandleCacheTest, ReuseSameMemSucceeds) { + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // Two consecutive round-trips on the same handle_mem ptr: first opens the + // cached handle, second hits the cache. Both must succeed and be correct. + for (int iter = 0; iter < 2; ++iter) { + ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + } + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(HandleCacheTest, RelocDefaultThrows) { + // Default bootstrap has allow_handle_mem_reloc=0: a second prepare call on + // the same handle_id with a different handle_mem ptr must throw. + const int num_tokens = 8, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + DevBuf second_hm(b.handle_mem_size); // distinct device buffer + ASSERT_NE(b.handle_mem.get(), second_hm.get()); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + + auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), + {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); + auto token_counts_t = make_nvte_tensor(b.token_counts.get(), + {(size_t)num_local_experts_}, kNVTEInt32); + auto hm1_t = make_nvte_tensor(b.handle_mem.get(), + {b.handle_mem_size}, kNVTEByte); + auto hm2_t = make_nvte_tensor(second_hm.get(), + {b.handle_mem_size}, kNVTEByte); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // First prepare seeds the cache. + NVTEEpHandle h1{b.handle_id, hm1_t.tensor}; + ASSERT_NO_THROW(nvte_ep_prepare(h1, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + // Same handle_id with a different handle_mem ptr must throw. + NVTEEpHandle h2{b.handle_id, hm2_t.tensor}; + EXPECT_THROW(nvte_ep_prepare(h2, topk_idx_t.tensor, token_counts_t.tensor, + /*alignment=*/0, stream), + std::exception); + CHECK_CUDA(cudaStreamDestroy(stream)); +} + +TEST_F(HandleCacheTest, RelocAllowedRebuilds) { + // Re-init EP backend with allow_handle_mem_reloc=1, run two round-trips with + // distinct handle_mem buffers, verify both succeed numerically, restore. + reinit_ep_with_reloc(/*allow_reloc=*/1); + + struct Restore { ~Restore() { reinit_ep_with_reloc(/*allow_reloc=*/0); } } restore; + + const int num_tokens = 16, top_k = 2; + Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); + DevBuf alt_hm(b.handle_mem_size); + ASSERT_NE(b.handle_mem.get(), alt_hm.get()); + + auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens * top_k, 1.0f / top_k); + auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); + CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + CHECK_CUDA(cudaStreamCreate(&stream)); + + // First on the original handle_mem. + ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + // Then on the relocated handle_mem — must trigger silent rebuild, not throw. + ASSERT_NO_THROW(run_round_trip(b, alt_hm.get(), num_tokens, top_k, + num_local_experts_, hidden_dim_, + /*alignment=*/0, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_res(num_tokens * hidden_dim_); + CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), + h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; + for (int t = 0; t < num_tokens; ++t) + for (int p : probes) + EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), + static_cast(top_k) * 0.5f, 1e-2f); + + CHECK_CUDA(cudaStreamDestroy(stream)); +} + // ── main ────────────────────────────────────────────────────────────────────── int main(int argc, char* argv[]) { diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index ae0f3ab888..6494a86817 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -57,32 +57,6 @@ inline ncclEpTensor_t make_payload_tensor(const NVTETensor t, const NVTECommWind return desc; } -// RAII guard for ncclEpHandle_t — destroys on scope exit, leak-free on throw. -class ScopedEpHandle { - public: - ScopedEpHandle() = default; - explicit ScopedEpHandle(ncclEpHandle_t h) : h_(h) {} - ~ScopedEpHandle() { - if (h_ != nullptr) ncclEpHandleDestroy(h_); - } - ScopedEpHandle(const ScopedEpHandle&) = delete; - ScopedEpHandle& operator=(const ScopedEpHandle&) = delete; - ScopedEpHandle(ScopedEpHandle&& other) noexcept : h_(other.h_) { other.h_ = nullptr; } - ScopedEpHandle& operator=(ScopedEpHandle&& other) noexcept { - if (this != &other) { - if (h_ != nullptr) ncclEpHandleDestroy(h_); - h_ = other.h_; - other.h_ = nullptr; - } - return *this; - } - operator ncclEpHandle_t() const { return h_; } - ncclEpHandle_t get() const { return h_; } - - private: - ncclEpHandle_t h_ = nullptr; -}; - } // namespace // --------------------------------------------------------------------------- @@ -158,6 +132,13 @@ void EPBackend::shutdown() { EPBackend& inst = instance(); std::lock_guard lock(inst.mutex_); if (!inst.initialized_) return; + for (auto& kv : inst.handles_) { + if (kv.second.cached_handle != nullptr) { + ncclEpHandleDestroy(kv.second.cached_handle); + kv.second.cached_handle = nullptr; + kv.second.cached_handle_mem = nullptr; + } + } inst.handles_.clear(); // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. if (inst.ep_group_ != nullptr) { @@ -196,7 +177,7 @@ ncclDataType_t EPBackend::nvte_dtype_to_nccl(NVTEDType dtype) { return ncclFloat32; // unreachable } -// Open a transient ncclEpHandle over handle_mem. Caller owns the result. +// Open a fresh ncclEpHandle over handle_mem. Caller (or cache) owns the result. ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, size_t dispatch_output_per_expert_alignment) { size_t hm_sizes[1] = {handle_mem_size}; @@ -273,6 +254,26 @@ EPBackend::HandleEntry& EPBackend::lookup_config(uint64_t handle_id) { return it->second; } +ncclEpHandle_t EPBackend::get_or_open_handle(HandleEntry& cfg, void* handle_mem) { + if (cfg.cached_handle != nullptr && cfg.cached_handle_mem == handle_mem) { + return cfg.cached_handle; + } + if (cfg.cached_handle != nullptr) { + NVTE_CHECK(group_config_.allow_handle_mem_reloc != 0, + "EP handle_mem relocated for cached handle (old=", + reinterpret_cast(cfg.cached_handle_mem), + ", new=", reinterpret_cast(handle_mem), + "). Set NVTEEpGroupConfig.allow_handle_mem_reloc=1 to allow rebuild."); + ncclEpHandleDestroy(cfg.cached_handle); + cfg.cached_handle = nullptr; + cfg.cached_handle_mem = nullptr; + } + ncclEpHandle_t h = open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment); + cfg.cached_handle = h; + cfg.cached_handle_mem = handle_mem; + return h; +} + // --------------------------------------------------------------------------- // Per-step operations // --------------------------------------------------------------------------- @@ -320,17 +321,13 @@ void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETenso ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; layout_info.expert_counters = (token_counts_data != nullptr) ? &token_counts_desc : nullptr; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, - "ep_prepare: alignment mismatch for handle_id=", handle_id, - " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpUpdateHandle(transient, &nccl_topk_idx, &layout_info, stream)); + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, + "ep_prepare: alignment mismatch for handle_id=", handle_id, + " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); } void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor topk_idx, @@ -397,14 +394,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpDispatch(transient, &in_struct, &out_struct, + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpDispatch(h, &in_struct, &out_struct, /*layout_info=*/nullptr, &dispatch_cfg, stream)); } @@ -436,14 +429,10 @@ void EPBackend::combine(uint64_t handle_id, void* handle_mem, const NVTETensor e ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; out_struct.tokens = &nccl_result_out; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& cfg = lookup_config(handle_id); - transient = - ScopedEpHandle(open_handle(handle_mem, cfg.handle_mem_size, cfg.top_k, cfg.alignment)); - } - NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, /*config=*/nullptr, stream)); + std::lock_guard lock(mutex_); + HandleEntry& cfg = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, /*config=*/nullptr, stream)); } void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, @@ -491,14 +480,10 @@ void EPBackend::dispatch_bwd(uint64_t handle_id, void* handle_mem, const NVTETen ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; cfg.pass_direction = NCCL_EP_BWD_PASS; - ScopedEpHandle transient; - { - std::lock_guard lock(mutex_); - HandleEntry& entry = lookup_config(handle_id); - transient = ScopedEpHandle( - open_handle(handle_mem, entry.handle_mem_size, entry.top_k, entry.alignment)); - } - NVTE_CHECK_NCCL(ncclEpCombine(transient, &in_struct, &out_struct, &cfg, stream)); + std::lock_guard lock(mutex_); + HandleEntry& entry = lookup_config(handle_id); + ncclEpHandle_t h = get_or_open_handle(entry, handle_mem); + NVTE_CHECK_NCCL(ncclEpCombine(h, &in_struct, &out_struct, &cfg, stream)); } void EPBackend::combine_bwd(uint64_t handle_id, void* handle_mem, const NVTETensor grad, diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h index 18307ebb4f..e82c974c3f 100644 --- a/transformer_engine/common/ep/ep_backend.h +++ b/transformer_engine/common/ep/ep_backend.h @@ -98,6 +98,10 @@ class EPBackend { size_t handle_mem_size; size_t alignment; int top_k; + // Persistent ncclEpHandle bound to cached_handle_mem. Lazily opened on first + // op; reused while handle_mem ptr is unchanged. Destroyed in shutdown(). + ncclEpHandle_t cached_handle{nullptr}; + void* cached_handle_mem{nullptr}; }; std::unordered_map handles_; std::atomic next_handle_id_{1}; // 0 reserved as "no id" @@ -106,6 +110,10 @@ class EPBackend { // Caller must hold mutex_. Throws on cap overflow. uint64_t insert_new_entry(size_t handle_mem_size, int top_k, size_t alignment); HandleEntry& lookup_config(uint64_t handle_id); + // Caller must hold mutex_. Returns the cached handle if handle_mem matches. + // On mismatch: if group_config_.allow_handle_mem_reloc != 0, destroys the + // stale handle and opens a fresh one; otherwise throws. + ncclEpHandle_t get_or_open_handle(HandleEntry& cfg, void* handle_mem); }; } // namespace ep From ed3d73cc84a215cd4d7c2c87db8eb6eae7e44b5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 May 2026 23:09:15 +0000 Subject: [PATCH 14/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/ep/ep_backend.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 6494a86817..83657943a4 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -324,8 +324,8 @@ void EPBackend::prepare(uint64_t handle_id, const NVTETensor topk_idx, NVTETenso std::lock_guard lock(mutex_); HandleEntry& cfg = lookup_config(handle_id); NVTE_CHECK(cfg.alignment == dispatch_output_per_expert_alignment, - "ep_prepare: alignment mismatch for handle_id=", handle_id, - " (cached=", cfg.alignment, ", got=", dispatch_output_per_expert_alignment, ")"); + "ep_prepare: alignment mismatch for handle_id=", handle_id, " (cached=", cfg.alignment, + ", got=", dispatch_output_per_expert_alignment, ")"); ncclEpHandle_t h = get_or_open_handle(cfg, handle_mem); NVTE_CHECK_NCCL(ncclEpUpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); } From 1923180bff3c02f0c95a91d997ce3a1301d414ec Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 14:12:53 -0700 Subject: [PATCH 15/36] Build: NCCL_HOME discovery supports Debian/Ubuntu multiarch lib paths Signed-off-by: Phuong Nguyen --- setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index db360c8a29..34a3abfd99 100644 --- a/setup.py +++ b/setup.py @@ -167,11 +167,13 @@ def _discover_nccl_home() -> str: f"'{env_home}/include/nccl.h' was not found; falling back to system probes." ) + lib_names = ("libnccl.so", "libnccl.so.2") + # Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu). + lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu") for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): p = Path(cand) if (p / "include" / "nccl.h").exists() and any( - (p / "lib" / name).exists() or (p / "lib64" / name).exists() - for name in ("libnccl.so", "libnccl.so.2") + (p / sub / name).exists() for sub in lib_subdirs for name in lib_names ): return str(p) @@ -180,9 +182,11 @@ def _discover_nccl_home() -> str: for line in out.splitlines(): if "libnccl.so" in line and "=>" in line: lib_path = Path(line.split("=>")[-1].strip()) - root = lib_path.parent.parent - if (root / "include" / "nccl.h").exists(): - return str(root) + # Walk upward so multiarch layouts (.../lib//libnccl.so) + # resolve to the prefix that contains include/nccl.h. + for root in (lib_path.parent.parent, lib_path.parent.parent.parent): + if (root / "include" / "nccl.h").exists(): + return str(root) except (subprocess.CalledProcessError, FileNotFoundError): pass From 3b8aafb0bd81f5d0d18bd633ac679905c7b47673 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 14:26:39 -0700 Subject: [PATCH 16/36] bump NCCL Signed-off-by: Phuong Nguyen --- 3rdparty/nccl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/nccl b/3rdparty/nccl index 6a9bc953ac..146496ac88 160000 --- a/3rdparty/nccl +++ b/3rdparty/nccl @@ -1 +1 @@ -Subproject commit 6a9bc953ac1c4eef92d5adbe3092d4c2cb0a4c98 +Subproject commit 146496ac881bc504ed1a52be0ae7b707ce41e706 From 9b225cbed1834d235234d9850ee6ee20f1f64c15 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:25:16 -0700 Subject: [PATCH 17/36] Expert Parallelism: require token_dtype in NVTEEpGroupConfig and enforce at dispatch Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep_common.h | 4 ++++ transformer_engine/common/ep/ep_backend.cpp | 21 +++++++++++++++---- .../common/include/transformer_engine/ep.h | 3 +++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index 77baa92b0c..ccb20ee3a0 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -74,6 +74,7 @@ static int g_ep_size = -1; static int g_num_experts = -1; static int g_hidden_dim = 256; static int g_max_tokens_per_rank = 64; +static NVTEDType g_token_dtype = kNVTEBFloat16; static bool g_ep_initialized = false; static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown @@ -224,6 +225,8 @@ static void ep_parse_args(int argc, char* argv[]) { else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); + else if (a.rfind("--token-dtype=", 0) == 0) + g_token_dtype = static_cast(std::stoi(a.substr(14))); } if (g_process_id < 0 || g_num_processes <= 0) { @@ -279,6 +282,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; group_config.hidden_dim = g_hidden_dim; + group_config.token_dtype = g_token_dtype; ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); nvte_ep_initialize(static_cast(g_ep_comm), group_config); diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 83657943a4..1e08cb55df 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -82,9 +82,13 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.hidden_dim * sizeof(nv_bfloat16) >= 16, - "hidden_dim * 2 must be >= 16 (NCCL EP 16B row alignment); got hidden_dim=", - config.hidden_dim); + NVTE_CHECK(config.token_dtype >= 0 && config.token_dtype < kNVTENumTypes, + "token_dtype out of range, got ", static_cast(config.token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.token_dtype)); + NVTE_CHECK(config.hidden_dim * elem_bytes >= 16, + "hidden_dim * sizeof(token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "got hidden_dim=", + config.hidden_dim, ", element_bytes=", elem_bytes); NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, ") must be divisible by ep_size (", config.ep_size, ")"); NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", @@ -214,7 +218,8 @@ void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; cfg.num_experts = static_cast(group_config.num_experts); cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); - cfg.max_token_bytes = static_cast(group_config.hidden_dim * sizeof(nv_bfloat16)); + const size_t elem_bytes = typeToSize(static_cast(group_config.token_dtype)); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); cfg.rdma_buffer_size = NCCL_EP_AUTO; cfg.num_qp_per_rank = NCCL_EP_AUTO; cfg.num_channels = NCCL_EP_AUTO; @@ -341,6 +346,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape tok_shape = nvte_tensor_shape(tokens); NVTEDType tok_dtype = nvte_tensor_type(tokens); + NVTE_CHECK(tok_dtype == group_config_.token_dtype, + "tokens dtype (", static_cast(tok_dtype), + ") does not match group token_dtype (", + static_cast(group_config_.token_dtype), ")"); const size_t num_tokens = tok_shape.data[0]; const size_t hidden_dim = tok_shape.data[1]; @@ -367,6 +376,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + NVTE_CHECK(recv_dtype == group_config_.token_dtype, + "recv_tokens dtype (", static_cast(recv_dtype), + ") does not match group token_dtype (", + static_cast(group_config_.token_dtype), ")"); size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 8c3a06b5f0..ac7f1dbf07 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -35,6 +35,9 @@ typedef struct { int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ int allow_handle_mem_reloc; + /*! Token dtype for this EP group. Sizes NCCL EP staging buffers at group + * create and is enforced against tensors passed to nvte_ep_dispatch. */ + NVTEDType token_dtype; } NVTEEpGroupConfig; /*! \brief Per-layer EP configuration. */ From 03e56d221d28fe129eeede510168723ee2d26d68 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:31:47 -0700 Subject: [PATCH 18/36] Expert Parallelism: document ep_comm lifetime, v0.1 single-GPU scope, static layer registration Signed-off-by: Phuong Nguyen --- .../common/include/transformer_engine/ep.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index ac7f1dbf07..a1c9305e9b 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -54,8 +54,13 @@ typedef struct { /*! \brief Bootstrap from an existing NCCL EP sub-communicator. Requires SM>=90. * * ep_comm is borrowed and must span exactly group_config.ep_size ranks. + * The caller retains ownership and must keep ep_comm alive until + * nvte_ep_shutdown() returns; destroying it earlier is undefined behavior. * Re-init after shutdown is allowed; double-init throws. * + * v0.1 scope: one EP group per process, bound to the current CUDA device at + * initialize time. Multiple GPUs per process are not supported. + * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. * \param[in] group_config Group-level EP configuration. */ @@ -69,6 +74,11 @@ void nvte_ep_shutdown(void); /*! \brief Reserve a handle_id for a layer config and report the handle_mem buffer * size the caller must allocate. Host-only. * + * Registration is intended to be static (once per layer at model init). There is + * no per-layer unregister API; all registrations are released by nvte_ep_shutdown. + * Re-registering the same layer config each step is not supported and will + * eventually exhaust the handle cache (NVTE_EP_HANDLE_CACHE_SIZE, default 8192). + * * \param[in] layer_config Per-layer EP configuration. * \param[out] handle_mem_size Bytes the caller must allocate for handle_mem. * \return uint64_t handle_id (non-zero). From 4cefdcb2ad71f95be154516a4234a44c50eef641 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 15:32:48 -0700 Subject: [PATCH 19/36] Expert Parallelism: drop version label from initialize scope note Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/CMakeLists.txt | 45 +- tests/cpp_distributed/run_test_ep.sh | 123 +--- .../{test_ep_pipeline.cu => test_ep.cu} | 643 ++++++++---------- tests/cpp_distributed/test_ep_common.h | 194 +----- tests/cpp_distributed/test_ep_coverage.cu | 562 --------------- tests/cpp_distributed/test_ep_init.cu | 64 -- transformer_engine/common/ep/ep_backend.cpp | 25 +- .../common/include/transformer_engine/ep.h | 13 +- transformer_engine/common/util/logging.h | 8 + 9 files changed, 376 insertions(+), 1301 deletions(-) rename tests/cpp_distributed/{test_ep_pipeline.cu => test_ep.cu} (51%) delete mode 100644 tests/cpp_distributed/test_ep_coverage.cu delete mode 100644 tests/cpp_distributed/test_ep_init.cu diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 3870f57911..7dd8ea33e7 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) @@ -73,10 +73,8 @@ target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) -# ── EP distributed tests (HT mode) ───────────────────────────────────────── -# No MPI dependency — processes are spawned by run_test_ep.sh with -# --rank / --nranks flags. ncclUniqueId exchange uses a -# shared temp file (see test_ep_common.h for details). +# ── EP distributed tests ────────────────────────────────────────────────────── +# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h). # Headers + libs come from the in-tree 3rdparty/nccl submodule build. set(NCCL_EP_SUBMODULE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") @@ -103,41 +101,28 @@ endif() set(EP_TEST_COMMON_INCLUDES ${EP_TEST_NCCL_INCLUDES} + ${MPI_CXX_INCLUDE_PATH} ../../transformer_engine/common/include ../../transformer_engine/common ${CMAKE_CURRENT_SOURCE_DIR}) +# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so +# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc. set(EP_TEST_COMMON_LIBS CUDA::cuda_driver CUDA::cudart - CUDA::nvrtc GTest::gtest ${TE_LIB} + CUDA::nvrtc ${NCCL_LIB} - ${NCCL_EP_LIB}) - -# nvrtc symbols are referenced from libtransformer_engine.so but not in its -# DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc -# explicitly with --no-as-needed so the linker keeps the dependency. -set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") - -# ── EP init tests (InitPath, HandleMemSizeQuery) ───────────────────────────── -add_executable(test_ep_init test_ep_init.cu) -target_include_directories(test_ep_init PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_init PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_init PUBLIC ${EP_TEST_LINK_OPTS}) - -# ── EP pipeline tests (dispatch, combine, bwd, integrated) ─────────────────── -add_executable(test_ep_pipeline test_ep_pipeline.cu) -target_include_directories(test_ep_pipeline PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_pipeline PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_pipeline PUBLIC ${EP_TEST_LINK_OPTS}) - -# ── EP coverage tests (multi-handle, top_k=1, empty experts, negatives, threading) ── -add_executable(test_ep_coverage test_ep_coverage.cu) -target_include_directories(test_ep_coverage PRIVATE ${EP_TEST_COMMON_INCLUDES}) -target_link_libraries(test_ep_coverage PUBLIC ${EP_TEST_COMMON_LIBS}) -target_link_options(test_ep_coverage PUBLIC ${EP_TEST_LINK_OPTS}) + ${NCCL_EP_LIB} + MPI::MPI_CXX + OpenMP::OpenMP_CXX) + +# ── EP distributed tests (per-op + full pipeline + zero-copy symm) ─────────── +add_executable(test_ep test_ep.cu ../cpp/test_common.cu) +target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS}) # Do NOT use gtest_discover_tests — these binaries require multi-process # launch via run_test_ep.sh, not direct single-process execution. diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh index 017d3f807b..13e86fa02d 100755 --- a/tests/cpp_distributed/run_test_ep.sh +++ b/tests/cpp_distributed/run_test_ep.sh @@ -3,12 +3,8 @@ # # See LICENSE for license information. # -# Run TE EP distributed unit tests across multiple GPUs. -# -# Spawns one background bash process per GPU (no MPI dependency), matching the -# JAX multi-process launcher style. ncclUniqueId is exchanged via a shared -# temp file (see test_ep_common.h). Each rank builds its own ncclComm_t and -# passes it to nvte_ep_initialize. +# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU +# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast. # # Usage: # bash run_test_ep.sh [num_gpus] [build_dir] @@ -18,15 +14,16 @@ # build_dir = /build # # Environment variables: -# GTEST_FILTER — forwarded to all processes (e.g., "EPDispatchTest.*") -# TEST_TIMEOUT_S — per-process timeout in seconds (default: 180) +# GTEST_FILTER — forwarded to all processes (e.g., "EPPipelineTest.*") +# MPIRUN — override the mpirun binary (default: mpirun) +# MPIRUN_EXTRA — extra flags forwarded to mpirun set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_DIR="${2:-${SCRIPT_DIR}/build}" NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" -TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" +MPIRUN="${MPIRUN:-mpirun}" # Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ @@ -36,102 +33,22 @@ if (( MIN_SM > 0 && MIN_SM < 90 )); then exit 0 fi -GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" -OVERALL_FAIL=0 - -# --------------------------------------------------------------------------- -# run_suite BINARY SUITE_NAME MIN_GPUS -# --------------------------------------------------------------------------- -run_suite() { - local BINARY="$1" - local SUITE_NAME="$2" - local MIN_GPUS="${3:-2}" - - local TEST_BIN="${BUILD_DIR}/${BINARY}" - - if [[ ! -x "${TEST_BIN}" ]]; then - echo "ERROR: binary not found: ${TEST_BIN}" - echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" - OVERALL_FAIL=1 - return - fi - - if (( NUM_GPUS < MIN_GPUS )); then - echo "${SUITE_NAME}: requires ${MIN_GPUS} GPUs, found ${NUM_GPUS}. Skipping." - return - fi - - local TMPDIR_L="${TMPDIR:-/tmp}" - local UID_FILE="${TMPDIR_L}/te_ep_uid_${BINARY}_$$" - rm -f "${UID_FILE}" - - local LOG_DIR - LOG_DIR=$(mktemp -d) - local FAIL=0 - - echo "=== ${SUITE_NAME} ===" - echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" - echo - - # Spawn one background process per GPU. ncclUniqueId is exchanged via the - # shared UID_FILE. Each process is wrapped in `timeout` to detect hangs early. - local PIDS=() - for i in $(seq 0 $((NUM_GPUS - 1))); do - timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ - "${TEST_BIN}" \ - --rank="${i}" \ - --nranks="${NUM_GPUS}" \ - --uid-file="${UID_FILE}" \ - ${GTEST_ARGS} \ - > "${LOG_DIR}/rank_${i}.log" 2>&1 & - PIDS+=($!) - done - for i in $(seq 0 $((NUM_GPUS - 1))); do - if ! wait "${PIDS[$i]}"; then - local rc=$? - FAIL=1 - if [[ $rc -eq 137 || $rc -eq 124 ]]; then - echo " rank ${i}: TIMEOUT after ${TEST_TIMEOUT_S}s (rc=${rc})" - fi - fi - done - - echo "--- Rank 0 output ---" - cat "${LOG_DIR}/rank_0.log" - - if (( FAIL )); then - for i in $(seq 1 $((NUM_GPUS - 1))); do - echo "--- Rank ${i} output ---" - cat "${LOG_DIR}/rank_${i}.log" - done - echo "=== ${SUITE_NAME}: FAILED ===" - OVERALL_FAIL=1 - else - echo "=== ${SUITE_NAME}: ALL PASSED ===" - fi - - rm -rf "${LOG_DIR}" - rm -f "${UID_FILE}" -} +TEST_BIN="${BUILD_DIR}/test_ep" +if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + exit 1 +fi -# --------------------------------------------------------------------------- -# Cleanup on abort -# --------------------------------------------------------------------------- -cleanup() { rm -f "${TMPDIR:-/tmp}"/te_ep_uid_*_"$$" 2>/dev/null || true; } -trap cleanup EXIT INT TERM +if (( NUM_GPUS < 2 )); then + echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping." + exit 0 +fi -# --------------------------------------------------------------------------- -# Run all suites -# --------------------------------------------------------------------------- -run_suite "test_ep_init" "EP Init Tests" 2 -run_suite "test_ep_pipeline" "EP Pipeline Tests" 2 -run_suite "test_ep_coverage" "EP Coverage Tests" 2 +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" +echo "=== EP Tests ===" +echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" echo -if (( OVERALL_FAIL )); then - echo "=== SOME SUITES FAILED ===" -else - echo "=== ALL SUITES PASSED ===" -fi -exit "${OVERALL_FAIL}" +"${MPIRUN}" -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS} diff --git a/tests/cpp_distributed/test_ep_pipeline.cu b/tests/cpp_distributed/test_ep.cu similarity index 51% rename from tests/cpp_distributed/test_ep_pipeline.cu rename to tests/cpp_distributed/test_ep.cu index 41f83a6d11..bcf4ca3c98 100644 --- a/tests/cpp_distributed/test_ep_pipeline.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -39,10 +39,21 @@ static inline float token_value(int rank, int t, int num_tokens) { return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); } -static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { - std::vector v(num_tokens * hidden_dim); +// Per-element host-side conversion helpers used by templated test code. +inline float tok_to_float(nv_bfloat16 v) { return __bfloat162float(v); } +inline float tok_to_float(__half v) { return __half2float(v); } +inline float tok_to_float(float v) { return v; } + +template T tok_from_float(float v); +template <> inline nv_bfloat16 tok_from_float(float v) { return __float2bfloat16(v); } +template <> inline __half tok_from_float<__half> (float v) { return __float2half(v); } +template <> inline float tok_from_float (float v) { return v; } + +template +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); for (int t = 0; t < num_tokens; ++t) { - nv_bfloat16 val = __float2bfloat16(token_value(rank, t, num_tokens)); + T val = tok_from_float(token_value(rank, t, num_tokens)); for (int h = 0; h < hidden_dim; ++h) v[t * hidden_dim + h] = val; } @@ -85,17 +96,20 @@ static std::vector expected_recv_values_sorted( return vals; } -// BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for -// accumulation noise inside dispatch/combine. +// 2^-5 relative tolerance for BF16 (matches mantissa precision with margin), +// plus a small atol floor for near-zero expected values. +static constexpr float kBf16Rtol = 1.0f / 32.0f; +static constexpr float kBf16Atol = 1e-3f; static float bf16_tol(float magnitude) { - return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); + return kBf16Atol + kBf16Rtol * std::fabs(magnitude); } -static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name) { - std::vector h(count); - cudaMemcpy(h.data(), dev, count * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost); +template +static bool check_no_nan_inf(const T* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(T), cudaMemcpyDeviceToHost); for (int i = 0; i < count; ++i) { - float v = __bfloat162float(h[i]); + float v = tok_to_float(h[i]); if (std::isnan(v) || std::isinf(v)) { fprintf(stderr, "Rank %d: %s in %s[%d]\n", g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); @@ -107,20 +121,21 @@ static bool check_no_nan_inf(const nv_bfloat16* dev, int count, const char* name // ── Forward buffer set with RAII ────────────────────────────────────────────── +template struct EPBuffers { // Forward DevBuf topk_idx; DevBuf topk_weights; - DevBuf tokens; + DevBuf tokens; DevBuf token_counts; DevBuf handle_mem; - DevBuf recv_tokens; + DevBuf recv_tokens; DevBuf recv_topk_weights; - DevBuf result; + DevBuf result; // Backward - DevBuf grad_result; - DevBuf grad_expert; - DevBuf grad_tokens; + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; DevBuf g_recv_topk_weights; DevBuf grad_topk_weights; @@ -154,42 +169,45 @@ struct EPBuffers { } }; -// Bundled NVTETensor views over an EPBuffers — one place to update the shape -// conventions when the C-API evolves. +// Bundled NVTETensor views over an EPBuffers, with the shapes the EP C API +// expects. +template struct EPTensors { - TensorHandle topk_idx, topk_weights, token_counts, handle_mem, tokens; - TensorHandle recv_tokens, recv_topk_weights, result; - TensorHandle grad_result, grad_expert, grad_tokens; - TensorHandle g_recv_topk_weights, grad_topk_weights; + TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper recv_tokens, recv_topk_weights, result; + TensorWrapper grad_result, grad_expert, grad_tokens; + TensorWrapper g_recv_topk_weights, grad_topk_weights; - EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, int num_local_experts) { - topk_idx = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - topk_weights = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - token_counts = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts}, kNVTEInt32); - handle_mem = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - tokens = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - recv_tokens = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - recv_topk_weights = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - result = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - grad_result = make_nvte_tensor(b.grad_result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - grad_expert = make_nvte_tensor(b.grad_expert.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - grad_tokens = make_nvte_tensor(b.grad_tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - g_recv_topk_weights = make_nvte_tensor(b.g_recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - grad_topk_weights = make_nvte_tensor(b.grad_topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); + constexpr DType kTokDType = test::TypeInfo::dtype; + using Shape = std::vector; + topk_idx = TensorWrapper(b.topk_idx.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); + topk_weights = TensorWrapper(b.topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + token_counts = TensorWrapper(b.token_counts.get(), + Shape{(size_t)num_local_experts}, DType::kInt32); + handle_mem = TensorWrapper(b.handle_mem.get(), + Shape{b.handle_mem_size}, DType::kByte); + tokens = TensorWrapper(b.tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + recv_tokens = TensorWrapper(b.recv_tokens.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + recv_topk_weights = TensorWrapper(b.recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + result = TensorWrapper(b.result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_result = TensorWrapper(b.grad_result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_expert = TensorWrapper(b.grad_expert.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + grad_tokens = TensorWrapper(b.grad_tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + g_recv_topk_weights = TensorWrapper(b.g_recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + grad_topk_weights = TensorWrapper(b.grad_topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); } }; @@ -215,29 +233,31 @@ class EpOpTestBase : public ::testing::Test { num_tokens_ = 32; } - void upload_inputs(EPBuffers& buf, int rank = -1) { + template + void upload_inputs(EPBuffers& buf, int rank = -1) { if (rank < 0) rank = g_process_id; auto h_idx = routing_balanced(rank, num_tokens_, top_k_, num_experts_, num_local_experts_); std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); - auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); - - CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(T), cudaMemcpyHostToDevice)); } NVTEEpLayerConfig layer_config(size_t alignment = 0) const { return NVTEEpLayerConfig{num_local_experts_, top_k_, alignment}; } - // ASSERT_CUDA_OK (fprintf+exit) so this non-void helper stays legal. - int read_total_recv(const EPBuffers& buf) const { + // NVTE_CHECK_CUDA (fprintf+exit) so this non-void helper stays legal. + template + int read_total_recv(const EPBuffers& buf) const { std::vector cnt(num_local_experts_); - ASSERT_CUDA_OK(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); int total = 0; for (int c : cnt) total += c; @@ -252,28 +272,28 @@ class EpOpTestBase : public ::testing::Test { class EPDispatchTest : public EpOpTestBase {}; TEST_F(EPDispatchTest, PrepareAndDispatch) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + NVTE_CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); // 1. Per-expert counts. std::vector got_counts(num_local_experts_); - CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, num_experts_, num_local_experts_); @@ -288,7 +308,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { // 2. Recv values: read only the filled prefix per local-expert zone, not the // whole recv buffer — avoids false positives from legitimate-zero token values. std::vector h_recv(buf.recv_capacity * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); std::vector got_vals; @@ -312,7 +332,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). std::vector h_w(buf.recv_capacity); - CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); const float exp_w = 1.0f / static_cast(top_k_); for (int i = 0; i < total_recv; ++i) @@ -321,7 +341,7 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { if (g_process_id == 0) printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -331,34 +351,32 @@ TEST_F(EPDispatchTest, PrepareAndDispatch) { class EPCombineTest : public EpOpTestBase {}; TEST_F(EPCombineTest, Combine) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - // Spot-check 3 hidden-dim positions per token to catch partial-row writes. - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; for (int tok = 0; tok < num_tokens_; ++tok) { float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); - for (int p : probes) { + for (int p = 0; p < hidden_dim_; ++p) { float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); EXPECT_NEAR(got, exp, bf16_tol(exp)) << "token " << tok << " rank " << g_process_id << " hidden " << p; @@ -368,7 +386,7 @@ TEST_F(EPCombineTest, Combine) { if (g_process_id == 0) printf(" Combine: passed (result == top_k * tokens)\n"); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -378,41 +396,41 @@ TEST_F(EPCombineTest, Combine) { class EPCombineBwdTest : public EpOpTestBase {}; TEST_F(EPCombineBwdTest, CombineBwdCheck) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), h_grad_r.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); int total_recv = read_total_recv(buf); std::vector cnt(num_local_experts_); - CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); std::vector h_ge(buf.recv_capacity * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Walk filled slots by per-expert zone (no v != 0 heuristic). @@ -421,9 +439,12 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { int filled = 0; for (int e = 0; e < num_local_experts_; ++e) { for (int i = 0; i < cnt[e]; ++i) { - float v = __bfloat162float(h_ge[slot * hidden_dim_]); - EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) - << "grad_expert expert " << e << " slot " << i << " (linear " << slot << ")"; + for (int p = 0; p < hidden_dim_; ++p) { + float v = __bfloat162float(h_ge[slot * hidden_dim_ + p]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i + << " (linear " << slot << ") hidden " << p; + } ++filled; ++slot; } } @@ -432,7 +453,7 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { if (g_process_id == 0) printf(" CombineBwdCheck: passed (filled=%d)\n", filled); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -442,51 +463,53 @@ TEST_F(EPCombineBwdTest, CombineBwdCheck) { class EPDispatchBwdTest : public EpOpTestBase {}; TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), h_grad.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_gt(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); const float kExpGrad = static_cast(top_k_) * 0.1f; for (int tok = 0; tok < num_tokens_; ++tok) - EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) - << "grad_tokens token " << tok; + for (int p = 0; p < hidden_dim_; ++p) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_ + p]), kExpGrad, + bf16_tol(kExpGrad)) + << "grad_tokens token " << tok << " hidden " << p; if (g_process_id == 0) printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= @@ -496,11 +519,11 @@ TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { - EPBuffers buf; + EPBuffers<> buf; buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); // Distinct per-(rank, t, k) weights so each slot carries a unique value. std::vector h_w(num_tokens_ * top_k_); @@ -508,39 +531,39 @@ TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { for (int k = 0; k < top_k_; ++k) h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + 0.0001f * (g_process_id + 1); - CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. std::vector h_nan(num_tokens_ * top_k_, std::numeric_limits::quiet_NaN()); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), h_nan.size() * sizeof(float), cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); // g_recv_topk_weights := recv_topk_weights (the round-trip input). - auto g_recv_t = make_nvte_tensor(buf.recv_topk_weights.get(), - {buf.recv_capacity}, kNVTEFloat32); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, - NVTECommWindow{}, g_recv_t.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + auto g_recv_t = TensorWrapper(buf.recv_topk_weights.get(), + std::vector{buf.recv_capacity}, DType::kFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), + NVTECommWindow{}, g_recv_t.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector h_grad_w(num_tokens_ * top_k_); - CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + NVTE_CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); const float kTol = 1e-5f; @@ -566,57 +589,81 @@ TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } // ============================================================================= // Integrated FwdBwd: NaN/Inf check end-to-end. // ============================================================================= -class EPPipelineTest : public EpOpTestBase {}; - -TEST_F(EPPipelineTest, FullForwardBackward) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, NVTECommWindow{}, t.topk_weights.tensor, - NVTECommWindow{}, t.recv_tokens.tensor, NVTECommWindow{}, - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, NVTECommWindow{}, - t.result.tensor, stream)); - - std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), - h_grad.size() * sizeof(nv_bfloat16), - cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, NVTECommWindow{}, - t.grad_expert.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, NVTECommWindow{}, - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); - ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); - - if (g_process_id == 0) printf(" FullForwardBackward: passed\n"); +class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface { + protected: + template + void run_full_forward_backward() { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + uint64_t handle_id = buf.handle_id; + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.data()}, t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + } +}; - CHECK_CUDA(cudaStreamDestroy(stream)); +TEST_P(EPPipelineTest, FullForwardBackward) { + const DType dtype = GetParam(); + // NCCL EP backend currently asserts ncclBfloat16 in ncclEpDispatch + // (contrib/nccl_ep/nccl_ep.cc); skip FP16/FP32 until the backend supports them. + if (dtype != DType::kBFloat16) { + GTEST_SKIP() << test::typeName(dtype) << " not yet supported by NCCL EP backend"; + } + switch (dtype) { + case DType::kBFloat16: run_full_forward_backward(); break; + case DType::kFloat16: run_full_forward_backward<__half> (); break; + case DType::kFloat32: run_full_forward_backward (); break; + default: FAIL() << "unsupported token dtype " << static_cast(dtype); + } + if (g_process_id == 0) + printf(" FullForwardBackward[%s]: passed\n", test::typeName(dtype).c_str()); } +INSTANTIATE_TEST_SUITE_P( + Dtypes, EPPipelineTest, + ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32), + [](const ::testing::TestParamInfo& info) { + return test::typeName(info.param); + }); + // ============================================================================= // EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached // to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). @@ -646,9 +693,9 @@ struct SymmBuf { void alloc(size_t n_bytes) { bytes = n_bytes; - ASSERT_NCCL_OK(ncclMemAlloc(&ptr, bytes)); - CHECK_CUDA(cudaMemset(ptr, 0, bytes)); - ASSERT_NCCL_OK(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NVTE_CHECK_NCCL(ncclMemAlloc(&ptr, bytes)); + NVTE_CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + NVTE_CHECK_NCCL(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, NCCL_WIN_COLL_SYMMETRIC)); } }; @@ -666,34 +713,34 @@ class EPZeroCopyTest : public EpOpTestBase {}; // vs HBM reference (same routing, same input). TEST_F(EPZeroCopyTest, IdentityAllSymm) { // HBM reference run. - EPBuffers ref_buf; + EPBuffers<> ref_buf; ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(ref_buf); - EPTensors ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); uint64_t ref_hid = ref_buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, ref_t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.topk_idx.tensor, - ref_t.tokens.tensor, NVTECommWindow{}, ref_t.topk_weights.tensor, - NVTECommWindow{}, ref_t.recv_tokens.tensor, NVTECommWindow{}, - ref_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.tensor}, ref_t.recv_tokens.tensor, NVTECommWindow{}, - ref_t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(), ref_t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.topk_idx.data(), + ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), + NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{ref_hid, ref_t.handle_mem.data()}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); std::vector ref_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + NVTE_CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. - EPBuffers sym_buf; // alloc all buffers except the symm ones. + EPBuffers<> sym_buf; // alloc all buffers except the symm ones. sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, ep_size_, max_tokens_per_rank_); upload_inputs(sym_buf); @@ -704,32 +751,32 @@ TEST_F(EPZeroCopyTest, IdentityAllSymm) { // Stage same tokens into the symm-mem input. auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + NVTE_CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - EPTensors sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + EPTensors<> sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. - sym_t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - sym_t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {sym_buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); + sym_t.tokens = TensorWrapper(sym_tokens.ptr, + std::vector{(size_t)num_tokens_, (size_t)hidden_dim_}, DType::kBFloat16); + sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, + std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, DType::kBFloat16); uint64_t sym_hid = sym_buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, sym_t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.topk_idx.tensor, - sym_t.tokens.tensor, symm_window(sym_tokens), - sym_t.topk_weights.tensor, NVTECommWindow{}, - sym_t.recv_tokens.tensor, symm_window(sym_recv), - sym_t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.tensor}, sym_t.recv_tokens.tensor, - symm_window(sym_recv), sym_t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); + ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(), sym_t.token_counts.data(), /*alignment=*/0, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.topk_idx.data(), + sym_t.tokens.data(), symm_window(sym_tokens), + sym_t.topk_weights.data(), NVTECommWindow{}, + sym_t.recv_tokens.data(), symm_window(sym_recv), + sym_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{sym_hid, sym_t.handle_mem.data()}, sym_t.recv_tokens.data(), + symm_window(sym_recv), sym_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); std::vector sym_result(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + NVTE_CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + NVTE_CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); // Compare per filled recv slot (HBM ref vs symm) and full result. @@ -744,141 +791,9 @@ TEST_F(EPZeroCopyTest, IdentityAllSymm) { if (g_process_id == 0) printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// Same buffers, 2 iterations — catches window-lifecycle regressions where the -// symm-mem registration goes stale between calls. -TEST_F(EPZeroCopyTest, IdentityAllSymmRepeated) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - - SymmBuf sym_tokens, sym_recv; - sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - for (int iter = 0; iter < 2; ++iter) { - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, symm_window(sym_tokens), - t.topk_weights.tensor, NVTECommWindow{}, - t.recv_tokens.tensor, symm_window(sym_recv), - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, - symm_window(sym_recv), t.result.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), buf.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - for (int tok = 0; tok < num_tokens_; ++tok) { - float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); - float got = __bfloat162float(h_res[tok * hidden_dim_]); - ASSERT_NEAR(got, exp, bf16_tol(exp)) << "iter " << iter << " tok " << tok; - } - } - - if (g_process_id == 0) - printf(" IdentityAllSymmRepeated: passed (2 iters)\n"); - - CHECK_CUDA(cudaStreamDestroy(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); } -// Full forward+backward with symm-mem on every spec-mandated buffer: -// dispatch i/o, combine input, combine_bwd i/o, dispatch_bwd input. -// TODO: flaky on rank 0 (grad_tokens partial-zero) when run after the prior -// EPZeroCopyTest cases in the same binary; passes in isolation. Re-enable once -// the root cause (likely NCCL EP NVLS write→read coherence on grad_expert) is -// understood. Tracked separately. -TEST_F(EPZeroCopyTest, DISABLED_FullPipelineSymm) { - EPBuffers buf; - buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, - ep_size_, max_tokens_per_rank_); - upload_inputs(buf); - - // Symm-mem: tokens (dispatch input), recv_tokens (dispatch output AND - // combine input), grad_result (combine_bwd input), grad_expert - // (combine_bwd output AND dispatch_bwd input). - SymmBuf sym_tokens, sym_recv, sym_grad_result, sym_grad_expert; - sym_tokens .alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_recv .alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - sym_grad_result.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); - sym_grad_expert.alloc(buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); - - auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); - CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); - t.tokens = make_nvte_tensor(sym_tokens.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.recv_tokens = make_nvte_tensor(sym_recv.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - t.grad_result = make_nvte_tensor(sym_grad_result.ptr, - {(size_t)num_tokens_, (size_t)hidden_dim_}, kNVTEBFloat16); - t.grad_expert = make_nvte_tensor(sym_grad_expert.ptr, - {buf.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - uint64_t handle_id = buf.handle_id; - ASSERT_NO_THROW(nvte_ep_prepare(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, t.token_counts.tensor, /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.topk_idx.tensor, - t.tokens.tensor, symm_window(sym_tokens), - t.topk_weights.tensor, NVTECommWindow{}, - t.recv_tokens.tensor, symm_window(sym_recv), - t.recv_topk_weights.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.recv_tokens.tensor, - symm_window(sym_recv), t.result.tensor, stream)); - - std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); - CHECK_CUDA(cudaMemcpyAsync(sym_grad_result.ptr, h_grad.data(), - h_grad.size() * sizeof(nv_bfloat16), - cudaMemcpyHostToDevice, stream)); - CHECK_CUDA(cudaMemsetAsync(sym_grad_expert.ptr, 0, sym_grad_expert.bytes, stream)); - CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); - CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); - - ASSERT_NO_THROW(nvte_ep_combine_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_result.tensor, - symm_window(sym_grad_result), t.grad_expert.tensor, - symm_window(sym_grad_expert), stream)); - ASSERT_NO_THROW(nvte_ep_dispatch_bwd(NVTEEpHandle{handle_id, t.handle_mem.tensor}, t.grad_expert.tensor, - symm_window(sym_grad_expert), - t.g_recv_topk_weights.tensor, NVTECommWindow{}, - t.grad_tokens.tensor, t.grad_topk_weights.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); - ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); - - std::vector h_gt(num_tokens_ * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), - h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const float kExpGrad = static_cast(top_k_) * 0.1f; - for (int tok = 0; tok < num_tokens_; ++tok) - EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_]), kExpGrad, bf16_tol(kExpGrad)) - << "grad_tokens token " << tok; - - if (g_process_id == 0) printf(" FullPipelineSymm: passed\n"); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} // ── main ────────────────────────────────────────────────────────────────────── diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index ccb20ee3a0..135a39416e 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -13,157 +13,67 @@ #include #include +#include #include #include +#include #include -#include #include #include #include #include -#include #include #include #include #include +#include "../cpp/test_common.h" +#include "util/logging.h" -// ── Error-checking macros ───────────────────────────────────────────────────── +using transformer_engine::DType; +using transformer_engine::TensorWrapper; -#define CHECK_NCCL(expr) \ - do { \ - ncclResult_t _err = (expr); \ - if (_err != ncclSuccess) \ - FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ - } while (false) - -#define CHECK_CUDA(expr) \ - do { \ - cudaError_t _err = (expr); \ - if (_err != cudaSuccess) \ - FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ - } while (false) - -#define ASSERT_CUDA_OK(expr) \ - do { \ - cudaError_t _err = (expr); \ - if (_err != cudaSuccess) { \ - fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (false) - -#define ASSERT_NCCL_OK(expr) \ - do { \ - ncclResult_t _err = (expr); \ - if (_err != ncclSuccess) { \ - fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ - exit(EXIT_FAILURE); \ - } \ +#define CHECK_MPI(expr) \ + do { \ + int _err_mpi = (expr); \ + NVTE_CHECK(_err_mpi == MPI_SUCCESS, "MPI error: ", _err_mpi); \ } while (false) // ── Process-level state ─────────────────────────────────────────────────────── static int g_process_id = -1; static int g_num_processes = -1; -static std::string g_uid_file; static int g_sm_major = -1; // set by ep_bootstrap; -1 until then static int g_ep_size = -1; static int g_num_experts = -1; static int g_hidden_dim = 256; static int g_max_tokens_per_rank = 64; -static NVTEDType g_token_dtype = kNVTEBFloat16; +static NVTEDType g_max_token_dtype = kNVTEFloat32; // staging-buffer sizing static bool g_ep_initialized = false; static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown -// ── TensorHandle RAII wrapper ───────────────────────────────────────────────── - -// View over a caller-owned device buffer; owns NVTETensor metadata only. Move-only. -struct TensorHandle { - NVTETensor tensor = nullptr; - void* dev_ptr = nullptr; - - ~TensorHandle() { - if (tensor) nvte_destroy_tensor(tensor); - } - - TensorHandle() = default; - TensorHandle(const TensorHandle&) = delete; - TensorHandle& operator=(const TensorHandle&) = delete; - - TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { - o.tensor = nullptr; o.dev_ptr = nullptr; - } - TensorHandle& operator=(TensorHandle&& o) noexcept { - if (this != &o) { - if (tensor) nvte_destroy_tensor(tensor); - tensor = o.tensor; dev_ptr = o.dev_ptr; - o.tensor = nullptr; o.dev_ptr = nullptr; - } - return *this; - } -}; - -static TensorHandle make_nvte_tensor(void* dev_ptr, - const std::vector& shape, - NVTEDType dtype) { - TensorHandle h; - h.dev_ptr = dev_ptr; - h.tensor = nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING); - - NVTEShape s; - s.ndim = shape.size(); - for (size_t i = 0; i < shape.size(); ++i) s.data[i] = shape[i]; - - NVTEBasicTensor bt; - bt.data_ptr = dev_ptr; - bt.dtype = dtype; - bt.shape = s; - nvte_set_tensor_param_v2(h.tensor, kNVTERowwiseData, &bt, sizeof(bt)); - - return h; -} - -// RAII owner for a cudaMalloc'd device buffer; frees on destruction. +// RAII owner for a cudaMalloc'd device buffer; element-count API on top of +// test::CudaPtr. template struct DevBuf { - T* ptr = nullptr; + test::CudaPtr ptr; size_t count = 0; DevBuf() = default; explicit DevBuf(size_t n) { alloc(n); } - ~DevBuf() { reset(); } - - DevBuf(const DevBuf&) = delete; - DevBuf& operator=(const DevBuf&) = delete; - DevBuf(DevBuf&& o) noexcept : ptr(o.ptr), count(o.count) { o.ptr = nullptr; o.count = 0; } - DevBuf& operator=(DevBuf&& o) noexcept { - if (this != &o) { reset(); ptr = o.ptr; count = o.count; o.ptr = nullptr; o.count = 0; } - return *this; - } void alloc(size_t n) { - reset(); count = n; - if (n > 0) { - cudaError_t e = cudaMalloc(&ptr, n * sizeof(T)); - if (e != cudaSuccess) { - fprintf(stderr, "DevBuf cudaMalloc(%zu) failed: %s\n", n * sizeof(T), - cudaGetErrorString(e)); - ptr = nullptr; - count = 0; - } - } + ptr = (n > 0) ? test::cuda_alloc(n * sizeof(T)) : test::CudaPtr{}; } - void reset() { - if (ptr) { cudaFree(ptr); ptr = nullptr; } + ptr.reset(); count = 0; } - T* get() const { return ptr; } + T* get() const { return ptr.get(); } size_t bytes() const { return count * sizeof(T); } }; @@ -180,39 +90,11 @@ static inline std::vector routing_balanced( return idx; } -// ── File-based ncclUniqueId exchange ───────────────────────────────────────── +// ── ncclUniqueId exchange via MPI ───────────────────────────────────────────── static void exchange_unique_id(ncclUniqueId* uid) { - const size_t sz = sizeof(ncclUniqueId); - - if (g_process_id == 0) { - ASSERT_NCCL_OK(ncclGetUniqueId(uid)); - FILE* f = fopen(g_uid_file.c_str(), "wb"); - if (!f) { fprintf(stderr, "Cannot open uid file: %s\n", g_uid_file.c_str()); exit(EXIT_FAILURE); } - fwrite(uid, 1, sz, f); - fclose(f); - } else { - auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); - while (true) { - FILE* f = fopen(g_uid_file.c_str(), "rb"); - if (f) { - fseek(f, 0, SEEK_END); - if (static_cast(ftell(f)) >= sz) { - fseek(f, 0, SEEK_SET); - size_t n = fread(uid, 1, sz, f); - fclose(f); - if (n == sz) break; - } else { - fclose(f); - } - } - if (std::chrono::steady_clock::now() > deadline) { - fprintf(stderr, "Process %d: timed out waiting for uid file\n", g_process_id); - exit(EXIT_FAILURE); - } - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - } + if (g_process_id == 0) NVTE_CHECK_NCCL(ncclGetUniqueId(uid)); + CHECK_MPI(MPI_Bcast(uid, sizeof(*uid), MPI_BYTE, 0, MPI_COMM_WORLD)); } // ── CLI parsing ─────────────────────────────────────────────────────────────── @@ -220,26 +102,8 @@ static void exchange_unique_id(ncclUniqueId* uid) { static void ep_parse_args(int argc, char* argv[]) { for (int i = 1; i < argc; ++i) { std::string a(argv[i]); - if (a.rfind("--process-id=", 0) == 0) g_process_id = std::stoi(a.substr(13)); - else if (a.rfind("--rank=", 0) == 0) g_process_id = std::stoi(a.substr(7)); - else if (a.rfind("--num-processes=",0)==0) g_num_processes = std::stoi(a.substr(16)); - else if (a.rfind("--nranks=", 0) == 0) g_num_processes = std::stoi(a.substr(9)); - else if (a.rfind("--uid-file=", 0) == 0) g_uid_file = a.substr(11); - else if (a.rfind("--token-dtype=", 0) == 0) - g_token_dtype = static_cast(std::stoi(a.substr(14))); - } - - if (g_process_id < 0 || g_num_processes <= 0) { - fprintf(stderr, - "Usage: %s --rank=N --nranks=N [--uid-file=path] [gtest flags]\n" - " Aliases: --process-id=N, --num-processes=N\n", - argc > 0 ? argv[0] : "test_ep"); - exit(EXIT_FAILURE); - } - - if (g_uid_file.empty()) { - const char* t = getenv("TMPDIR"); if (!t) t = "/tmp"; - g_uid_file = std::string(t) + "/te_ep_uid_" + std::to_string(g_process_id); + if (a.rfind("--max-token-dtype=", 0) == 0) + g_max_token_dtype = static_cast(std::stoi(a.substr(18))); } } @@ -247,6 +111,12 @@ static void ep_parse_args(int argc, char* argv[]) { // Returns false if the binary should exit without running tests (wrong SM, etc.). static bool ep_bootstrap(int argc, char* argv[]) { + int mpi_initialized = 0; + MPI_Initialized(&mpi_initialized); + if (!mpi_initialized) CHECK_MPI(MPI_Init(&argc, &argv)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &g_process_id)); + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &g_num_processes)); + ep_parse_args(argc, argv); ::testing::InitGoogleTest(&argc, argv); @@ -282,9 +152,9 @@ static bool ep_bootstrap(int argc, char* argv[]) { // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; group_config.hidden_dim = g_hidden_dim; - group_config.token_dtype = g_token_dtype; + group_config.max_token_dtype = g_max_token_dtype; - ASSERT_NCCL_OK(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); nvte_ep_initialize(static_cast(g_ep_comm), group_config); if (g_process_id == 0) { @@ -308,5 +178,7 @@ static void ep_teardown() { } g_ep_initialized = false; } - if (g_process_id == 0) remove(g_uid_file.c_str()); + int finalized = 0; + MPI_Finalized(&finalized); + if (!finalized) MPI_Finalize(); } diff --git a/tests/cpp_distributed/test_ep_coverage.cu b/tests/cpp_distributed/test_ep_coverage.cu deleted file mode 100644 index e9e532386c..0000000000 --- a/tests/cpp_distributed/test_ep_coverage.cu +++ /dev/null @@ -1,562 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/* - * EP C-API coverage tests (paths not exercised by the pipeline suite). - * - * MultiHandleAllocTest — distinct handle ids; each works end-to-end. - * TopK1Test — top_k=1 dispatch/combine/bwd round-trip. - * EmptyExpertsTest — alignment ∈ {0, 2, 8, 16} with experts receiving 0 tokens. - * NegativeTests — alignment mismatch and null handle_mem must throw. - */ - -#include "test_ep_common.h" - -#include -#include - -// top1 -> expert 0, top2 -> expert 2; leaves local-expert 1 empty between two -// full experts. Requires top_k >= 2 and num_experts >= 3. -static std::vector routing_skip_middle(int num_tokens, int top_k) { - std::vector idx(num_tokens * top_k); - for (int t = 0; t < num_tokens; ++t) { - idx[t * top_k + 0] = 0; - if (top_k >= 2) idx[t * top_k + 1] = 2; - for (int k = 2; k < top_k; ++k) idx[t * top_k + k] = 2 + k; // distinct stragglers - } - return idx; -} - -static std::vector tokens_constant(int num_tokens, int hidden_dim, float val) { - std::vector v(num_tokens * hidden_dim); - nv_bfloat16 b = __float2bfloat16(val); - std::fill(v.begin(), v.end(), b); - return v; -} - -namespace { - -class EpCoverageBase : public ::testing::Test { - protected: - int ep_size_, num_experts_, num_local_experts_, hidden_dim_; - int max_tokens_per_rank_; - - void SetUp() override { - if (g_sm_major < 9) - GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; - ASSERT_GE(g_num_processes, 2); - ASSERT_TRUE(g_ep_initialized); - ep_size_ = g_ep_size; - num_experts_ = g_num_experts; - num_local_experts_ = num_experts_ / ep_size_; - hidden_dim_ = g_hidden_dim; - max_tokens_per_rank_ = g_max_tokens_per_rank; - } - - // Helper: allocate buffers + tensor views for a single dispatch+combine. - struct Bundle { - DevBuf topk_idx; - DevBuf topk_weights; - DevBuf tokens; - DevBuf token_counts; - DevBuf handle_mem; - DevBuf recv_tokens; - DevBuf recv_topk_weights; - DevBuf result; - uint64_t handle_id = 0; - size_t handle_mem_size = 0; - size_t recv_capacity = 0; - }; - - Bundle make_bundle(int num_tokens, int top_k, int num_local_experts, - size_t alignment) { - Bundle b; - b.recv_capacity = static_cast(ep_size_) * max_tokens_per_rank_ * 2; - b.topk_idx.alloc(num_tokens * top_k); - b.topk_weights.alloc(num_tokens * top_k); - b.tokens.alloc(num_tokens * hidden_dim_); - b.token_counts.alloc(num_local_experts); - b.recv_tokens.alloc(b.recv_capacity * hidden_dim_); - b.recv_topk_weights.alloc(b.recv_capacity); - b.result.alloc(num_tokens * hidden_dim_); - NVTEEpLayerConfig cfg{num_local_experts, top_k, alignment}; - b.handle_id = nvte_ep_register_layer(cfg, &b.handle_mem_size); - b.handle_mem.alloc(b.handle_mem_size); - return b; - } -}; - -} // namespace - -// ============================================================================= -// MultiHandleAllocTest: ids are distinct and each is independently usable. -// ============================================================================= - -class MultiHandleAllocTest : public EpCoverageBase {}; - -TEST_F(MultiHandleAllocTest, IdsAreDistinct) { - NVTEEpLayerConfig cfg{num_local_experts_, /*top_k=*/2, /*alignment=*/0}; - const int kN = 8; - std::vector ids(kN); - for (int i = 0; i < kN; ++i) { - size_t sz = 0; - ids[i] = nvte_ep_register_layer(cfg, &sz); - } - for (int i = 0; i < kN; ++i) { - EXPECT_NE(ids[i], 0u) << "handle_id 0 is reserved as \"no id\""; - for (int j = i + 1; j < kN; ++j) - EXPECT_NE(ids[i], ids[j]) << "duplicate id " << ids[i] << " at indices " << i << ", " << j; - } -} - -TEST_F(MultiHandleAllocTest, TwoHandlesCoexist) { - const int num_tokens = 16, top_k = 2; - Bundle a = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - for (Bundle* x : {&a, &b}) { - CHECK_CUDA(cudaMemcpy(x->topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(x->topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(x->tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - } - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - ASSERT_NE(a.handle_id, b.handle_id); - - auto run_one = [&](Bundle& x) { - auto topk_idx = make_nvte_tensor(x.topk_idx.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights = make_nvte_tensor(x.topk_weights.get(), {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts = make_nvte_tensor(x.token_counts.get(), {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem = make_nvte_tensor(x.handle_mem.get(), {x.handle_mem_size}, kNVTEByte); - auto tokens = make_nvte_tensor(x.tokens.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens = make_nvte_tensor(x.recv_tokens.get(), {x.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w = make_nvte_tensor(x.recv_topk_weights.get(), {x.recv_capacity}, kNVTEFloat32); - auto result = make_nvte_tensor(x.result.get(), {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - NVTEEpHandle h{x.handle_id, handle_mem.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx.tensor, token_counts.tensor, - /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx.tensor, tokens.tensor, - NVTECommWindow{}, topk_weights.tensor, NVTECommWindow{}, - recv_tokens.tensor, NVTECommWindow{}, recv_w.tensor, - NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens.tensor, NVTECommWindow{}, - result.tensor, stream)); - }; - run_one(a); - run_one(b); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // Both round-trips must produce result == top_k * 0.5 = 1.0. - for (Bundle* x : {&a, &b}) { - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), x->result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - } - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// TopK1Test: top_k=1 dispatch/combine round-trip, including dispatch_bwd. -// ============================================================================= - -class TopK1Test : public EpCoverageBase {}; - -TEST_F(TopK1Test, RoundTrip) { - const int num_tokens = 16, top_k = 1; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f); // top_k=1: weight is unity - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.25f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, - tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, - NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, - NVTECommWindow{}, result_t.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // top_k=1: combine is unweighted gather, so result[t] == tokens[t]. - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), 0.25f, 1e-2f) - << "tok " << t << " hidden " << p; - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// EmptyExpertsTest: alignment ∈ {0, 2, 8, 16}, only local-expert 0 receives -// tokens. Round-trip must produce result == top_k * tokens regardless of the -// per-expert padding choice. -// ============================================================================= - -class EmptyExpertsTest : public EpCoverageBase, - public ::testing::WithParamInterface {}; - -TEST_P(EmptyExpertsTest, RoundTripCorrect) { - // routing_skip_middle needs experts {0, 2, ...}; smallest viable num_experts is 3. - ASSERT_GE(num_experts_, 3); - const size_t alignment = GetParam(); - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, alignment); - - // top1 -> expert 0, top2 -> expert 2; rank 0's local-expert 1 receives 0 - // tokens between two non-empty experts. - std::vector h_idx = routing_skip_middle(num_tokens, top_k); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.3f); - - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim_}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim_}, kNVTEBFloat16); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - alignment, stream)); - ASSERT_NO_THROW(nvte_ep_dispatch(h, topk_idx_t.tensor, - tokens_t.tensor, NVTECommWindow{}, topk_weights_t.tensor, - NVTECommWindow{}, recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream)); - ASSERT_NO_THROW(nvte_ep_combine(h, recv_tokens_t.tensor, - NVTECommWindow{}, result_t.tensor, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - // Identity expert + uniform weights: result[t] == top_k * tokens[t]. - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const float expected = static_cast(top_k) * 0.3f; - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), expected, 1e-2f) - << "alignment=" << alignment << " tok=" << t << " hidden=" << p; - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -INSTANTIATE_TEST_SUITE_P(Alignments, EmptyExpertsTest, - ::testing::Values(0, 2, 8, 16)); - -// ============================================================================= -// NegativeTests: prepare/dispatch must surface bad inputs as exceptions. -// ============================================================================= - -class NegativeTests : public EpCoverageBase {}; - -TEST_F(NegativeTests, AlignmentMismatchThrows) { - const int num_tokens = 8, top_k = 2; - // Allocate handle for alignment=0, then call prepare with alignment=16. - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/16, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(NegativeTests, NullHandleMemThrows) { - const int num_tokens = 8, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - // Construct a tensor view backed by a null device pointer. - auto null_hm_t = make_nvte_tensor(nullptr, {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - NVTEEpHandle h{b.handle_id, null_hm_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ============================================================================= -// HandleCacheTest: persistent ncclEpHandle is reused across ops on the same -// handle_mem ptr; relocation triggers throw by default and rebuild when -// NVTEEpGroupConfig.allow_handle_mem_reloc=1. -// ============================================================================= - -class HandleCacheTest : public EpCoverageBase {}; - -// Run prepare → dispatch → combine on bundle b. handle_mem_data overrides the -// device ptr used for handle_mem (must be the buffer owned by b unless -// reloc-allowed mode is active). Templated on Bundle because EpCoverageBase:: -// Bundle is declared in a protected section. -template -static void run_round_trip(B& b, void* handle_mem_data, - int num_tokens, int top_k, int num_local_experts, - int hidden_dim, size_t alignment, - cudaStream_t stream) { - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto topk_weights_t = make_nvte_tensor(b.topk_weights.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEFloat32); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts}, kNVTEInt32); - auto handle_mem_t = make_nvte_tensor(handle_mem_data, - {b.handle_mem_size}, kNVTEByte); - auto tokens_t = make_nvte_tensor(b.tokens.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - auto recv_tokens_t = make_nvte_tensor(b.recv_tokens.get(), - {b.recv_capacity, (size_t)hidden_dim}, kNVTEBFloat16); - auto recv_w_t = make_nvte_tensor(b.recv_topk_weights.get(), - {b.recv_capacity}, kNVTEFloat32); - auto result_t = make_nvte_tensor(b.result.get(), - {(size_t)num_tokens, (size_t)hidden_dim}, kNVTEBFloat16); - - NVTEEpHandle h{b.handle_id, handle_mem_t.tensor}; - nvte_ep_prepare(h, topk_idx_t.tensor, token_counts_t.tensor, alignment, stream); - nvte_ep_dispatch(h, topk_idx_t.tensor, tokens_t.tensor, NVTECommWindow{}, - topk_weights_t.tensor, NVTECommWindow{}, - recv_tokens_t.tensor, NVTECommWindow{}, - recv_w_t.tensor, NVTECommWindow{}, stream); - nvte_ep_combine(h, recv_tokens_t.tensor, NVTECommWindow{}, result_t.tensor, stream); -} - -// Re-bootstrap EP backend with a different allow_handle_mem_reloc setting. -// Reuses the existing g_ep_comm; caller is responsible for restoring defaults. -static void reinit_ep_with_reloc(int allow_reloc) { - nvte_ep_shutdown(); - NVTEEpGroupConfig cfg{}; - cfg.ep_size = g_ep_size; - cfg.num_experts = g_num_experts; - cfg.max_tokens_per_rank = g_max_tokens_per_rank; - cfg.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; - cfg.hidden_dim = g_hidden_dim; - cfg.allow_handle_mem_reloc = allow_reloc; - nvte_ep_initialize(static_cast(g_ep_comm), cfg); -} - -TEST_F(HandleCacheTest, ReuseSameMemSucceeds) { - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // Two consecutive round-trips on the same handle_mem ptr: first opens the - // cached handle, second hits the cache. Both must succeed and be correct. - for (int iter = 0; iter < 2; ++iter) { - ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - } - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(HandleCacheTest, RelocDefaultThrows) { - // Default bootstrap has allow_handle_mem_reloc=0: a second prepare call on - // the same handle_id with a different handle_mem ptr must throw. - const int num_tokens = 8, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - DevBuf second_hm(b.handle_mem_size); // distinct device buffer - ASSERT_NE(b.handle_mem.get(), second_hm.get()); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - - auto topk_idx_t = make_nvte_tensor(b.topk_idx.get(), - {(size_t)num_tokens, (size_t)top_k}, kNVTEInt64); - auto token_counts_t = make_nvte_tensor(b.token_counts.get(), - {(size_t)num_local_experts_}, kNVTEInt32); - auto hm1_t = make_nvte_tensor(b.handle_mem.get(), - {b.handle_mem_size}, kNVTEByte); - auto hm2_t = make_nvte_tensor(second_hm.get(), - {b.handle_mem_size}, kNVTEByte); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // First prepare seeds the cache. - NVTEEpHandle h1{b.handle_id, hm1_t.tensor}; - ASSERT_NO_THROW(nvte_ep_prepare(h1, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - // Same handle_id with a different handle_mem ptr must throw. - NVTEEpHandle h2{b.handle_id, hm2_t.tensor}; - EXPECT_THROW(nvte_ep_prepare(h2, topk_idx_t.tensor, token_counts_t.tensor, - /*alignment=*/0, stream), - std::exception); - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -TEST_F(HandleCacheTest, RelocAllowedRebuilds) { - // Re-init EP backend with allow_handle_mem_reloc=1, run two round-trips with - // distinct handle_mem buffers, verify both succeed numerically, restore. - reinit_ep_with_reloc(/*allow_reloc=*/1); - - struct Restore { ~Restore() { reinit_ep_with_reloc(/*allow_reloc=*/0); } } restore; - - const int num_tokens = 16, top_k = 2; - Bundle b = make_bundle(num_tokens, top_k, num_local_experts_, /*alignment=*/0); - DevBuf alt_hm(b.handle_mem_size); - ASSERT_NE(b.handle_mem.get(), alt_hm.get()); - - auto h_idx = routing_balanced(g_process_id, num_tokens, top_k, - num_experts_, num_local_experts_); - std::vector h_w(num_tokens * top_k, 1.0f / top_k); - auto h_tok = tokens_constant(num_tokens, hidden_dim_, 0.5f); - CHECK_CUDA(cudaMemcpy(b.topk_idx.get(), h_idx.data(), - h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.topk_weights.get(), h_w.data(), - h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); - CHECK_CUDA(cudaMemcpy(b.tokens.get(), h_tok.data(), - h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); - - cudaStream_t stream; - CHECK_CUDA(cudaStreamCreate(&stream)); - - // First on the original handle_mem. - ASSERT_NO_THROW(run_round_trip(b, b.handle_mem.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - // Then on the relocated handle_mem — must trigger silent rebuild, not throw. - ASSERT_NO_THROW(run_round_trip(b, alt_hm.get(), num_tokens, top_k, - num_local_experts_, hidden_dim_, - /*alignment=*/0, stream)); - CHECK_CUDA(cudaStreamSynchronize(stream)); - - std::vector h_res(num_tokens * hidden_dim_); - CHECK_CUDA(cudaMemcpy(h_res.data(), b.result.get(), - h_res.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); - const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; - for (int t = 0; t < num_tokens; ++t) - for (int p : probes) - EXPECT_NEAR(__bfloat162float(h_res[t * hidden_dim_ + p]), - static_cast(top_k) * 0.5f, 1e-2f); - - CHECK_CUDA(cudaStreamDestroy(stream)); -} - -// ── main ────────────────────────────────────────────────────────────────────── - -int main(int argc, char* argv[]) { - if (!ep_bootstrap(argc, argv)) return 0; - int ret = RUN_ALL_TESTS(); - ep_teardown(); - return ret; -} diff --git a/tests/cpp_distributed/test_ep_init.cu b/tests/cpp_distributed/test_ep_init.cu deleted file mode 100644 index 08744dfee5..0000000000 --- a/tests/cpp_distributed/test_ep_init.cu +++ /dev/null @@ -1,64 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/* - * Unit tests for EP initialization paths. - * - * Tests: - * EPInitTest/InitPath — backend is live after init, handle_mem_size > 0 - * EPInitTest/NumLocalExperts — handle_mem_size is consistent across num_local_experts values - * - * Run via run_test_ep.sh (both uid and comm init paths are tested by the script). - */ - -#include "test_ep_common.h" - -// ── Fixture ─────────────────────────────────────────────────────────────────── - -class EPInitTest : public ::testing::Test { - protected: - void SetUp() override { - if (g_sm_major < 9) - GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; - ASSERT_GE(g_num_processes, 2) << "EP tests require at least 2 processes"; - ASSERT_TRUE(g_ep_initialized) << "EP not initialized"; - } -}; - -// ── Tests ───────────────────────────────────────────────────────────────────── - -TEST_F(EPInitTest, InitPath) { - int nle = g_num_experts / g_ep_size; - NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; - size_t sz = 0; - (void)nvte_ep_register_layer(cfg, &sz); - ASSERT_GT(sz, 0u) << "handle_mem_size must be > 0 after init"; - - if (g_process_id == 0) { - printf(" handle_mem : %zu bytes\n", sz); - } -} - -TEST_F(EPInitTest, NumLocalExperts) { - // handle_mem_size should be > 0 for any valid num_local_experts value. - for (int nle : {1, g_num_experts / g_ep_size}) { - NVTEEpLayerConfig cfg{nle, /*top_k=*/2}; - size_t sz = 0; - (void)nvte_ep_register_layer(cfg, &sz); - ASSERT_GT(sz, 0u) << "num_local_experts=" << nle; - if (g_process_id == 0) - printf(" nle=%-3d handle_mem_size=%zu bytes\n", nle, sz); - } -} - -// ── main ────────────────────────────────────────────────────────────────────── - -int main(int argc, char* argv[]) { - if (!ep_bootstrap(argc, argv)) return 0; - int ret = RUN_ALL_TESTS(); - ep_teardown(); - return ret; -} diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 1e08cb55df..a5ae99b089 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -82,11 +82,11 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.token_dtype >= 0 && config.token_dtype < kNVTENumTypes, - "token_dtype out of range, got ", static_cast(config.token_dtype)); - const size_t elem_bytes = typeToSize(static_cast(config.token_dtype)); + NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, + "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); NVTE_CHECK(config.hidden_dim * elem_bytes >= 16, - "hidden_dim * sizeof(token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "hidden_dim * sizeof(max_token_dtype) must be >= 16 (NCCL EP 16B row alignment); " "got hidden_dim=", config.hidden_dim, ", element_bytes=", elem_bytes); NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, @@ -218,7 +218,7 @@ void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; cfg.num_experts = static_cast(group_config.num_experts); cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); - const size_t elem_bytes = typeToSize(static_cast(group_config.token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(group_config.max_token_dtype)); cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); cfg.rdma_buffer_size = NCCL_EP_AUTO; cfg.num_qp_per_rank = NCCL_EP_AUTO; @@ -346,10 +346,10 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape tok_shape = nvte_tensor_shape(tokens); NVTEDType tok_dtype = nvte_tensor_type(tokens); - NVTE_CHECK(tok_dtype == group_config_.token_dtype, - "tokens dtype (", static_cast(tok_dtype), - ") does not match group token_dtype (", - static_cast(group_config_.token_dtype), ")"); + NVTE_CHECK(typeToSize(static_cast(tok_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "tokens dtype (", static_cast(tok_dtype), ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); const size_t num_tokens = tok_shape.data[0]; const size_t hidden_dim = tok_shape.data[1]; @@ -376,10 +376,11 @@ void EPBackend::dispatch(uint64_t handle_id, void* handle_mem, const NVTETensor NVTEShape recv_shape = nvte_tensor_shape(recv_tokens); NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); - NVTE_CHECK(recv_dtype == group_config_.token_dtype, + NVTE_CHECK(typeToSize(static_cast(recv_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), "recv_tokens dtype (", static_cast(recv_dtype), - ") does not match group token_dtype (", - static_cast(group_config_.token_dtype), ")"); + ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); size_t recv_sizes[2] = {recv_shape.data[0], recv_shape.data[1]}; ncclEpTensor_t nccl_tokens_out = make_payload_tensor(recv_tokens, recv_tokens_win, 2, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index a1c9305e9b..22e7ec48ac 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -23,6 +23,8 @@ extern "C" { #endif /* ── Config structs ─────────────────────────────────────────────────────── */ +/* TODO: add a struct_size/version field to these configs (and align with other + * TE public structs) once a TE-wide convention for ABI versioning lands. */ /*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ typedef struct { @@ -35,9 +37,10 @@ typedef struct { int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ int allow_handle_mem_reloc; - /*! Token dtype for this EP group. Sizes NCCL EP staging buffers at group - * create and is enforced against tensors passed to nvte_ep_dispatch. */ - NVTEDType token_dtype; + /*! Widest token dtype the group will dispatch. Sizes NCCL EP staging buffers + * at group create. Tensors passed to nvte_ep_dispatch may use any dtype whose + * element size is <= sizeof(max_token_dtype). */ + NVTEDType max_token_dtype; } NVTEEpGroupConfig; /*! \brief Per-layer EP configuration. */ @@ -58,8 +61,8 @@ typedef struct { * nvte_ep_shutdown() returns; destroying it earlier is undefined behavior. * Re-init after shutdown is allowed; double-init throws. * - * v0.1 scope: one EP group per process, bound to the current CUDA device at - * initialize time. Multiple GPUs per process are not supported. + * One EP group per process, bound to the current CUDA device at initialize + * time. Multiple GPUs per process are not supported. * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. * \param[in] group_config Group-level EP configuration. diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index da8b9b377d..3308bd22e4 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -98,6 +98,14 @@ } \ } while (false) +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) + #ifdef NVTE_WITH_CUBLASMP #define NVTE_CHECK_CUBLASMP(expr) \ From d10189603939c54304429649ae54969633835eb4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 May 2026 23:05:43 +0000 Subject: [PATCH 20/36] Expert Parallelism: JAX bindings (FFI, custom_vjp, multi-process tests, MoE example) Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 41 +- examples/jax/ep/ep_moe.py | 396 ++++++++ examples/jax/ep/run_test_ep.sh | 85 ++ tests/jax/multi_process_launch_ep.sh | 67 ++ tests/jax/test_multi_process_ep.py | 690 +++++++++++++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/ep.py | 955 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 19 + transformer_engine/jax/csrc/extensions/ep.cpp | 457 +++++++++ .../jax/csrc/extensions/pybind.cpp | 18 + transformer_engine/jax/ep.py | 303 ++++++ transformer_engine/jax/sharding.py | 12 +- 12 files changed, 3041 insertions(+), 3 deletions(-) create mode 100644 examples/jax/ep/ep_moe.py create mode 100755 examples/jax/ep/run_test_ep.sh create mode 100755 tests/jax/multi_process_launch_ep.sh create mode 100644 tests/jax/test_multi_process_ep.py create mode 100644 transformer_engine/jax/cpp_extensions/ep.py create mode 100644 transformer_engine/jax/csrc/extensions/ep.cpp create mode 100644 transformer_engine/jax/ep.py diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..49c5001d18 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -103,13 +103,50 @@ def setup_jax_extension( setup_mpi_flags(include_dirs, cxx_flags) + # NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it. + build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))) + libraries = [] + submod_lib_dir = None + submod_nccl_inc = None + if build_with_nccl_ep: + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # Headers + libs come from the in-tree 3rdparty/nccl submodule build + # (auto-produced by setup.py). + libraries = ["nccl", "nccl_ep"] + # NCCL EP requires SM>=90 (Hopper+). + archs_env = os.getenv("NVTE_CUDA_ARCHS", "") + for a in archs_env.split(";"): + a_num = "".join(c for c in a if c.isdigit()) + if a_num and int(a_num) < 90: + raise RuntimeError( + f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in" + " NVTE_CUDA_ARCHS." + ) + submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve() + submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include" + if not (submod_ep_inc / "nccl_ep.h").exists(): + raise RuntimeError( + f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. " + "Run `git submodule update --init --recursive` to checkout 3rdparty/nccl." + ) + include_dirs.append(submod_ep_inc) + submod_lib_dir = submod_root / "build" / "lib" + submod_nccl_inc = submod_root / "build" / "include" + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension - return Pybind11Extension( + ext = Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"], + libraries=libraries, ) + if submod_lib_dir is not None: + ext.library_dirs.append(str(submod_lib_dir)) + ext.runtime_library_dirs.append(str(submod_lib_dir)) + # Prefer submodule's nccl.h when present (matches the C++ side). + if (submod_nccl_inc / "nccl.h").exists(): + ext.include_dirs.insert(0, str(submod_nccl_inc)) + return ext diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py new file mode 100644 index 0000000000..8dcac02a04 --- /dev/null +++ b/examples/jax/ep/ep_moe.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU. Run via run_test_ep.sh. +""" + +import argparse +import sys + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ── Setup ─────────────────────────────────────────────────────────────────── + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-experts", + type=int, + default=None, + help="Total experts across the EP group. Default: num_processes.", + ) + p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.") + p.add_argument( + "--check", + action="store_true", + default=True, + help="Verify fwd+bwd against a single-rank numpy reference.", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + assert ( + jax.local_device_count() == 1 + ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}" + + +def _build_mesh_and_resource(args): + """Pick a (2, 2) mesh by default. Override via --dp-size.""" + n = args.num_processes + if n < 4: + raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP") + if args.dp_size is None: + if n != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override" + ) + args.dp_size = 2 + assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}" + args.ep_size = n // args.dp_size + if args.num_experts is None: + args.num_experts = args.num_processes + assert args.num_experts % args.ep_size == 0 + args.num_local_experts = args.num_experts // args.ep_size + args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + + devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) + mesh = Mesh(devs, ("dp", "ep")) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + return mesh, mr + + +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" + topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) + for t in range(num_tokens): + for k in range(top_k): + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + return topk_idx + + +def _make_inputs(args): + """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. + + B = num_processes (sharded across the compound (dp,ep) axis so each rank + holds one slot); S = args.num_tokens. Global numpy views (rank-0 + reference) are kept 2D for the legacy reference implementation. + """ + T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out + E = args.num_experts + dp_size = args.dp_size + ep_size = args.ep_size + num_procs = args.num_processes + dp_color = args.process_id // ep_size + + rng_dp = np.random.default_rng(seed=42 + dp_color) + tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) + w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + + tokens_global_np = np.concatenate( + [ + ( + np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5 + ).astype(np.float32) + for c in range(dp_size) + ], + axis=0, + ) + topk_idx_global_np = np.concatenate( + [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 + ) + w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) + + # Same seed on every rank → identical kernel array everywhere. + rng = np.random.default_rng(seed=42) + kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + + # Each rank contributes one [1, T, ...] slab; the global shape is + # [num_procs, T, ...] sharded on the first dim across (dp, ep). + mesh = args.mesh + dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None)) + tokens = jax.make_array_from_process_local_data( + dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) + ).astype(jnp.bfloat16) + topk_idx = jax.make_array_from_process_local_data( + dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) + ) + topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) + kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + return ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) + + +# ── MoE step ──────────────────────────────────────────────────────────────── + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size): + """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]`` + (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``, + broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout + with ``H_out`` in place of ``H``.""" + num_procs, recv_pr, H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H] + grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H) + # Contract H; batch over (ep, NLE) which are present on both sides. + out = jax.lax.dot_general( + grouped, + kernels.astype(grouped.dtype), + dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))), + ) + # Output dim order from dot_general: batch dims first, then remaining lhs, rhs. + # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,) + # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out]. + out = jnp.transpose(out, (2, 0, 1, 3, 4)) + return out.reshape(num_procs, recv_pr, H_out) + + +def _moe_step(args, topk_idx, tokens, topk_w, kernels): + """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. + + Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across + ``("dp","ep")``. Combine returns the same 3D shape. + """ + B = args.num_processes + S = args.num_tokens + NLE = args.num_local_experts + dp_size, ep_size = args.dp_size, args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] + ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] + ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] + # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. + kernel_spec = PartitionSpec("ep", None, None, None) + + kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + + @jax.jit + def step(topk_idx, tokens, topk_w, local_kernels): + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) + slots_per_expert = args.recv_capacity_per_rank // NLE + recv_tokens, recv_topk_w, handle, _tc = ep_dispatch( + topk_idx, + tokens, + topk_w, + args.recv_capacity_per_rank, + dispatch_output_per_expert_alignment=slots_per_expert, + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + return ep_combine( + handle, + _tc, + expert_out, + recv_topk_w, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + return step(topk_idx, tokens, topk_w, kernels) + + +# ── Reference (numerical check) ───────────────────────────────────────────── + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +# ── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + args = _parse_args() + _distributed_init(args) + + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is not None: + major, minor = (int(x) for x in str(cap).split(".")) + if major * 10 + minor < 90: + print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})") + return + + args.mesh, args.mr = _build_mesh_and_resource(args) + + with args.mesh, global_shard_guard(args.mr): + ep_bootstrap( + world_size=args.num_processes, + rank=args.process_id, + ep_size=args.ep_size, + num_experts=args.num_experts, + max_tokens_per_rank=args.num_tokens, + recv_capacity_per_rank=args.recv_capacity_per_rank, + hidden_dim=args.hidden, + ) + + ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) = _make_inputs(args) + + def loss_fn(toks, idx, w, kern): + out = _moe_step(args, idx, toks, w, kern) + return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out + + (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))( + tokens, topk_idx, topk_w, kernels + ) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + + if args.process_id == 0: + print( + f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} " + f"dp={args.dp_size} ep={args.ep_size} " + f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) + + if args.check: + + def _norm(spec, ndim): + return tuple(spec) + (None,) * (ndim - len(spec)) + + # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can + # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both. + if args.dp_size > 1: + acceptable_specs = ((("dp", "ep"), None, None),) + else: + acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None)) + assert ( + _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs + ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})" + assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, ( + f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}" + f" (expected one of {acceptable_specs})" + ) + + replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec()) + out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd) + grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))( + grad_tokens + ) + out_global.block_until_ready() + grad_global.block_until_ready() + + ref_out, ref_grad = _reference_grad( + tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + ) + ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum()) + # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP + # column in a DP color sees identical inputs (and produces identical + # outputs), so collapse the ep dim to one replica before flattening + # to 2D against the dp-only reference. + dp_size, ep_size = args.dp_size, args.ep_size + global_out = ( + np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0] + .reshape(-1, ref_out.shape[-1]) + ) + global_grad = ( + np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] + .reshape(-1, ref_grad.shape[-1]) + ) + if args.process_id == 0: + fwd_diff = np.abs(global_out - ref_out) + grad_diff = np.abs(global_grad - ref_grad) + print( + f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} " + f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)" + ) + print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}") + print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}") + np.testing.assert_allclose( + global_out, + ref_out, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: fwd mismatch", + ) + np.testing.assert_allclose( + global_grad, + ref_grad, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: bwd mismatch", + ) + if args.process_id == 0: + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh new file mode 100755 index 0000000000..55b958f146 --- /dev/null +++ b/examples/jax/ep/run_test_ep.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +#!/bin/bash + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}" + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# NCCL EP requires NVLink P2P among ranks on the node. +echo "*** Checking NVLINK support ***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \ + || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING" + exit 0 +fi +echo "NVLINK support detected" + +SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" +COORD="${COORD:-127.0.0.1:12345}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo +echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***" + +PIDS=() +cleanup() { + for pid in "${PIDS[@]}"; do + kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true + done +} +trap cleanup EXIT INT TERM + +EXTRA_ARGS=${EXTRA_ARGS:-"--check"} + +for ((i=1; i "stdout_rank_${i}.txt" 2>&1 & + PIDS+=($!) +done +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python -u "$SCRIPT" \ + --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \ + $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt +wait + +HAS_FAILURE=0 +if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then + echo "... ep_moe FAILED" + HAS_FAILURE=1 +elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then + echo "... ep_moe INVALID (rank 0 produced no summary line)" + for ((i=1; i/dev/null + done + HAS_FAILURE=1 +else + echo "... ep_moe PASSED" +fi +rm -f stdout_rank_*.txt +exit $HAS_FAILURE diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh new file mode 100755 index 0000000000..a37ffc2952 --- /dev/null +++ b/tests/jax/multi_process_launch_ep.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) + +if [ "${NUM_RUNS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING." + exit 0 +fi +# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" + +OVERALL_RET=0 + +for SCRIPT_NAME in $SCRIPT_NAMES; do + echo "=== Running ${SCRIPT_NAME} ===" + for ((i=1; i stdout_rank_${i}.txt 2>&1 & + done + + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + + wait + + RET=0 + if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 + fi + # Treat missing test summary on rank 0 as hang/crash rather than silent success. + if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then + echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash." + echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output." + RET=1 + fi + if [ "$RET" -ne 0 ]; then + for ((i=1; i/dev/null || echo "(no log)" + done + fi + + rm -f stdout_multi_process.txt stdout_rank_*.txt + if [ "$RET" -ne 0 ]; then + OVERALL_RET=1 + fi +done + +exit "$OVERALL_RET" diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py new file mode 100644 index 0000000000..0658ad9750 --- /dev/null +++ b/tests/jax/test_multi_process_ep.py @@ -0,0 +1,690 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives. + +Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``. +Coverage: + + - ``ep_bootstrap`` rejects when ``ep_resource`` is unset. + - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``) + round-trip an identity expert → output ≈ tokens. + - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form). + - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form). + - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under + skewed upstream gradients (no k-axis averaging). + - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI. + +Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU). +""" + +import os +import sys +import unittest + +import jax +import jax.experimental.multihost_utils as jmu +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.cpp_extensions.ep import ( + ep_prepare, + ep_dispatch_fwd, + ep_combine_fwd, +) + + +# ── Test config ───────────────────────────────────────────────────────────── +# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in +# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even. + +NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_DP_SHARD = 4 # per device along dp + + +def _factor_dp_ep(num_procs): + """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``. + + NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment. + """ + override = os.environ.get("NVTE_TEST_EP_MESH") + if override: + dp_str, ep_str = override.lower().split("x") + dp, ep = int(dp_str), int(ep_str) + if dp * ep != num_procs: + raise ValueError( + f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}" + ) + if (NUM_LOCAL_EXPERTS * ep) % 4 != 0: + raise ValueError( + f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 " + "for NCCL EP TMA alignment" + ) + return dp, ep + if num_procs != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {num_procs}); set " + "NVTE_TEST_EP_MESH=DPxEP to override" + ) + return 2, 2 + + +def _build_mesh(dp, ep): + devs = np.asarray(jax.devices()).reshape(dp, ep) + return Mesh(devs, ("dp", "ep")) + + +def _local_device_sm(): + """Return SM major*10+minor of the first local CUDA device, or None.""" + try: + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is None: + return None + major, minor = (int(x) for x in str(cap).split(".")) + return major * 10 + minor + except Exception: + return None + + +class TestEP(unittest.TestCase): + @classmethod + def setUpClass(cls): + sm = _local_device_sm() + if sm is not None and sm < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})") + cls.num_procs = jax.process_count() + cls.rank = jax.process_index() + cls.dp, cls.ep = _factor_dp_ep(cls.num_procs) + cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep + # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color). + # Under PartitionSpec(("dp","ep"), None) each EP group sees + # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew. + T_per_ep_group = TOKENS_PER_DP_SHARD + active_experts = min(cls.num_experts, T_per_ep_group * TOP_K) + overconc = cls.num_experts // active_experts + cls.recv_capacity_per_rank = ( + NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2 + ) + cls.mesh = _build_mesh(cls.dp, cls.ep) + cls.mr = MeshResource(dp_resource="dp", ep_resource="ep") + with cls.mesh, global_shard_guard(cls.mr): + ep_bootstrap( + world_size=cls.num_procs, + rank=cls.rank, + ep_size=cls.ep, + num_experts=cls.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=cls.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Bootstrap precondition ──────────────────────────────────────────── + + def test_bootstrap_rejects_missing_ep_axis(self): + """ep_bootstrap raises when MeshResource has no ep_resource.""" + with self.mesh, global_shard_guard(MeshResource()): + with self.assertRaisesRegex(ValueError, "ep_resource"): + ep_bootstrap( + world_size=self.num_procs, + rank=self.rank, + ep_size=self.ep, + num_experts=self.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=self.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _make_identity_inputs(self, nonuniform=False): + """Identity routing + uniform weights — combined output ≈ tokens. + + ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced). + ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` — + expert 0 absorbs the entire batch while the others split the second + slot evenly. Exercises a skewed per-expert load. + """ + T_global = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + topk_idx = np.empty((T_global, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_global): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_global): + for k in range(TOP_K): + topk_idx[t, k] = (t * TOP_K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32) + tokens = jnp.asarray( + np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape( + T_global, HIDDEN_DIM + ), + dtype=jnp.bfloat16, + ) + return T_global, topk_idx, tokens, topk_weights + + def _make_random_inputs(self, seed=42, nonuniform=True): + """Random tokens + skewed top-2 routing (top1=0 always; top2 varies). + + Non-uniform load by default — guarantees expert 0 receives every token + while the rest of the experts split the second slot. Use + ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern. + """ + T_dp = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + rng = np.random.default_rng(seed=seed) + tokens = jnp.asarray( + rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_dp): + topk_idx_np[t, 0] = 0 + topk_idx_np[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_dp): + a, b = t % E, (t + 1) % E + topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a) + topk_idx = jnp.asarray(topk_idx_np) + topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) + return T_dp, tokens, topk_idx, topk_weights + + # ── Individual primitives (cpp_extensions level) ────────────────────── + + def test_two_prepares_distinct_handle_ids(self): + """Two ep_prepare sites with matching (top_k, alignment) must produce + distinct handle_ids — distinct logical layers cannot share a + HandleEntry. Verified by tracing through jit so the primitive's + outer_primitive.bind path is exercised.""" + _T, topk_idx, _tokens, _w = self._make_identity_inputs() + captured: list = [] + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + _tc_a, ha = ep_prepare(idx) + _tc_b, hb = ep_prepare(idx) + captured.append((ha.handle_id, hb.handle_id)) + return ha.handle_mem, hb.handle_mem + + hm_a, hm_b = run(idx_s) + hm_a.block_until_ready() + hm_b.block_until_ready() + id_a, id_b = captured[0] + self.assertNotEqual(id_a, id_b, "two ep_prepare calls returned the same handle_id") + + def test_primitive_prepare(self): + """ep_prepare returns the expected shapes and a valid handle id.""" + T_global, topk_idx, _tokens, _w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + tc, handle = ep_prepare(idx) + return tc, handle.handle_mem + + tc, hm = run(idx_s) + tc.block_until_ready() + self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS)) + self.assertEqual(hm.shape[0], self.dp * self.ep) + self.assertGreater(hm.shape[1], 0) + + def _run_identity_round_trip(self, nonuniform): + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + _tc, handle = ep_prepare(idx) + recv_t, recv_w, handle = ep_dispatch_fwd( + handle, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + # Apply the weighted hadamard inline (combine FFI is unweighted). + mask = (recv_w != 0).astype(jnp.float32)[..., None] + weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype( + recv_t.dtype + ) + weighted = jax.lax.with_sharding_constraint( + weighted, NamedSharding(self.mesh, ep_spec_3d) + ) + out = ep_combine_fwd( + handle, weighted, T_global, out_partition_spec=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + # Allgather so the rank-0 numpy comparison sees the full global tensor. + out_global = jmu.process_allgather(out, tiled=True) + + # Identity expert + uniform weights → out ≈ tokens (rank-0 check). + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_dispatch_combine_identity_uniform(self): + """Round-robin routing → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + """Skewed routing (top1=0 always) → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=True) + + def test_primitive_dispatch_combine_identity_bwd_uniform(self): + """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens. + + Identity routing + uniform top-k weights ⇒ dispatch∘combine is the + identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + return 0.5 * (out.astype(jnp.float32) ** 2).sum() + + grad = jax.jit(jax.grad(loss_fn))(tokens) + grad.block_until_ready() + grad_global = jmu.process_allgather(grad, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_3d_input_output(self): + """3D input ``[B, S, H]`` sharded on the first dim only — + ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape + and combine returns a matching 3D ``[B, S, H]`` output. End-to-end + round trip recovers the original tokens under identity routing + + uniform top-k weights.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + # B is sharded across all (dp*ep) ranks; S held in one piece per rank. + B, S, H = T_global, 1, tokens.shape[-1] + tokens_3d = tokens.reshape(B, S, H) + topk_idx_3d = topk_idx.reshape(B, S, -1) + topk_w_3d = topk_w.reshape(B, S, -1) + spec_3d = PartitionSpec(("dp", "ep"), None, None) + out_spec_3d = (("dp", "ep"), None, None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d)) + tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d)) + w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=(B, S), + out_sharding=out_spec_3d, + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + self.assertEqual(out_global.shape, (B, S, H)) + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens_3d.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_dp_only_first_dim(self): + """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must + accept it. JAX SPMD slices the missing ep axis locally so the kernel + still sees ``T/(dp*ep)`` tokens per rank.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_only = PartitionSpec("dp", None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + handle, + _tc, + recv_t, + recv_w, + num_local_tokens=T_global, + out_sharding=(("dp", "ep"), None), + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + # ── Custom-VJP tests ───────────────────────────────────────────────── + + def test_dispatch_vjp_fwd_bwd(self): + """ep_dispatch fwd + jax.grad w.r.t. tokens. + + Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears + TOP_K times in recv_tokens (all routes fit recv_capacity), so + grad_tokens = TOP_K * tokens (closed form). + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_tokens, _recv_w, _handle, _tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint( + recv_tokens, NamedSharding(self.mesh, ep_spec_3d) + ) + return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum() + + loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) + grad_tokens.block_until_ready() + grad_global = jmu.process_allgather(grad_tokens, tiled=True) + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_tokens.shape, tokens.shape) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)) * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_combine_vjp_fwd_bwd(self): + """ep_combine fwd + jax.grad w.r.t. expert_out. + + Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c + (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled + slots — so max|grad_eo| ≈ c / TOP_K. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + eo_const = 0.5 + expert_out = jnp.full( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(eo): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + _recv_tokens, recv_w, handle, tc = ep_dispatch( + idx, toks, w, self.recv_capacity_per_rank + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) + ) + combined = ep_combine(handle, tc, eo, recv_w, T_global) + # Pin combined to dp-sharded so autodiff transpose feeds + # ep_combine_bwd a per-shard cotangent. + combined = jax.lax.with_sharding_constraint( + combined, NamedSharding(self.mesh, dp_spec) + ) + return 0.5 * (combined.astype(jnp.float32) ** 2).sum() + + loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out) + grad_eo.block_until_ready() + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_eo.shape, expert_out.shape) + for shard in grad_eo.addressable_shards: + arr = np.asarray(shard.data.astype(jnp.float32)) + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots") + np.testing.assert_allclose( + arr.max(), + eo_const / float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_bwd_exact_per_k_topk_weights(self): + """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t. + + Guards against a regression where the bwd would average across the k + axis (per-token mean instead of per-slot exact recovery). + """ + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(idx_in, tok_in, w_in): + idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec)) + tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) + w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) + _recv_t, recv_w, _h, _tc = ep_dispatch( + idx_in, tok_in, w_in, self.recv_capacity_per_rank + ) + # Per-slot index scale ⇒ each slot's contribution differs. + scale = jnp.asarray( + np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0 + ) + return jnp.sum(recv_w * scale) + + grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w) + grad_topk_w.block_until_ready() + grad_global = jmu.process_allgather(grad_topk_w, tiled=True) + + if self.rank == 0: + grad_np = np.asarray(grad_global).astype(np.float32) + mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp)) + self.assertEqual( + mismatch, + 0, + f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed " + f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].", + ) + + # ── HLO reshard guard ──────────────────────────────────────────────── + # Compile-only: assert XLA inserts no cross-device collectives outside + # the EP FFI. EP-axis flux is carried by the FFI itself. + + def test_z_no_unexpected_reshard_in_hlo_fwd(self): + """Compiled fwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + @jax.jit + def run(idx, toks, w): + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + handle, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + compiled = run.lower(topk_idx, tokens, topk_w).compile() + hlo = compiled.as_text() + # Match instruction names; "all-gather-start" and "all-gather-done" + # bracket a single async all-gather. + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}") + # XLA drops trailing-None entries from the spec; compare as a tuple. + # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep". + expected = (("dp", "ep"),) if self.dp > 1 else ("ep",) + self.assertEqual(tuple(compiled.output_shardings.spec), expected) + + def test_z_no_unexpected_reshard_in_hlo_bwd(self): + """Compiled bwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + rng = np.random.default_rng(seed=44) + expert_out = jnp.asarray( + rng.standard_normal( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32 + ) + * 0.5, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def fwd(eo, toks, idx, w): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + _rt, rw, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) + combined = ep_combine(handle, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) + + # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd + # the expected sharding without relying on XLA-transpose propagation. + def bwd_only(eo, toks, idx, w, g): + _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w) + g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec)) + grads = vjp_fn(g) + return ( + jax.lax.with_sharding_constraint( + grads[0], NamedSharding(self.mesh, ep_spec_3d) + ), + jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)), + ) + + g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16) + compiled = ( + jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile() + ) + hlo = compiled.as_text() + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}") + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + if len(sys.argv) < 4: + print("Usage: python test_multi_process_ep.py ") + sys.exit(1) + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id, + local_device_ids=[proc_id], + ) + + loader = unittest.TestLoader() + target = os.environ.get("TARGET_TEST") + if target: + name = target.split(".")[-1] + suite = loader.loadTestsFromName(name, TestEP) + else: + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index fe1f93dc7a..604da5e1b7 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,4 +10,5 @@ from .softmax import * from .gemm import * from .router import * +from .ep import * from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py new file mode 100644 index 0000000000..7d112ad5f4 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -0,0 +1,955 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for Expert Parallelism (EP). + +Sharding model: + - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim. + Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else + ``ep_resource`` alone. + - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first + dim may be sharded, with axis ∈ {ep, (dp, ep), dp, None}. Trailing dims + must be replicated. ``dp`` alone gets ``ep`` folded in locally. + - EpCombine output sharding comes from ``out_sharding`` or defaults to the + compound ``(dp, ep)`` axis on the leading dim. +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec +import jax.tree_util as jtu + +import transformer_engine_jax +from .base import BasePrimitive, register_primitive +from ..sharding import global_mesh_resource + +__all__ = [ + "EpConfig", + "EpHandle", + "set_ep_config", + "get_ep_config", + "get_ep_num_local_experts", + "ep_allocate_handle_id", + "ep_prepare", + "ep_dispatch_fwd", + "ep_combine_fwd", + "ep_dispatch_bwd", + "ep_combine_bwd", +] + + +# Routing-state container threaded through dispatch/combine/*_bwd. +@jtu.register_pytree_node_class +class EpHandle: + def __init__(self, handle_mem, handle_id): + self.handle_mem = handle_mem + self.handle_id = int(handle_id) + + def tree_flatten(self): + return (self.handle_mem,), (self.handle_id,) + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0], aux[0]) + + def __repr__(self): + return f"EpHandle(handle_id={self.handle_id})" + + +# ── Module-level EP config ────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class EpConfig: + """Immutable Python view of the EP bootstrap config (see ep_bootstrap).""" + + world_size: int + rank: int + ep_size: int + num_experts: int + num_local_experts: int + max_tokens_per_rank: int + recv_capacity_per_rank: int + hidden_dim: int + + +_ep_config: EpConfig = None + + +def set_ep_config(config: EpConfig) -> None: + """Cache the EP config for abstract-eval / sharding helpers. Call once.""" + global _ep_config + _ep_config = config + + +def get_ep_config() -> EpConfig: + if _ep_config is None: + raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") + return _ep_config + + +def get_ep_num_local_experts() -> int: + return get_ep_config().num_local_experts + + +# handle_id -> handle_mem buffer size in bytes. +_HANDLE_MEM_SIZE_BY_ID: dict = {} + + +def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int: + """Reserve a fresh handle_id for an EP layer. + + Distinct logical layers must each call this — sharing a handle_id across + layers corrupts the routing state, even when (top_k, alignment) match. + """ + handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer( + int(top_k), int(dispatch_output_per_expert_alignment) + ) + handle_id = int(handle_id) + _HANDLE_MEM_SIZE_BY_ID[handle_id] = int(handle_mem_size) + return handle_id + + +def _ep_handle_mem_size(handle_id: int) -> int: + """Return the handle_mem byte size for an id from ep_allocate_handle_id.""" + try: + return _HANDLE_MEM_SIZE_BY_ID[int(handle_id)] + except KeyError as e: + raise RuntimeError( + f"handle_id={handle_id} not registered; call ep_allocate_handle_id first." + ) from e + + +def _leading_axis_ok(spec, ep_axis, outer_axes=()): + # Only the first dim may carry sharding; remaining dims must be replicated. + # The first dim's axis must be one of: + # ``ep_axis`` alone, + # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally), + # a tuple ending in ``ep_axis`` with dp/fsdp axes before it. + # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``, + # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``, + # ``("dp", None, None)``. + if len(spec) < 2 or ep_axis is None: + return False + if any(ax is not None for ax in spec[1:]): + return False # only first dim sharded + leading = spec[0] + allowed_outers = {a for a in outer_axes if a is not None} + allowed = allowed_outers | {ep_axis, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +def _canonical_input_spec(spec, ndim): + """Canonical input PartitionSpec the primitive demands JAX deliver. + + Sharding lives entirely on the first dim. If ``spec[0]`` already includes + ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded + into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added + ep axis is a local slice (the missing dim was replicated), no cross-device + comm. + """ + gsr = global_mesh_resource() + ep = gsr.ep_resource + leading = spec[0] + present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else () + if ep in present: + return PartitionSpec(*spec) + if leading is None: + new_leading = ep + elif isinstance(leading, tuple): + new_leading = (*leading, ep) + else: + new_leading = (leading, ep) + return PartitionSpec(new_leading, *([None] * (ndim - 1))) + + +def _dispatch_input_outer_axes(): + """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input.""" + gsr = global_mesh_resource() + return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + + +def _ep_outer_axis(): + """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors. + + When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD + sees each DP color's slab as distinct (rather than replicated across DP). + """ + gsr = global_mesh_resource() + return gsr.dp_resource or gsr.fsdp_resource + + +def _ep_leading_dims(is_outer): + """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when + DP is unset) globally; ``(1,)`` per shard.""" + cfg = get_ep_config() + outer = _ep_outer_axis() + if not is_outer: + return (1,) + return (cfg.world_size,) if outer is not None else (cfg.ep_size,) + + +def _ep_output_spec(*trailing): + """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when + DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``.""" + gsr = global_mesh_resource() + outer = _ep_outer_axis() + if outer is None: + return PartitionSpec(gsr.ep_resource, *trailing) + return PartitionSpec((outer, gsr.ep_resource), *trailing) + + +def _ep_spec_ok(spec, trailing_count): + """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` / + ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP) + on an EP-output tensor's single leading dim. JAX may collapse a size-1 + mesh axis to ``None`` (matters for dp_size=1 like 1x4).""" + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = _ep_outer_axis() + expected_len = 1 + trailing_count + if len(spec) != expected_len: + return False + if any(ax is not None for ax in spec[1:]): + return False + leading = spec[0] + if outer is None: + return leading == ep_axis + allowed = {ep_axis, outer, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +# ── ep_prepare ────────────────────────────────────────────────────────────── + + +class EpPreparePrimitive(BasePrimitive): + name = "te_ep_prepare_ffi" + multiple_results = True + impl_static_args = (1, 2, 3) # handle_id, dispatch_output_per_expert_alignment, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del dispatch_output_per_expert_alignment + cfg = get_ep_config() + num_local_experts = cfg.num_local_experts + assert ( + len(topk_idx_aval.shape) >= 2 + ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}" + handle_mem_size = _ep_handle_mem_size(handle_id) + leading = _ep_leading_dims(is_outer) + token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32) + handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) + # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last + # dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + # TODO(phuong): drop once NCCL EP supports int32 topk_idx. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return token_counts_aval, handle_mem_aval, workspace_aval + + @staticmethod + def outer_abstract(topk_idx_aval, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + avals = EpPreparePrimitive.abstract( + topk_idx_aval, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=True, + ) + return avals[:2] + + @staticmethod + def lowering(ctx, topk_idx, *, handle_id, dispatch_output_per_expert_alignment, is_outer): + del is_outer + return ffi.ffi_lowering(EpPreparePrimitive.name)( + ctx, + topk_idx, + handle_id=int(handle_id), + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + ) + + @staticmethod + def impl(topk_idx, handle_id, dispatch_output_per_expert_alignment, is_outer): + assert EpPreparePrimitive.inner_primitive is not None + token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=is_outer, + ) + return token_counts, handle_mem + + @staticmethod + def batcher( + batched_args, batch_dims, *, handle_id, dispatch_output_per_expert_alignment, is_outer + ): + raise NotImplementedError("EpPreparePrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + idx_spec = arg_infos[0].sharding.spec + if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpPrepare: topk_idx leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" + f" got spec={idx_spec}." + ) + idx_ndim = len(arg_infos[0].shape) + arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),) + tc_sharding = NamedSharding(mesh, _ep_output_spec(None)) + hm_sharding = NamedSharding(mesh, _ep_output_spec(None)) + + def sharded_impl(topk_idx): + return EpPreparePrimitive.impl( + topk_idx, handle_id, dispatch_output_per_expert_alignment, False + ) + + return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, dispatch_alignment, is_outer). + value_types = args[-2] + topk_idx_rank = len(value_types[0].shape) + in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk" + return f"{in_axes} -> EPL nle, EPL hm" + + +register_primitive(EpPreparePrimitive) + + +# ── ep_dispatch ───────────────────────────────────────────────────────────── + + +class EpDispatchPrimitive(BasePrimitive): + name = "te_ep_dispatch_ffi" + multiple_results = True + impl_static_args = (4, 5, 6, 7) # handle_id, recv_capacity_per_rank, top_k, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + topk_idx_aval, + tokens_aval, + topk_weights_aval, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, topk_weights_aval, top_k, handle_mem_aval + assert ( + len(tokens_aval.shape) >= 2 + ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}" + recv_pr = recv_capacity_per_rank + tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype) + hidden_dim = tokens_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) + recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) + # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + avals = EpDispatchPrimitive.abstract(*args, **kwargs) + return avals[:2] + + @staticmethod + def lowering( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + *, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpDispatchPrimitive.name)( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=int(handle_id), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + ): + assert EpDispatchPrimitive.inner_primitive is not None + recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=is_outer, + ) + return recv_tokens, recv_topk_weights + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, top_k, is_outer): + raise NotImplementedError("EpDispatchPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, recv_capacity_per_rank, top_k, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + tokens_spec = arg_infos[2].sharding.spec + if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpDispatch: tokens leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" + f" got spec={tokens_spec}." + ) + idx_spec = arg_infos[1].sharding.spec + tw_spec = arg_infos[3].sharding.spec + arg_shardings = ( + arg_infos[0].sharding, + NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))), + NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))), + NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))), + ) + out_shardings = ( + NamedSharding(mesh, _ep_output_spec(None, None)), + NamedSharding(mesh, _ep_output_spec(None)), + ) + + def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): + return EpDispatchPrimitive.impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id, + recv_capacity_per_rank, + top_k, + False, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (handle_id, recv_capacity_per_rank, top_k, is_outer). + value_types = args[-2] + # Inputs: handle_mem, topk_idx, tokens, topk_weights. + idx_rank = len(value_types[1].shape) + tok_rank = len(value_types[2].shape) + tw_rank = len(value_types[3].shape) + idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in" + tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H" + tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk" + return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr" + + +register_primitive(EpDispatchPrimitive) + + +# ── ep_combine ────────────────────────────────────────────────────────────── +# `expert_out` here is the post-weight buffer; ep.ep_combine applies the +# hadamard before calling. + + +def _normalize_leading_shape(s): + return s if isinstance(s, tuple) else (int(s),) + + +def _prod(seq): + p = 1 + for x in seq: + p *= int(x) + return p + + +def _resolve_out_partition_spec(out_partition_spec, num_leading): + """Pick the combine output PartitionSpec. + + Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a + DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``. + This matches the input sharding so XLA does not need collective-permutes + in the bwd path. + """ + if out_partition_spec is not None: + assert len(out_partition_spec) == num_leading + 1, ( + f"out_partition_spec length {len(out_partition_spec)} must equal num_leading" + f" + 1 ({num_leading + 1})" + ) + return tuple(out_partition_spec) + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_combine: ep_resource is not set on the active MeshResource;" + " pass out_sharding=... explicitly." + ) + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + (None,) * num_leading + + +def _per_shard_leading(out_leading_shape, resolved_spec, mesh): + """Per-shard leading shape given resolved partition spec and mesh.""" + per_shard = list(out_leading_shape) + for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]): + if ax is None: + continue + axes = ax if isinstance(ax, tuple) else (ax,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert ( + per_shard[i] % factor == 0 + ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}" + per_shard[i] //= factor + return tuple(per_shard) + + +class EpCombinePrimitive(BasePrimitive): + name = "te_ep_combine_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + expert_out_aval, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del handle_id, out_partition_spec, handle_mem_aval + assert ( + len(expert_out_aval.shape) == 3 + ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}" + eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype) + hidden_dim = expert_out_aval.shape[-1] + out_shape = tuple(out_leading_shape) + (hidden_dim,) + return jax.core.ShapedArray(out_shape, eo_dtype) + + @staticmethod + def lowering( + ctx, + handle_mem, + expert_out, + *, + handle_id, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpCombinePrimitive.name)( + ctx, + handle_mem, + expert_out, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + ) + + @staticmethod + def impl(handle_mem, expert_out, handle_id, out_leading_shape, out_partition_spec): + assert EpCombinePrimitive.inner_primitive is not None + return EpCombinePrimitive.inner_primitive.bind( + handle_mem, + expert_out, + handle_id=handle_id, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, out_leading_shape, out_partition_spec): + raise NotImplementedError("EpCombinePrimitive does not support vmap") + + @staticmethod + def partition(handle_id, out_leading_shape, out_partition_spec, mesh, arg_infos, result_infos): + del result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*resolved)) + + def sharded_impl(handle_mem, expert_out): + return EpCombinePrimitive.impl( + handle_mem, expert_out, handle_id, per_shard_leading, out_partition_spec + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args: + # (handle_id, out_leading_shape, out_partition_spec). + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H" + return f"EPL hm, EPL recv_pr H -> {out_axes}" + + +register_primitive(EpCombinePrimitive) + + +# ── ep_dispatch_bwd ───────────────────────────────────────────────────────── + + +class EpDispatchBwdPrimitive(BasePrimitive): + name = "te_ep_dispatch_bwd_ffi" + multiple_results = True + impl_static_args = (3, 4, 5, 6) # handle_id, top_k, out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + g_recv_topk_weights_aval, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del handle_id, g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval + assert ( + len(grad_aval.shape) == 3 + ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype) + grad_topk_weights_aval = jax.core.ShapedArray( + tuple(out_leading_shape) + (top_k,), jnp.float32 + ) + return result_aval, grad_topk_weights_aval + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + handle_id=int(handle_id), + num_local_tokens=_prod(out_leading_shape), + top_k=top_k, + ) + + @staticmethod + def impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + assert EpDispatchBwdPrimitive.inner_primitive is not None + return EpDispatchBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle_id, + top_k=top_k, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap") + + @staticmethod + def partition( + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_shardings = [ + NamedSharding(mesh, PartitionSpec(*resolved)), + NamedSharding(mesh, PartitionSpec(*resolved, None)), + ] + + def sharded_impl(handle_mem, grad, g_recv_topk_weights): + return EpDispatchBwdPrimitive.impl( + handle_mem, + grad, + g_recv_topk_weights, + handle_id, + top_k, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Result rank + # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1. + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k" + + +register_primitive(EpDispatchBwdPrimitive) + + +# ── ep_combine_bwd ────────────────────────────────────────────────────────── + + +class EpCombineBwdPrimitive(BasePrimitive): + name = "te_ep_combine_bwd_ffi" + multiple_results = False + impl_static_args = (2, 3, 4) # handle_id, recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(handle_mem_aval, grad_aval, *, handle_id, recv_capacity_per_rank, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del handle_id, handle_mem_aval + assert ( + len(grad_aval.shape) >= 2 + ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + return EpCombineBwdPrimitive.abstract(*args, **kwargs) + + @staticmethod + def lowering(ctx, handle_mem, grad, *, handle_id, recv_capacity_per_rank, is_outer): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpCombineBwdPrimitive.name)( + ctx, + handle_mem, + grad, + handle_id=int(handle_id), + ) + + @staticmethod + def impl(handle_mem, grad, handle_id, recv_capacity_per_rank, is_outer): + assert EpCombineBwdPrimitive.inner_primitive is not None + return EpCombineBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + handle_id=handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, handle_id, recv_capacity_per_rank, is_outer): + raise NotImplementedError("EpCombineBwdPrimitive does not support vmap") + + @staticmethod + def partition(handle_id, recv_capacity_per_rank, is_outer, mesh, arg_infos, result_infos): + del is_outer, result_infos + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, _ep_output_spec(None, None)) + + def sharded_impl(handle_mem, grad): + return EpCombineBwdPrimitive.impl( + handle_mem, grad, handle_id, recv_capacity_per_rank, False + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # T axes are dynamic-rank based on the actual cotangent shape. + value_types = args[-2] + g_rank = len(value_types[1].shape) + g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H" + return f"EPL hm, {g_axes} -> EPL recv_pr H" + + +register_primitive(EpCombineBwdPrimitive) + + +# ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── + + +_HANDLE_ID_CALLSITE_CACHE = {} + + +def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): + """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" + import sys as _sys + + top_k = int(topk_idx.shape[-1]) + alignment = int(dispatch_output_per_expert_alignment) + # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces + # the same call site (e.g. custom_vjp fwd vs primal) and the resulting + # EpHandles must share the same id to compare equal in pytree aux. + f = _sys._getframe(1) + cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) + handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) + if handle_id is None: + handle_id = ep_allocate_handle_id(top_k, alignment) + _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id + token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( + topk_idx, + handle_id=handle_id, + dispatch_output_per_expert_alignment=alignment, + is_outer=True, + ) + return token_counts, EpHandle(handle_mem, handle_id) + + +def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" + top_k = int(topk_weights.shape[-1]) + recv_tokens, recv_topk_weights = EpDispatchPrimitive.outer_primitive.bind( + handle.handle_mem, + topk_idx, + tokens, + topk_weights, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + top_k=top_k, + is_outer=True, + ) + return recv_tokens, recv_topk_weights, handle + + +def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): + """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpCombinePrimitive.outer_primitive.bind( + handle.handle_mem, + expert_out, + handle_id=handle.handle_id, + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_dispatch_bwd( + handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None +): + """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpDispatchBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + g_recv_topk_weights, + handle_id=handle.handle_id, + top_k=int(top_k), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_combine_bwd(handle, grad, recv_capacity_per_rank): + """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" + return EpCombineBwdPrimitive.outer_primitive.bind( + handle.handle_mem, + grad, + handle_id=handle.handle_id, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 416b18ada0..4d8b097f27 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -200,6 +200,25 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// EP bootstrap (called once per process) +void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +// EP shutdown — registered as a Python atexit hook so it runs before +// C++ static destructors of the JAX extension and libtransformer_engine.so. +void EpShutdown(); +// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where +// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync +// per op) and handle_mem_size sizes the caller's handle_mem buffer. +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); + +// EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler); + // TopK XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..e2c50135aa --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -0,0 +1,457 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include + +#include +#include +#include + +#include "../extensions.h" +#include "common.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine { +namespace jax { + +namespace { + +// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast +// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a +// Python atexit hook from ep.py so it runs before C++ static destructors). +class EpCommManager { + public: + static EpCommManager& get() { + static EpCommManager inst; + return inst; + } + + void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { + std::lock_guard lock(mutex_); + NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + ncclUniqueId uid; + std::memcpy(&uid, uid_bytes, sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + } + + ncclComm_t comm() const { return comm_; } + + void shutdown() { + std::lock_guard lock(mutex_); + if (comm_ == nullptr) return; + ncclCommDestroy(comm_); + comm_ = nullptr; + } + + private: + EpCommManager() = default; + // Intentionally no NCCL teardown in the destructor: this runs at static-dtor + // time, after Python has finalized and possibly after the CUDA driver + // detaches the context. Calling ncclCommDestroy there has been observed to + // hang or report cudartUnloading. Normal teardown goes through the Python + // atexit hook (shutdown_ep_communicator) registered from ep.py; any path + // that skips that (os._exit, fatal signal) leaks the comm, which the OS + // reaps on process exit. + ~EpCommManager() = default; + EpCommManager(const EpCommManager&) = delete; + EpCommManager& operator=(const EpCommManager&) = delete; + + std::mutex mutex_; + ncclComm_t comm_{nullptr}; +}; + +} // namespace + +// handle_id is baked at jit trace time and carried as a static FFI attribute. + +struct EpPrepareConfig { + int64_t handle_id; + int64_t dispatch_output_per_expert_alignment; +}; + +struct EpDispatchConfig { + int64_t handle_id; + int64_t top_k; +}; + +struct EpCombineConfig { + int64_t handle_id; + int64_t num_local_tokens; +}; + +struct EpDispatchBwdConfig { + int64_t handle_id; + int64_t num_local_tokens; + int64_t top_k; +}; + +struct EpCombineBwdConfig { + int64_t handle_id; +}; + +// ── Bootstrap helpers ───────────────────────────────────────────────────────── + +void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { + std::string uid_str = unique_id_bytes_obj; + NVTE_CHECK(static_cast(uid_str.size()) >= 128, + "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); + EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, + rank_within_group); + NVTEEpGroupConfig cfg{.ep_size = ep_size, + .num_experts = num_experts, + .max_tokens_per_rank = max_tokens_per_rank, + .max_recv_tokens_per_rank = max_recv_tokens_per_rank, + .hidden_dim = hidden_dim, + .max_num_sms = max_num_sms}; + // If common rejects the config (validate_config / ncclEpCreateGroup), roll + // the comm back so the two singletons don't end up in inconsistent states + // and the comm doesn't strand until process exit. + try { + nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); + } catch (...) { + EpCommManager::get().shutdown(); + throw; + } +} + +void EpShutdown() { + // Order matters: ep_group_ in common reads from the comm, so tear it down + // first, then destroy the comm. + nvte_ep_shutdown(); + EpCommManager::get().shutdown(); +} + +pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{0, top_k, dispatch_output_per_expert_alignment}; + size_t handle_mem_size = 0; + uint64_t handle_id = nvte_ep_register_layer(layer_cfg, &handle_mem_size); + return pybind11::make_tuple(handle_id, handle_mem_size); +} + +// ── ep_prepare ──────────────────────────────────────────────────────────────── + +Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, + Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { + auto topk_dims = topk_idx.dimensions(); + NVTE_CHECK(topk_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + + std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1), + static_cast(topk_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = topk_shape[0] * topk_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); + + std::vector tc_shape = {static_cast(token_counts->element_count())}; + auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + + std::vector hm_shape = {static_cast(handle_mem->element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + nvte_ep_prepare(handle, topk_idx_.data(), token_counts_.data(), + static_cast(config.dispatch_output_per_expert_alignment), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch ─────────────────────────────────────────────────────────────── + +Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, + Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, + Result_Type recv_topk_weights, Result_Type workspace, + EpDispatchConfig config) { + auto token_dims = tokens.dimensions(); + NVTE_CHECK(token_dims.size() >= 2, + "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + auto idx_dims = topk_idx.dimensions(); + NVTE_CHECK(idx_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k, + ") must match topk_idx last dim (", idx_dims.back(), ")"); + std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1), + static_cast(idx_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = idx_shape[0] * idx_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64); + + const size_t T_flat = product(token_dims, 0, token_dims.size() - 1); + const size_t H = static_cast(token_dims.back()); + std::vector tok_shape = {T_flat, H}; + auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type()); + auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype); + + auto tw_dims = topk_weights.dimensions(); + NVTE_CHECK(tw_dims.size() >= 2, + "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size()); + std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1), + static_cast(tw_dims.back())}; + auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32); + + // recv_tokens: flatten any leading dims into recv_capacity_per_rank. + auto recv_dims = recv_tokens->dimensions(); + NVTE_CHECK(recv_dims.size() >= 2, + "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size()); + const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1); + std::vector recv_shape = {recv_capacity_per_rank, H}; + auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype); + + auto recv_w_dims = recv_topk_weights->dimensions(); + NVTE_CHECK(recv_w_dims.size() >= 1, + "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size()); + const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size()); + NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total, + ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")"); + std::vector recv_w_shape = {recv_capacity_per_rank}; + auto recv_topk_weights_ = + TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch(handle, topk_idx_.data(), tokens_.data(), no_win, topk_weights_.data(), no_win, + recv_tokens_.data(), no_win, recv_topk_weights_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine ──────────────────────────────────────────────────────────────── + +Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, + Result_Type result, EpCombineConfig config) { + auto eo_dims = expert_out.dimensions(); + NVTE_CHECK(eo_dims.size() >= 2, + "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1); + const size_t H = static_cast(eo_dims.back()); + std::vector eo_shape = {recv_capacity_per_rank, H}; + auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type()); + auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype); + + auto res_dims = result->dimensions(); + NVTE_CHECK(res_dims.size() >= 2, + "result must be at least 2D [..., H]; got ndim=", res_dims.size()); + const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); + NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens, + "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector res_shape = {res_T_flat, H}; + auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine(handle, expert_out_.data(), no_win, result_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── + +Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, + Result_Type grad_topk_weights, EpDispatchBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {recv_capacity_per_rank, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto gw_dims = g_recv_topk_weights.dimensions(); + NVTE_CHECK( + gw_dims.size() >= 1, + "g_recv_topk_weights rank must flatten to recv_capacity_per_rank; got ndim=", gw_dims.size()); + const size_t gw_total = product(gw_dims, 0, gw_dims.size()); + NVTE_CHECK(gw_total == recv_capacity_per_rank, "g_recv_topk_weights total (", gw_total, + ") must match grad recv_pr (", recv_capacity_per_rank, ")"); + std::vector gw_shape = {recv_capacity_per_rank}; + auto g_recv_topk_weights_ = + TensorWrapper(g_recv_topk_weights.untyped_data(), gw_shape, DType::kFloat32); + + auto out_dims = grad_tokens->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); + const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); + NVTE_CHECK(static_cast(T_flat) == config.num_local_tokens, + "grad_tokens leading-dim product (", T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector out_shape = {T_flat, H}; + auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); + + auto gtw_dims = grad_topk_weights->dimensions(); + NVTE_CHECK(gtw_dims.size() >= 2, + "grad_topk_weights must be at least 2D [..., top_k]; got ndim=", gtw_dims.size()); + const size_t gtw_T_flat = product(gtw_dims, 0, gtw_dims.size() - 1); + NVTE_CHECK(gtw_T_flat == T_flat, "grad_topk_weights leading-dim product (", gtw_T_flat, + ") must equal grad_tokens leading-dim product (", T_flat, ")"); + const size_t top_k = static_cast(gtw_dims.back()); + NVTE_CHECK(static_cast(top_k) == config.top_k, "top_k attr (", config.top_k, + ") must match grad_topk_weights last dim (", top_k, ")"); + std::vector gtw_shape = {T_flat, top_k}; + auto grad_topk_weights_ = + TensorWrapper(grad_topk_weights->untyped_data(), gtw_shape, DType::kFloat32); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch_bwd(handle, grad_.data(), no_win, g_recv_topk_weights_.data(), no_win, + grad_tokens_.data(), grad_topk_weights_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine_bwd ──────────────────────────────────────────────────────────── + +Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, + Result_Type grad_expert_out, EpCombineBwdConfig config) { + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t T_flat = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {T_flat, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto out_dims = grad_expert_out->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_expert_out must be at least 2D [..., recv_pr, H]; got ndim=", out_dims.size()); + const size_t recv_capacity_per_rank = product(out_dims, 0, out_dims.size() - 1); + const size_t out_H = static_cast(out_dims.back()); + NVTE_CHECK(out_H == H, "grad_expert_out hidden dim (", out_H, ") must match grad H (", H, ")"); + std::vector out_shape = {recv_capacity_per_rank, H}; + auto grad_expert_out_ = TensorWrapper(grad_expert_out->untyped_data(), out_shape, g_dtype); + + NVTEEpHandle handle{static_cast(config.handle_id), handle_mem_.data()}; + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine_bwd(handle, grad_.data(), no_win, grad_expert_out_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out + .Attrs(), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpPrepareConfig, ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpDispatchBwdConfig, + ::xla::ffi::StructMember("handle_id"), + ::xla::ffi::StructMember("num_local_tokens"), + ::xla::ffi::StructMember("top_k")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::EpCombineBwdConfig, + ::xla::ffi::StructMember("handle_id")); + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..b34f8739ee 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -101,6 +101,15 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); +#ifdef NVTE_WITH_NCCL_EP + // Expert Parallelism + dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); + dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); + dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); + dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); + dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); +#endif // NVTE_WITH_NCCL_EP + // TopK dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); @@ -127,6 +136,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); +#ifdef NVTE_WITH_NCCL_EP + m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), + pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); + m.def("shutdown_ep_communicator", &EpShutdown); + m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), + pybind11::arg("dispatch_output_per_expert_alignment") = 0); +#endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py new file mode 100644 index 0000000000..40d07bc3d4 --- /dev/null +++ b/transformer_engine/jax/ep.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX Expert Parallelism (EP) API.""" + +import atexit +import ctypes +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jmu +import numpy as np + +import transformer_engine_jax +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.ep import EpHandle +from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size + +ep_prepare = tex.ep_prepare + +__all__ = [ + "EpHandle", + "ep_bootstrap", + "ep_prepare", + "ep_dispatch", + "ep_combine", +] + +_atexit_registered = False + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +def ep_bootstrap( + world_size, + rank, + ep_size, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=0, +): + """Initialize the EP communicator. Call once per process before any EP op. + + max_num_sms caps the SMs allotted to EP kernels (0 = auto). + """ + if world_size < 2: + raise ValueError( + f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" + " at least 2 ranks to form a group." + ) + if world_size % ep_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_size ({ep_size}); otherwise" + " some EP groups would have fewer than ep_size ranks and ncclCommInitRank would hang." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + if jax.local_device_count() != 1: + raise ValueError( + "ep_bootstrap requires one local device per process (got" + f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" + " support single-process multi-device setups." + ) + UID_SIZE = 128 + dp_color = rank // ep_size + rank_within_group = rank % ep_size + is_color_root = rank_within_group == 0 + if is_color_root: + try: + from nccl import get_unique_id + + uid_bytes = bytes(get_unique_id())[:UID_SIZE] + except ImportError: + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) + else: + uid_bytes = bytes(UID_SIZE) + + uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) + all_uids = jmu.process_allgather(uid_arr).reshape(world_size, UID_SIZE) + uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) + + ep_resource = global_mesh_resource().ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + mesh_ep_size = get_mesh_axis_size(ep_resource) + if mesh_ep_size != ep_size: + raise ValueError( + f"ep_bootstrap: EpConfig.ep_size ({ep_size}) does not match mesh axis" + f" '{ep_resource}' size ({mesh_ep_size})." + ) + + transformer_engine_jax.initialize_ep_communicator( + uid_bytes, + ep_size, + rank_within_group, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=int(max_num_sms), + ) + + # Shutdown ordering: + # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, + # which assumes jax.distributed.initialize() ran earlier, so JAX's + # distributed atexit hooks are already registered before this one. Ours + # therefore fires first at exit — fine, because EpShutdown only touches + # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on + # JAX's coordination service. Do not add JAX calls to EpShutdown. + # - Running before C++ static destructors avoids the cudartUnloading + # hazard; the C++ destructors are intentionally no-ops. + global _atexit_registered + if not _atexit_registered: + atexit.register(transformer_engine_jax.shutdown_ep_communicator) + _atexit_registered = True + + tex.ep.set_ep_config( + tex.ep.EpConfig( + world_size=world_size, + rank=rank, + ep_size=ep_size, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + ) + ) + + +# ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def ep_dispatch( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment=0, +): + """Scatter tokens and weights to expert ranks. + + Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``. Only the leading dim may + be sharded — axis ∈ {ep, (dp, ep), dp, None}; trailing dims replicated. + + Args: + topk_idx: ``[..., top_k]`` int32/int64 routing indices. + tokens: ``[..., H]`` activations (matching leading dims). + topk_weights: ``[..., top_k]`` float32 routing weights. + recv_capacity_per_rank: STATIC int. Per-rank recv slot count. + dispatch_output_per_expert_alignment: STATIC int. Per-expert slot + alignment; 0 disables. + + Returns: + ``(recv_tokens, recv_topk_weights, handle, token_counts)`` where + ``recv_tokens`` is 3D ``[num_procs, recv_capacity_per_rank, H]`` + sharded ``(("dp","ep"), None, None)`` (or ``("ep", None, None)`` if + DP is unset), and ``recv_topk_weights`` is 2D + ``[num_procs, recv_capacity_per_rank]`` similarly sharded. Pass + ``handle`` to the matching ``ep_combine``. + """ + return _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, + )[0] + + +def _dispatch_fwd( + topk_idx, + tokens, + topk_weights, + recv_capacity_per_rank, + dispatch_output_per_expert_alignment, +): + top_k = int(topk_weights.shape[-1]) + token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment) + recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( + handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank + ) + out_leading = tuple(tokens.shape[:-1]) + primal = (recv_tokens, recv_topk_weights, handle, token_counts) + return primal, (handle, out_leading, top_k) + + +def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): + del recv_capacity_per_rank, dispatch_output_per_expert_alignment + handle, out_leading, top_k = res + # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a + # single-fwd-output cotangent, landing a global tensor in the FFI. + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None else ep_axis + g_recv_tokens = jax.lax.with_sharding_constraint( + g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) + ) + g_recv_topk_weights = jax.lax.with_sharding_constraint( + g_outputs[1], jax.sharding.PartitionSpec(leading, None) + ) + grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( + handle, g_recv_tokens, g_recv_topk_weights, top_k, out_leading + ) + return (None, grad_tokens, grad_topk_weights) + + +ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) + + +# ── ep_combine (custom_vjp) ────────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +def ep_combine( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding=None +): + """Reduce weighted expert outputs back to source ranks. + + Args: + handle: ``EpHandle`` from a matching ``ep_dispatch`` call. + token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). + expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. + recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights + returned by ``ep_dispatch``. + num_local_tokens: STATIC int or tuple. int → 2D output ``[T, H]``; + tuple → N-D output ``[*tuple, H]``. + out_sharding: STATIC optional ``PartitionSpec`` tuple for the + output. Defaults to ``(("dp","ep"), *None)`` when + DP is set, else ``("ep", *None)``. Pass a custom + spec to override; only the leading dim may be + sharded. + + Returns: + ``[..., H]`` combined output shaped per ``num_local_tokens``. + """ + return _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + )[0] + + +def _make_valid_mask(recv_topk_weights, dtype): + # recv_topk_weights == 0 marks a padded slot. + return (recv_topk_weights != 0).astype(dtype)[..., None] + + +def _combine_fwd( + handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding +): + del token_counts + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + result = tex.ep_combine_fwd(handle, weighted, num_local_tokens, out_partition_spec=out_sharding) + return result, (handle, recv_topk_weights, expert_out) + + +def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): + handle, recv_topk_weights, expert_out = res + # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. + recv_capacity_per_rank = expert_out.shape[-2] + # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. + gsr = global_mesh_resource() + if _out_sharding is not None: + spec = jax.sharding.PartitionSpec(*_out_sharding) + else: + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis + spec = ( + jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) + if leading is not None + else None + ) + if spec is not None: + g_result = jax.lax.with_sharding_constraint(g_result, spec) + grad_weighted = tex.ep_combine_bwd(handle, g_result, recv_capacity_per_rank) + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + grad_weighted_f32 = grad_weighted.astype(jnp.float32) + grad_expert_out = (grad_weighted_f32 * w * mask).astype(grad_weighted.dtype) + grad_recv_topk_weights = ( + (grad_weighted_f32 * expert_out.astype(jnp.float32) * mask) + .sum(axis=-1) + .astype(recv_topk_weights.dtype) + ) + return (None, None, grad_expert_out, grad_recv_topk_weights) + + +ep_combine.defvjp(_combine_fwd, _combine_bwd) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 182a4a2e00..1dbdfbc533 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -332,7 +332,12 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None - ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None + ep_resource: Axis name for expert parallelism. Dispatch input tokens + must be sharded on their leading dim by ``ep_resource`` (alone or + compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. + ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output + ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` + on the leading ``ep_size`` dim. """ dp_resource: str = None @@ -475,3 +480,8 @@ def dp_or_fsdp_axis_size(): dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) return dp_size if dp_size > 1 else fsdp_size + + +def ep_axis_size(): + """Get the size of the dispatch/EP axis (ep_resource). Returns 1 if unset.""" + return get_mesh_axis_size(global_mesh_resource().ep_resource) From b43710e538f6626ad91b0dee4a9b735bfe2a5fe9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 00:31:54 +0000 Subject: [PATCH 21/36] JAX EP: tie NCCL comm lifetime to JAX executables via XLA stateful FFI Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 11 + transformer_engine/jax/csrc/extensions.h | 21 +- transformer_engine/jax/csrc/extensions/ep.cpp | 273 +++++++++++------- .../jax/csrc/extensions/pybind.cpp | 28 +- transformer_engine/jax/ep.py | 15 +- 5 files changed, 222 insertions(+), 126 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6eb588c849..2cdef4bfe7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") +# Register EpInstanceState (no-op when TE is built without NCCL EP). +if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"): + ffi.register_ffi_type( + "EpInstanceState", + { + "type_id": transformer_engine_jax.get_ep_instance_state_type_id(), + "type_info": transformer_engine_jax.get_ep_instance_state_type_info(), + }, + platform="CUDA", + ) + def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): """ diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 4d8b097f27..62e762a5bb 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -200,19 +200,20 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); -// EP bootstrap (called once per process) -void EpInitialize(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms); -// EP shutdown — registered as a Python atexit hook so it runs before -// C++ static destructors of the JAX extension and libtransformer_engine.so. -void EpShutdown(); -// Host-only: register an EP layer. Returns (handle_id, handle_mem_size) where -// handle_id is baked into each FFI op as a static int64 attribute (no D2H sync -// per op) and handle_mem_size sizes the caller's handle_mem buffer. +// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms); +void ReleaseEpResources(); +// Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); +// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type. +pybind11::capsule GetEpInstanceStateTypeIdCapsule(); +pybind11::capsule GetEpInstanceStateTypeInfoCapsule(); + // EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index e2c50135aa..5dc05de0ae 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -10,8 +10,10 @@ #include +#include #include #include +#include #include #include "../extensions.h" @@ -21,52 +23,85 @@ namespace transformer_engine { namespace jax { -namespace { +// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI. + +struct EpBootstrapParams { + std::array uid_bytes{}; + int ep_size = 0; + int rank_within_group = 0; + int num_experts = 0; + int max_tokens_per_rank = 0; + int max_recv_tokens_per_rank = 0; + int hidden_dim = 0; + int max_num_sms = 0; +}; -// Process-lifetime owner of the EP ncclComm_t. Created from a broadcast -// ncclUniqueId during EpInitialize; destroyed by EpShutdown (registered as a -// Python atexit hook from ep.py so it runs before C++ static destructors). -class EpCommManager { +class EpResources { public: - static EpCommManager& get() { - static EpCommManager inst; - return inst; - } - - void init_from_uid(const uint8_t* uid_bytes, int ep_size, int rank_within_group) { - std::lock_guard lock(mutex_); - NVTE_CHECK(comm_ == nullptr, "EP comm already initialized for this process"); + explicit EpResources(const EpBootstrapParams& p) { ncclUniqueId uid; - std::memcpy(&uid, uid_bytes, sizeof(uid)); - NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, ep_size, uid, rank_within_group)); + std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + .num_experts = p.num_experts, + .max_tokens_per_rank = p.max_tokens_per_rank, + .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, + .hidden_dim = p.hidden_dim, + .max_num_sms = p.max_num_sms}; + try { + nvte_ep_initialize(static_cast(comm_), cfg); + } catch (...) { + ncclCommDestroy(comm_); + comm_ = nullptr; + throw; + } } - ncclComm_t comm() const { return comm_; } - - void shutdown() { - std::lock_guard lock(mutex_); + ~EpResources() { if (comm_ == nullptr) return; + nvte_ep_shutdown(); ncclCommDestroy(comm_); - comm_ = nullptr; } + EpResources(const EpResources&) = delete; + EpResources& operator=(const EpResources&) = delete; + + ncclComm_t comm() const { return comm_; } + private: - EpCommManager() = default; - // Intentionally no NCCL teardown in the destructor: this runs at static-dtor - // time, after Python has finalized and possibly after the CUDA driver - // detaches the context. Calling ncclCommDestroy there has been observed to - // hang or report cudartUnloading. Normal teardown goes through the Python - // atexit hook (shutdown_ep_communicator) registered from ep.py; any path - // that skips that (os._exit, fatal signal) leaks the comm, which the OS - // reaps on process exit. - ~EpCommManager() = default; - EpCommManager(const EpCommManager&) = delete; - EpCommManager& operator=(const EpCommManager&) = delete; - - std::mutex mutex_; ncclComm_t comm_{nullptr}; }; +struct EpInstanceState { + static ::xla::ffi::TypeId id; + static ::xla::ffi::TypeInfo info; + std::shared_ptr resources; +}; + +::xla::ffi::TypeId EpInstanceState::id = {}; +::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo(); + +namespace { + +std::mutex g_ep_mu; +EpBootstrapParams g_ep_params; +bool g_ep_params_set = false; +std::weak_ptr g_ep_resources_weak; +// Python-held anchor so trace-time ep_register_layer finds EPBackend ready. +std::shared_ptr g_ep_resources_anchor; + +std::shared_ptr AcquireEpResources() { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(g_ep_params_set, + "EP bootstrap params not set; call transformer_engine_jax." + "set_ep_bootstrap_params() (typically via ep_bootstrap) first."); + auto sp = g_ep_resources_weak.lock(); + if (sp) return sp; + sp = std::make_shared(g_ep_params); + g_ep_resources_weak = sp; + return sp; +} + } // namespace // handle_id is baked at jit trace time and carried as a static FFI attribute. @@ -98,36 +133,44 @@ struct EpCombineBwdConfig { // ── Bootstrap helpers ───────────────────────────────────────────────────────── -void EpInitialize(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, - int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms) { +// Caches uid + group config and eagerly creates the NCCL comm (ranks +// synchronize via the UID broadcast). +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); - EpCommManager::get().init_from_uid(reinterpret_cast(uid_str.data()), ep_size, - rank_within_group); - NVTEEpGroupConfig cfg{.ep_size = ep_size, - .num_experts = num_experts, - .max_tokens_per_rank = max_tokens_per_rank, - .max_recv_tokens_per_rank = max_recv_tokens_per_rank, - .hidden_dim = hidden_dim, - .max_num_sms = max_num_sms}; - // If common rejects the config (validate_config / ncclEpCreateGroup), roll - // the comm back so the two singletons don't end up in inconsistent states - // and the comm doesn't strand until process exit. - try { - nvte_ep_initialize(static_cast(EpCommManager::get().comm()), cfg); - } catch (...) { - EpCommManager::get().shutdown(); - throw; + std::shared_ptr anchor; + { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(!g_ep_resources_anchor, + "EP bootstrap already initialized; call release_ep_resources() before re-init."); + std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128); + g_ep_params.ep_size = ep_size; + g_ep_params.rank_within_group = rank_within_group; + g_ep_params.num_experts = num_experts; + g_ep_params.max_tokens_per_rank = max_tokens_per_rank; + g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; + g_ep_params.hidden_dim = hidden_dim; + g_ep_params.max_num_sms = max_num_sms; + g_ep_params_set = true; } + // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is + // a collective and may block on peer ranks. + anchor = AcquireEpResources(); + std::lock_guard lock(g_ep_mu); + g_ep_resources_anchor = std::move(anchor); } -void EpShutdown() { - // Order matters: ep_group_ in common reads from the comm, so tear it down - // first, then destroy the comm. - nvte_ep_shutdown(); - EpCommManager::get().shutdown(); +// Drops the anchor; comm tears down once the last executable also releases. +void ReleaseEpResources() { + std::shared_ptr to_drop; + { + std::lock_guard lock(g_ep_mu); + to_drop = std::move(g_ep_resources_anchor); + } + // to_drop dtor runs outside the lock. } pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment) { @@ -137,10 +180,35 @@ pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_ali return pybind11::make_tuple(handle_id, handle_mem_size); } +pybind11::capsule GetEpInstanceStateTypeIdCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id"); +} + +pybind11::capsule GetEpInstanceStateTypeInfoCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info"); +} + +// ── Instantiate handler ───────────────────────────────────────────────────── + +static ::xla::ffi::ErrorOr> EpInstantiateImpl() { + auto state = std::make_unique(); + try { + state->resources = AcquireEpResources(); + } catch (const std::exception& e) { + return ::xla::ffi::Unexpected( + ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what())); + } + return state; +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate()); + // ── ep_prepare ──────────────────────────────────────────────────────────────── -Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, - Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { +Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, + Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + EpPrepareConfig config) { + (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); @@ -178,20 +246,22 @@ Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type t XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, FFI::Bind() - .Ctx() // stream - .Arg() // topk_idx - .Ret() // token_counts - .Ret() // handle_mem - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch ─────────────────────────────────────────────────────────────── -Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type topk_idx, - Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, - Result_Type recv_topk_weights, Result_Type workspace, - EpDispatchConfig config) { +Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, + Result_Type recv_tokens, Result_Type recv_topk_weights, + Result_Type workspace, EpDispatchConfig config) { + (void)ep_state; auto token_dims = tokens.dimensions(); NVTE_CHECK(token_dims.size() >= 2, "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); @@ -264,21 +334,23 @@ Error_Type EpDispatchFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Typ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // topk_idx - .Arg() // tokens - .Arg() // topk_weights - .Ret() // recv_tokens - .Ret() // recv_topk_weights - .Ret() // workspace (FFI scratch) + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine ──────────────────────────────────────────────────────────────── -Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type expert_out, - Result_Type result, EpCombineConfig config) { +Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type expert_out, Result_Type result, EpCombineConfig config) { + (void)ep_state; auto eo_dims = expert_out.dimensions(); NVTE_CHECK(eo_dims.size() >= 2, "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); @@ -311,18 +383,21 @@ Error_Type EpCombineFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // expert_out - .Ret() // result + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── -Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, - Result_Type grad_topk_weights, EpDispatchBwdConfig config) { +Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Buffer_Type g_recv_topk_weights, + Result_Type grad_tokens, Result_Type grad_topk_weights, + EpDispatchBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); @@ -380,19 +455,22 @@ Error_Type EpDispatchBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. recv_tokens) - .Arg() // g_recv_topk_weights - .Ret() // grad_tokens - .Ret() // grad_topk_weights + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine_bwd ──────────────────────────────────────────────────────────── -Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_Type grad, - Result_Type grad_expert_out, EpCombineBwdConfig config) { +Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Result_Type grad_expert_out, + EpCombineBwdConfig config) { + (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); @@ -424,10 +502,11 @@ Error_Type EpCombineBwdFFI(cudaStream_t stream, Buffer_Type handle_mem, Buffer_T XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, FFI::Bind() - .Ctx() // stream - .Arg() // handle_mem - .Arg() // grad (w.r.t. result) - .Ret() // grad_expert_out + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out .Attrs(), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b34f8739ee..0304f37691 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -102,12 +102,22 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); #ifdef NVTE_WITH_NCCL_EP - // Expert Parallelism - dict["te_ep_prepare_ffi"] = EncapsulateFFI(EpPrepareHandler); - dict["te_ep_dispatch_ffi"] = EncapsulateFFI(EpDispatchHandler); - dict["te_ep_combine_ffi"] = EncapsulateFFI(EpCombineHandler); - dict["te_ep_dispatch_bwd_ffi"] = EncapsulateFFI(EpDispatchBwdHandler); - dict["te_ep_combine_bwd_ffi"] = EncapsulateFFI(EpCombineBwdHandler); + // Expert Parallelism (instantiate handler pins NCCL comm to executable lifetime). + dict["te_ep_prepare_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpPrepareHandler)); + dict["te_ep_dispatch_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchHandler)); + dict["te_ep_combine_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineHandler)); + dict["te_ep_dispatch_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchBwdHandler)); + dict["te_ep_combine_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineBwdHandler)); #endif // NVTE_WITH_NCCL_EP // TopK @@ -137,13 +147,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); #ifdef NVTE_WITH_NCCL_EP - m.def("initialize_ep_communicator", &EpInitialize, pybind11::arg("unique_id_bytes"), + m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); - m.def("shutdown_ep_communicator", &EpShutdown); + m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); + m.def("get_ep_instance_state_type_id", &GetEpInstanceStateTypeIdCapsule); + m.def("get_ep_instance_state_type_info", &GetEpInstanceStateTypeInfoCapsule); #endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 40d07bc3d4..d2850defaf 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -100,7 +100,8 @@ def ep_bootstrap( f" '{ep_resource}' size ({mesh_ep_size})." ) - transformer_engine_jax.initialize_ep_communicator( + # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. + transformer_engine_jax.set_ep_bootstrap_params( uid_bytes, ep_size, rank_within_group, @@ -111,18 +112,10 @@ def ep_bootstrap( max_num_sms=int(max_num_sms), ) - # Shutdown ordering: - # - Python atexit is LIFO. ep_bootstrap runs jmu.process_allgather first, - # which assumes jax.distributed.initialize() ran earlier, so JAX's - # distributed atexit hooks are already registered before this one. Ours - # therefore fires first at exit — fine, because EpShutdown only touches - # NCCL (ncclEpGroupDestroy + ncclCommDestroy) and does not depend on - # JAX's coordination service. Do not add JAX calls to EpShutdown. - # - Running before C++ static destructors avoids the cudartUnloading - # hazard; the C++ destructors are intentionally no-ops. + # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. global _atexit_registered if not _atexit_registered: - atexit.register(transformer_engine_jax.shutdown_ep_communicator) + atexit.register(transformer_engine_jax.release_ep_resources) _atexit_registered = True tex.ep.set_ep_config( From cb44374b9f329718eab090f75dfd61284eb1b3d8 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 23 May 2026 20:02:43 +0000 Subject: [PATCH 22/36] JAX EP: expose allow_handle_mem_reloc as opt-in ep_bootstrap parameter Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 2 ++ tests/jax/test_multi_process_ep.py | 2 ++ transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/ep.cpp | 7 +++++-- transformer_engine/jax/csrc/extensions/pybind.cpp | 3 ++- transformer_engine/jax/ep.py | 7 +++++++ 6 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 8dcac02a04..b2f48a6ad3 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -288,6 +288,8 @@ def main(): max_tokens_per_rank=args.num_tokens, recv_capacity_per_rank=args.recv_capacity_per_rank, hidden_dim=args.hidden, + # XLA reallocates handle_mem between JIT executables. + allow_handle_mem_reloc=True, ) ( diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 0658ad9750..7d070fb353 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -122,6 +122,8 @@ def setUpClass(cls): max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=cls.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, + # XLA reallocates handle_mem between JIT executables. + allow_handle_mem_reloc=True, ) # ── Bootstrap precondition ──────────────────────────────────────────── diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 62e762a5bb..9e64cf4d73 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -203,7 +203,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); // Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms); + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc); void ReleaseEpResources(); // Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 5dc05de0ae..39e2d8be3f 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -34,6 +34,7 @@ struct EpBootstrapParams { int max_recv_tokens_per_rank = 0; int hidden_dim = 0; int max_num_sms = 0; + int allow_handle_mem_reloc = 0; }; class EpResources { @@ -47,7 +48,8 @@ class EpResources { .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, - .max_num_sms = p.max_num_sms}; + .max_num_sms = p.max_num_sms, + .allow_handle_mem_reloc = p.allow_handle_mem_reloc}; try { nvte_ep_initialize(static_cast(comm_), cfg); } catch (...) { @@ -137,7 +139,7 @@ struct EpCombineBwdConfig { // synchronize via the UID broadcast). void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms) { + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); @@ -154,6 +156,7 @@ void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; g_ep_params.hidden_dim = hidden_dim; g_ep_params.max_num_sms = max_num_sms; + g_ep_params.allow_handle_mem_reloc = allow_handle_mem_reloc; g_ep_params_set = true; } // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0304f37691..aeca99510a 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -150,7 +150,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), - pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0); + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0, + pybind11::arg("allow_handle_mem_reloc") = 0); m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index d2850defaf..55b4ebec6c 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -42,10 +42,16 @@ def ep_bootstrap( recv_capacity_per_rank, hidden_dim, max_num_sms=0, + allow_handle_mem_reloc=False, ): """Initialize the EP communicator. Call once per process before any EP op. max_num_sms caps the SMs allotted to EP kernels (0 = auto). + + Set ``allow_handle_mem_reloc=True`` only if the caller cannot guarantee a + stable ``handle_mem`` device pointer across calls (e.g. XLA-managed + buffers reallocated between JIT executables). Default raises on + relocation so callers detect handle-aliasing bugs. """ if world_size < 2: raise ValueError( @@ -110,6 +116,7 @@ def ep_bootstrap( recv_capacity_per_rank, hidden_dim, max_num_sms=int(max_num_sms), + allow_handle_mem_reloc=int(bool(allow_handle_mem_reloc)), ) # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. From 2012b0ad716e1a37e10252220bbc302275a619b4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 28 May 2026 16:30:19 -0700 Subject: [PATCH 23/36] jax/ep: decorate EP ops with @compute_on("gpu_stream:collective") Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 7d112ad5f4..26d0291124 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi +from jax.experimental.compute_on import compute_on from jax.sharding import NamedSharding, PartitionSpec import jax.tree_util as jtu @@ -876,6 +877,7 @@ def shardy_sharding_rule(*args): _HANDLE_ID_CALLSITE_CACHE = {} +@compute_on("gpu_stream:collective") def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" import sys as _sys @@ -900,6 +902,7 @@ def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): return token_counts, EpHandle(handle_mem, handle_id) +@compute_on("gpu_stream:collective") def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" top_k = int(topk_weights.shape[-1]) @@ -916,6 +919,7 @@ def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_ra return recv_tokens, recv_topk_weights, handle +@compute_on("gpu_stream:collective") def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" out_leading = _normalize_leading_shape(num_local_tokens) @@ -928,6 +932,7 @@ def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None ) +@compute_on("gpu_stream:collective") def ep_dispatch_bwd( handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None ): @@ -944,6 +949,7 @@ def ep_dispatch_bwd( ) +@compute_on("gpu_stream:collective") def ep_combine_bwd(handle, grad, recv_capacity_per_rank): """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" return EpCombineBwdPrimitive.outer_primitive.bind( From c04bebb09ae3cd3e1446e251352c4dd89c75f795 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 May 2026 18:18:51 -0700 Subject: [PATCH 24/36] ep_bootstrap: add XLA-collective fallback for UID allgather Signed-off-by: Phuong Nguyen --- transformer_engine/jax/ep.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 55b4ebec6c..b0e404972e 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -30,6 +30,35 @@ _atexit_registered = False +def _allgather_uid(uid_arr, world_size, uid_size): + """Allgather UID bytes across all processes. + + Tries ``jax.experimental.multihost_utils.process_allgather`` first; + falls back to an XLA collective (process-local sharded global array + replicated via ``jax.jit``) when the multihost helper returns a + short buffer, which has been observed under some launchers. + """ + try: + gathered = jmu.process_allgather(uid_arr, tiled=True) + if gathered.size == world_size * uid_size: + return np.asarray(gathered).reshape(world_size, uid_size) + except Exception: # pylint: disable=broad-except + pass + devices = np.asarray(jax.devices()) + if devices.size != world_size: + raise RuntimeError( + f"_allgather_uid fallback expected {world_size} global devices," + f" got {devices.size}." + ) + mesh = jax.sharding.Mesh(devices, ("_uid_all",)) + sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + local = np.asarray(uid_arr).reshape(1, uid_size) + g_in = jax.make_array_from_process_local_data(sharded, local, (world_size, uid_size)) + g_out = jax.jit(lambda x: x, out_shardings=replicated)(g_in) + return np.asarray(g_out).reshape(world_size, uid_size) + + # ── Bootstrap ──────────────────────────────────────────────────────────────── @@ -90,7 +119,7 @@ def ep_bootstrap( uid_bytes = bytes(UID_SIZE) uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) - all_uids = jmu.process_allgather(uid_arr).reshape(world_size, UID_SIZE) + all_uids = _allgather_uid(uid_arr, world_size, UID_SIZE) uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) ep_resource = global_mesh_resource().ep_resource From 141558051ebaac7819fba9e6d7ccb4245e883f9d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 29 May 2026 12:25:06 -0700 Subject: [PATCH 25/36] jax/ep: introduce per-layer EpHandle, drop callsite-frame handle_id cache Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 15 +-- tests/jax/test_multi_process_ep.py | 122 ++++++++++++++------ transformer_engine/jax/cpp_extensions/ep.py | 107 ++++++++--------- transformer_engine/jax/ep.py | 107 +++++++---------- 4 files changed, 184 insertions(+), 167 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index b2f48a6ad3..dae3710526 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -14,7 +14,7 @@ import numpy as np from jax.sharding import Mesh, NamedSharding, PartitionSpec -from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine from transformer_engine.jax.sharding import MeshResource, global_shard_guard @@ -199,6 +199,7 @@ def _moe_step(args, topk_idx, tokens, topk_w, kernels): kernel_spec = PartitionSpec("ep", None, None, None) kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + ep_handle = ep_make_handle(args.top_k, dispatch_output_per_expert_alignment=16) @jax.jit def step(topk_idx, tokens, topk_w, local_kernels): @@ -208,20 +209,16 @@ def step(topk_idx, tokens, topk_w, local_kernels): local_kernels = jax.lax.with_sharding_constraint( local_kernels, NamedSharding(mesh, kernel_spec) ) - slots_per_expert = args.recv_capacity_per_rank // NLE - recv_tokens, recv_topk_w, handle, _tc = ep_dispatch( - topk_idx, - tokens, - topk_w, - args.recv_capacity_per_rank, - dispatch_output_per_expert_alignment=slots_per_expert, + recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( + ep_handle, topk_idx, tokens, topk_w, args.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) return ep_combine( - handle, + ep_handle, + handle_mem, _tc, expert_out, recv_topk_w, diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 7d070fb353..abdbcd32ec 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -29,7 +29,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from transformer_engine.jax.sharding import MeshResource, global_shard_guard -from transformer_engine.jax.ep import ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.ep import ep_bootstrap, ep_make_handle, ep_dispatch, ep_combine from transformer_engine.jax.cpp_extensions.ep import ( ep_prepare, ep_dispatch_fwd, @@ -125,6 +125,8 @@ def setUpClass(cls): # XLA reallocates handle_mem between JIT executables. allow_handle_mem_reloc=True, ) + # One handle key shared by all single-layer tests below. + cls.hk = ep_make_handle(TOP_K) # ── Bootstrap precondition ──────────────────────────────────────────── @@ -204,29 +206,76 @@ def _make_random_inputs(self, seed=42, nonuniform=True): # ── Individual primitives (cpp_extensions level) ────────────────────── - def test_two_prepares_distinct_handle_ids(self): - """Two ep_prepare sites with matching (top_k, alignment) must produce - distinct handle_ids — distinct logical layers cannot share a - HandleEntry. Verified by tracing through jit so the primitive's - outer_primitive.bind path is exercised.""" + def test_two_handles_distinct_ids(self): + """Two ``ep_make_handle`` calls must yield distinct ``handle_id``s; + distinct logical layers cannot share a HandleEntry. Verified through a + jit so each ``ep_prepare`` bind path is exercised.""" _T, topk_idx, _tokens, _w = self._make_identity_inputs() - captured: list = [] + ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K) dp_spec = PartitionSpec(("dp", "ep"), None) with self.mesh, global_shard_guard(self.mr): idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) @jax.jit def run(idx): - _tc_a, ha = ep_prepare(idx) - _tc_b, hb = ep_prepare(idx) - captured.append((ha.handle_id, hb.handle_id)) - return ha.handle_mem, hb.handle_mem + _tc_a, ha = ep_prepare(idx, ka) + _tc_b, hb = ep_prepare(idx, kb) + return ha, hb hm_a, hm_b = run(idx_s) hm_a.block_until_ready() hm_b.block_until_ready() - id_a, id_b = captured[0] - self.assertNotEqual(id_a, id_b, "two ep_prepare calls returned the same handle_id") + self.assertNotEqual(ka.handle_id, kb.handle_id) + + def test_two_layer_dispatch_no_handle_aliasing(self): + """Two ep_dispatch calls in one jit with distinct ``EpHandle``s must + not clobber each other's routing state. Different inputs per layer with + identity routing + uniform weights => both recv buffers must independently + identity-round-trip via ep_combine.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + tokens_b = (tokens.astype(jnp.float32) * -1.0 + 0.25).astype(tokens.dtype) + ka, kb = ep_make_handle(TOP_K), ep_make_handle(TOP_K) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + ta = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + tb = jax.lax.with_sharding_constraint(tokens_b, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + def one_layer(hk, idx, toks, w_): + recv_t, recv_w, hm, tc = ep_dispatch( + hk, idx, toks, w_, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) + return ep_combine( + hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + + @jax.jit + def run(idx, ta_, tb_, w_): + return one_layer(ka, idx, ta_, w_), one_layer(kb, idx, tb_, w_) + + out_a, out_b = run(idx_s, ta, tb, w) + out_a.block_until_ready() + out_b.block_until_ready() + out_a_g = jmu.process_allgather(out_a, tiled=True) + out_b_g = jmu.process_allgather(out_b, tiled=True) + + self.assertNotEqual(ka.handle_id, kb.handle_id) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_a_g.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) + np.testing.assert_allclose( + np.asarray(out_b_g.astype(jnp.float32)), + np.asarray(tokens_b.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) def test_primitive_prepare(self): """ep_prepare returns the expected shapes and a valid handle id.""" @@ -238,8 +287,8 @@ def test_primitive_prepare(self): @jax.jit def run(idx): - tc, handle = ep_prepare(idx) - return tc, handle.handle_mem + tc, hm = ep_prepare(idx, self.hk) + return tc, hm tc, hm = run(idx_s) tc.block_until_ready() @@ -260,9 +309,9 @@ def _run_identity_round_trip(self, nonuniform): @jax.jit def run(idx, toks, w): - _tc, handle = ep_prepare(idx) - recv_t, recv_w, handle = ep_dispatch_fwd( - handle, idx, toks, w, self.recv_capacity_per_rank + _tc, hm = ep_prepare(idx, self.hk) + recv_t, recv_w = ep_dispatch_fwd( + self.hk, hm, idx, toks, w, self.recv_capacity_per_rank ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) @@ -279,7 +328,8 @@ def run(idx, toks, w): weighted, NamedSharding(self.mesh, ep_spec_3d) ) out = ep_combine_fwd( - handle, weighted, T_global, out_partition_spec=(("dp", "ep"), None) + self.hk, hm, weighted, T_global, + out_partition_spec=(("dp", "ep"), None), ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -322,7 +372,7 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -330,7 +380,7 @@ def loss_fn(toks): recv_w, NamedSharding(self.mesh, ep_spec_2d) ) out = ep_combine( - handle, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) ) return 0.5 * (out.astype(jnp.float32) ** 2).sum() @@ -370,11 +420,12 @@ def test_dispatch_combine_3d_input_output(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( - handle, + self.hk, + hm, _tc, recv_t, recv_w, @@ -412,11 +463,12 @@ def test_dispatch_combine_dp_only_first_dim(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, handle, _tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) out = ep_combine( - handle, + self.hk, + hm, _tc, recv_t, recv_w, @@ -457,8 +509,8 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_tokens, _recv_w, _handle, _tc = ep_dispatch( - idx, toks, w, self.recv_capacity_per_rank + recv_tokens, _recv_w, _hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint( recv_tokens, NamedSharding(self.mesh, ep_spec_3d) @@ -503,13 +555,13 @@ def loss_fn(eo): toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - _recv_tokens, recv_w, handle, tc = ep_dispatch( - idx, toks, w, self.recv_capacity_per_rank + _recv_tokens, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank ) recv_w = jax.lax.with_sharding_constraint( recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) ) - combined = ep_combine(handle, tc, eo, recv_w, T_global) + combined = ep_combine(self.hk, hm, tc, eo, recv_w, T_global) # Pin combined to dp-sharded so autodiff transpose feeds # ep_combine_bwd a per-shard cotangent. combined = jax.lax.with_sharding_constraint( @@ -549,7 +601,7 @@ def loss_fn(idx_in, tok_in, w_in): tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) _recv_t, recv_w, _h, _tc = ep_dispatch( - idx_in, tok_in, w_in, self.recv_capacity_per_rank + self.hk, idx_in, tok_in, w_in, self.recv_capacity_per_rank ) # Per-slot index scale ⇒ each slot's contribution differs. scale = jnp.asarray( @@ -589,7 +641,7 @@ def run(idx, toks, w): idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -597,7 +649,7 @@ def run(idx, toks, w): recv_w, NamedSharding(self.mesh, ep_spec_2d) ) out = ep_combine( - handle, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -634,9 +686,9 @@ def fwd(eo, toks, idx, w): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - _rt, rw, handle, tc = ep_dispatch(idx, toks, w, self.recv_capacity_per_rank) + _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) - combined = ep_combine(handle, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 26d0291124..8fb0d90f8a 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -8,7 +8,7 @@ Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else ``ep_resource`` alone. - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first - dim may be sharded, with axis ∈ {ep, (dp, ep), dp, None}. Trailing dims + dim may be sharded, with axis in {ep, (dp, ep), dp, None}. Trailing dims must be replicated. ``dp`` alone gets ``ep`` folded in locally. - EpCombine output sharding comes from ``out_sharding`` or defaults to the compound ``(dp, ep)`` axis on the leading dim. @@ -21,7 +21,6 @@ from jax import dtypes, ffi from jax.experimental.compute_on import compute_on from jax.sharding import NamedSharding, PartitionSpec -import jax.tree_util as jtu import transformer_engine_jax from .base import BasePrimitive, register_primitive @@ -34,6 +33,7 @@ "get_ep_config", "get_ep_num_local_experts", "ep_allocate_handle_id", + "ep_make_handle", "ep_prepare", "ep_dispatch_fwd", "ep_combine_fwd", @@ -42,24 +42,6 @@ ] -# Routing-state container threaded through dispatch/combine/*_bwd. -@jtu.register_pytree_node_class -class EpHandle: - def __init__(self, handle_mem, handle_id): - self.handle_mem = handle_mem - self.handle_id = int(handle_id) - - def tree_flatten(self): - return (self.handle_mem,), (self.handle_id,) - - @classmethod - def tree_unflatten(cls, aux, children): - return cls(children[0], aux[0]) - - def __repr__(self): - return f"EpHandle(handle_id={self.handle_id})" - - # ── Module-level EP config ────────────────────────────────────────────────── @@ -101,11 +83,7 @@ def get_ep_num_local_experts() -> int: def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> int: - """Reserve a fresh handle_id for an EP layer. - - Distinct logical layers must each call this — sharing a handle_id across - layers corrupts the routing state, even when (top_k, alignment) match. - """ + """Low-level: reserve a fresh handle_id. Prefer ``ep_make_handle``.""" handle_id, handle_mem_size = transformer_engine_jax.ep_register_layer( int(top_k), int(dispatch_output_per_expert_alignment) ) @@ -114,6 +92,36 @@ def ep_allocate_handle_id(top_k: int, dispatch_output_per_expert_alignment: int return handle_id +@dataclass(frozen=True) +class EpHandle: + """Per-layer EP config + routing-slot identity. + + Carries static layer config and a ``handle_id`` that pins the C++ routing + slot across re-traces. Allocate via ``ep_make_handle``; distinct layers + must hold distinct handles. + """ + + handle_id: int + top_k: int + dispatch_output_per_expert_alignment: int = 0 + + +def ep_make_handle(top_k: int, dispatch_output_per_expert_alignment: int = 0) -> EpHandle: + """Allocate a per-layer EP handle. + + Call once per logical MoE layer at model init (outside ``jax.jit``), then + pass the same handle into every ``ep_dispatch`` / ``ep_combine`` for that + layer. The handle's ``handle_id`` survives re-traces, ``jax.checkpoint`` + rematerialization, and separate inference/training compilations. + """ + handle_id = ep_allocate_handle_id(top_k, dispatch_output_per_expert_alignment) + return EpHandle( + handle_id=handle_id, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + def _ep_handle_mem_size(handle_id: int) -> int: """Return the handle_mem byte size for an id from ep_allocate_handle_id.""" try: @@ -874,40 +882,23 @@ def shardy_sharding_rule(*args): # ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── -_HANDLE_ID_CALLSITE_CACHE = {} - - @compute_on("gpu_stream:collective") -def ep_prepare(topk_idx, dispatch_output_per_expert_alignment=0): - """Exchange routing metadata; return ``(token_counts, EpHandle)``.""" - import sys as _sys - - top_k = int(topk_idx.shape[-1]) - alignment = int(dispatch_output_per_expert_alignment) - # Cache handle_id by caller (file:lineno, top_k, alignment): JAX re-traces - # the same call site (e.g. custom_vjp fwd vs primal) and the resulting - # EpHandles must share the same id to compare equal in pytree aux. - f = _sys._getframe(1) - cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) - handle_id = _HANDLE_ID_CALLSITE_CACHE.get(cache_key) - if handle_id is None: - handle_id = ep_allocate_handle_id(top_k, alignment) - _HANDLE_ID_CALLSITE_CACHE[cache_key] = handle_id - token_counts, handle_mem = EpPreparePrimitive.outer_primitive.bind( +def ep_prepare(topk_idx, handle): + """Exchange routing metadata for ``handle``; return ``(token_counts, handle_mem)``.""" + return EpPreparePrimitive.outer_primitive.bind( topk_idx, - handle_id=handle_id, - dispatch_output_per_expert_alignment=alignment, + handle_id=handle.handle_id, + dispatch_output_per_expert_alignment=handle.dispatch_output_per_expert_alignment, is_outer=True, ) - return token_counts, EpHandle(handle_mem, handle_id) @compute_on("gpu_stream:collective") -def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): - """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights, handle).""" +def ep_dispatch_fwd(handle, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights).""" top_k = int(topk_weights.shape[-1]) - recv_tokens, recv_topk_weights = EpDispatchPrimitive.outer_primitive.bind( - handle.handle_mem, + return EpDispatchPrimitive.outer_primitive.bind( + handle_mem, topk_idx, tokens, topk_weights, @@ -916,15 +907,14 @@ def ep_dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_ra top_k=top_k, is_outer=True, ) - return recv_tokens, recv_topk_weights, handle @compute_on("gpu_stream:collective") -def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None): +def ep_combine_fwd(handle, handle_mem, expert_out, num_local_tokens, out_partition_spec=None): """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" out_leading = _normalize_leading_shape(num_local_tokens) return EpCombinePrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, expert_out, handle_id=handle.handle_id, out_leading_shape=out_leading, @@ -934,12 +924,13 @@ def ep_combine_fwd(handle, expert_out, num_local_tokens, out_partition_spec=None @compute_on("gpu_stream:collective") def ep_dispatch_bwd( - handle, grad, g_recv_topk_weights, top_k, num_local_tokens, out_partition_spec=None + handle, handle_mem, grad, g_recv_topk_weights, top_k, num_local_tokens, + out_partition_spec=None, ): """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" out_leading = _normalize_leading_shape(num_local_tokens) return EpDispatchBwdPrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, grad, g_recv_topk_weights, handle_id=handle.handle_id, @@ -950,10 +941,10 @@ def ep_dispatch_bwd( @compute_on("gpu_stream:collective") -def ep_combine_bwd(handle, grad, recv_capacity_per_rank): +def ep_combine_bwd(handle, handle_mem, grad, recv_capacity_per_rank): """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" return EpCombineBwdPrimitive.outer_primitive.bind( - handle.handle_mem, + handle_mem, grad, handle_id=handle.handle_id, recv_capacity_per_rank=recv_capacity_per_rank, diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index b0e404972e..62bd6691fd 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -14,14 +14,16 @@ import transformer_engine_jax import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.cpp_extensions.ep import EpHandle from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size ep_prepare = tex.ep_prepare +ep_make_handle = tex.ep_make_handle +EpHandle = tex.EpHandle __all__ = [ "EpHandle", "ep_bootstrap", + "ep_make_handle", "ep_prepare", "ep_dispatch", "ep_combine", @@ -171,64 +173,34 @@ def ep_bootstrap( # ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── -@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) -def ep_dispatch( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment=0, -): +@partial(jax.custom_vjp, nondiff_argnums=(0, 4)) +def ep_dispatch(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): """Scatter tokens and weights to expert ranks. - Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``. Only the leading dim may - be sharded — axis ∈ {ep, (dp, ep), dp, None}; trailing dims replicated. - - Args: - topk_idx: ``[..., top_k]`` int32/int64 routing indices. - tokens: ``[..., H]`` activations (matching leading dims). - topk_weights: ``[..., top_k]`` float32 routing weights. - recv_capacity_per_rank: STATIC int. Per-rank recv slot count. - dispatch_output_per_expert_alignment: STATIC int. Per-expert slot - alignment; 0 disables. - - Returns: - ``(recv_tokens, recv_topk_weights, handle, token_counts)`` where - ``recv_tokens`` is 3D ``[num_procs, recv_capacity_per_rank, H]`` - sharded ``(("dp","ep"), None, None)`` (or ``("ep", None, None)`` if - DP is unset), and ``recv_topk_weights`` is 2D - ``[num_procs, recv_capacity_per_rank]`` similarly sharded. Pass - ``handle`` to the matching ``ep_combine``. + ``handle`` is a per-layer ``EpHandle`` from ``ep_make_handle``; distinct + layers must hold distinct handles. Inputs are 2D ``[T, H]`` or 3D + ``[B, S, H]`` with only the leading dim sharded + (axis in {ep, (dp, ep), dp, None}). Returns + ``(recv_tokens, recv_topk_weights, handle_mem, token_counts)``; pass + ``handle_mem`` and ``token_counts`` to the matching ``ep_combine``. """ - return _dispatch_fwd( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment, - )[0] + return _dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank)[0] -def _dispatch_fwd( - topk_idx, - tokens, - topk_weights, - recv_capacity_per_rank, - dispatch_output_per_expert_alignment, -): +def _dispatch_fwd(handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank): top_k = int(topk_weights.shape[-1]) - token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment) - recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( - handle, topk_idx, tokens, topk_weights, recv_capacity_per_rank + token_counts, handle_mem = tex.ep_prepare(topk_idx, handle) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + handle, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank ) out_leading = tuple(tokens.shape[:-1]) - primal = (recv_tokens, recv_topk_weights, handle, token_counts) - return primal, (handle, out_leading, top_k) + primal = (recv_tokens, recv_topk_weights, handle_mem, token_counts) + return primal, (handle_mem, out_leading, top_k) -def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): - del recv_capacity_per_rank, dispatch_output_per_expert_alignment - handle, out_leading, top_k = res +def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs): + del recv_capacity_per_rank + handle_mem, out_leading, top_k = res # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a # single-fwd-output cotangent, landing a global tensor in the FFI. gsr = global_mesh_resource() @@ -242,7 +214,7 @@ def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, g_outputs[1], jax.sharding.PartitionSpec(leading, None) ) grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( - handle, g_recv_tokens, g_recv_topk_weights, top_k, out_leading + handle, handle_mem, g_recv_tokens, g_recv_topk_weights, top_k, out_leading ) return (None, grad_tokens, grad_topk_weights) @@ -253,31 +225,33 @@ def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, # ── ep_combine (custom_vjp) ────────────────────────────────────────────────── -@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) +@partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6)) def ep_combine( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding=None + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding=None, ): """Reduce weighted expert outputs back to source ranks. Args: - handle: ``EpHandle`` from a matching ``ep_dispatch`` call. + handle: ``EpHandle`` matching the ``ep_dispatch`` call. + handle_mem: Routing-state buffer returned by ``ep_dispatch``. token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights returned by ``ep_dispatch``. - num_local_tokens: STATIC int or tuple. int → 2D output ``[T, H]``; - tuple → N-D output ``[*tuple, H]``. + num_local_tokens: STATIC int or tuple. int -> 2D output ``[T, H]``; + tuple -> N-D output ``[*tuple, H]``. out_sharding: STATIC optional ``PartitionSpec`` tuple for the output. Defaults to ``(("dp","ep"), *None)`` when - DP is set, else ``("ep", *None)``. Pass a custom - spec to override; only the leading dim may be - sharded. + DP is set, else ``("ep", *None)``. Only the leading + dim may be sharded. Returns: ``[..., H]`` combined output shaped per ``num_local_tokens``. """ return _combine_fwd( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, )[0] @@ -287,18 +261,21 @@ def _make_valid_mask(recv_topk_weights, dtype): def _combine_fwd( - handle, token_counts, expert_out, recv_topk_weights, num_local_tokens, out_sharding + handle, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, ): del token_counts w = recv_topk_weights[..., None] mask = _make_valid_mask(recv_topk_weights, jnp.float32) weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) - result = tex.ep_combine_fwd(handle, weighted, num_local_tokens, out_partition_spec=out_sharding) - return result, (handle, recv_topk_weights, expert_out) + result = tex.ep_combine_fwd( + handle, handle_mem, weighted, num_local_tokens, out_partition_spec=out_sharding + ) + return result, (handle_mem, recv_topk_weights, expert_out) -def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): - handle, recv_topk_weights, expert_out = res +def _combine_bwd(handle, _num_local_tokens, _out_sharding, res, g_result): + handle_mem, recv_topk_weights, expert_out = res # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. recv_capacity_per_rank = expert_out.shape[-2] # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. @@ -316,7 +293,7 @@ def _combine_bwd(_num_local_tokens, _out_sharding, res, g_result): ) if spec is not None: g_result = jax.lax.with_sharding_constraint(g_result, spec) - grad_weighted = tex.ep_combine_bwd(handle, g_result, recv_capacity_per_rank) + grad_weighted = tex.ep_combine_bwd(handle, handle_mem, g_result, recv_capacity_per_rank) w = recv_topk_weights[..., None] mask = _make_valid_mask(recv_topk_weights, jnp.float32) grad_weighted_f32 = grad_weighted.astype(jnp.float32) From 0eee8b8fc82a2b50a8542477f008478a3dfa5e88 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:26:29 -0700 Subject: [PATCH 26/36] [JAX] EP: wire NVTEEpGroupConfig.max_token_dtype through bootstrap PR #3034 commit 9b225cbe added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR #3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR #3036 exposes the field upstream. See PR #3034 (commit 9b225cbe, ep.h:43) for the field definition. --- transformer_engine/jax/csrc/extensions.h | 3 +- transformer_engine/jax/csrc/extensions/ep.cpp | 8 +++-- .../jax/csrc/extensions/pybind.cpp | 2 +- transformer_engine/jax/ep.py | 31 +++++++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 9e64cf4d73..d6392819c0 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -203,7 +203,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); // Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms, int allow_handle_mem_reloc); + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc, + int max_token_dtype); void ReleaseEpResources(); // Register an EP layer; returns (handle_id, handle_mem_size). pybind11::tuple EpRegisterLayer(int top_k, size_t dispatch_output_per_expert_alignment); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 39e2d8be3f..84f24d75bf 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -35,6 +35,7 @@ struct EpBootstrapParams { int hidden_dim = 0; int max_num_sms = 0; int allow_handle_mem_reloc = 0; + int max_token_dtype = 0; }; class EpResources { @@ -49,7 +50,8 @@ class EpResources { .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, .max_num_sms = p.max_num_sms, - .allow_handle_mem_reloc = p.allow_handle_mem_reloc}; + .allow_handle_mem_reloc = p.allow_handle_mem_reloc, + .max_token_dtype = static_cast(p.max_token_dtype)}; try { nvte_ep_initialize(static_cast(comm_), cfg); } catch (...) { @@ -139,7 +141,8 @@ struct EpCombineBwdConfig { // synchronize via the UID broadcast). void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, - int hidden_dim, int max_num_sms, int allow_handle_mem_reloc) { + int hidden_dim, int max_num_sms, int allow_handle_mem_reloc, + int max_token_dtype) { std::string uid_str = unique_id_bytes_obj; NVTE_CHECK(static_cast(uid_str.size()) >= 128, "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); @@ -157,6 +160,7 @@ void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int g_ep_params.hidden_dim = hidden_dim; g_ep_params.max_num_sms = max_num_sms; g_ep_params.allow_handle_mem_reloc = allow_handle_mem_reloc; + g_ep_params.max_token_dtype = max_token_dtype; g_ep_params_set = true; } // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index aeca99510a..6020a228e3 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -151,7 +151,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms") = 0, - pybind11::arg("allow_handle_mem_reloc") = 0); + pybind11::arg("allow_handle_mem_reloc") = 0, pybind11::arg("max_token_dtype") = 0); m.def("release_ep_resources", &ReleaseEpResources); m.def("ep_register_layer", &EpRegisterLayer, pybind11::arg("top_k"), pybind11::arg("dispatch_output_per_expert_alignment") = 0); diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 62bd6691fd..17d00bef87 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -64,6 +64,30 @@ def _allgather_uid(uid_arr, world_size, uid_size): # ── Bootstrap ──────────────────────────────────────────────────────────────── +_TE_DTYPE_FOR_NUMPY = { + np.dtype(np.uint8): transformer_engine_jax.DType.kByte, + np.dtype(np.int32): transformer_engine_jax.DType.kInt32, + np.dtype(np.int64): transformer_engine_jax.DType.kInt64, + np.dtype(np.float32): transformer_engine_jax.DType.kFloat32, + np.dtype(np.float16): transformer_engine_jax.DType.kFloat16, +} + + +def _to_te_dtype_int(dtype): + """Map jax/numpy dtype -> NVTEDType int. bf16 / fp8 / fp4 handled explicitly.""" + if dtype is None: + return int(transformer_engine_jax.DType.kByte) + if dtype == jnp.bfloat16: + return int(transformer_engine_jax.DType.kBFloat16) + np_dtype = np.dtype(dtype) + if np_dtype in _TE_DTYPE_FOR_NUMPY: + return int(_TE_DTYPE_FOR_NUMPY[np_dtype]) + raise ValueError( + f"ep_bootstrap: unsupported max_token_dtype={dtype!r}; supported = " + "uint8 / int32 / int64 / float32 / float16 / bfloat16." + ) + + def ep_bootstrap( world_size, rank, @@ -74,6 +98,7 @@ def ep_bootstrap( hidden_dim, max_num_sms=0, allow_handle_mem_reloc=False, + max_token_dtype=None, ): """Initialize the EP communicator. Call once per process before any EP op. @@ -83,6 +108,11 @@ def ep_bootstrap( stable ``handle_mem`` device pointer across calls (e.g. XLA-managed buffers reallocated between JIT executables). Default raises on relocation so callers detect handle-aliasing bugs. + + ``max_token_dtype`` is the widest token dtype the group will dispatch + (sizes NCCL EP staging buffers at group create). Pass a jax/numpy + dtype, e.g. ``jnp.bfloat16``. Default ``None`` keeps the legacy ``kByte`` + behavior, which only accepts 1-byte tensors. """ if world_size < 2: raise ValueError( @@ -148,6 +178,7 @@ def ep_bootstrap( hidden_dim, max_num_sms=int(max_num_sms), allow_handle_mem_reloc=int(bool(allow_handle_mem_reloc)), + max_token_dtype=_to_te_dtype_int(max_token_dtype), ) # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. From 10f4b1c7d833f975b5ba2f9bf47a521a8f9e77a8 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 2 Jun 2026 16:19:50 -0700 Subject: [PATCH 27/36] [JAX] MoE: enforce (outer_dp, ep) ordering for TE EP compatibility [JAX] MoE: soft re-pin inbound activations sharding at moe() entry [JAX] MoE: scope gate_logits 2D reshape to topk primitive call [JAX] MoE: add apply_topk_weights_early flag (TE EP backend only) [JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat Signed-off-by: tdophung --- transformer_engine/jax/flax/moe.py | 3 + transformer_engine/jax/moe.py | 150 ++++++++++++++++++++--------- 2 files changed, 105 insertions(+), 48 deletions(-) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 91346a7a48..b5a4afc2ad 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -147,6 +147,8 @@ class _MoEBlock(TransformerEngineBase): permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX _align_size: int = 0 + apply_topk_weights_early: bool = False + # Dtypes / init / misc dtype: DType = jnp.float32 kernel_init: Optional[Initializer] = None @@ -273,6 +275,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: permutation_backend=self.permutation_backend, align_size=self._align_size, gate_inside_vjp=True, + apply_topk_weights_early=self.apply_topk_weights_early, ep_axis=ep_axis, data_parallelism_axes=self.data_parallelism_axes, input_axes=self.input_axes, diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 2a1c818cb3..4479b9f176 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -53,11 +53,12 @@ from enum import Enum from functools import partial from typing import Any, NewType, Optional, Tuple, Union +import warnings import jax import jax.numpy as jnp from flax import struct as flax_struct -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec as P from . import cpp_extensions as tex from .permutation import ( @@ -212,7 +213,7 @@ class _BodyCtx: routing_map: Any dispatch: Any # _DispatchState casted_sorted_x_lhs_trans: Any - casted_wi_rhs_trans: Any # combined [E, H, 2M] residual for fused wi_0|wi_1 bwd + casted_wi_rhs_trans: Any # stacked [E, H, 2, M] residual for fused wi_0|wi_1 bwd gate_proj_out: Any up_proj_out: Any casted_intermediate_lhs_trans: Any @@ -966,12 +967,20 @@ def _body_fwd( # pylint: disable=unused-argument num_ep: int, num_experts_local: int, recv_buffer_rows: int, + apply_topk_weights_early: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. ``aux_loss`` is always materialized (zeros scalar when disabled) so the ``shard_map``'s ``out_specs`` has a static structure. """ + if apply_topk_weights_early: + # Requires row-aligned per-token weights at the FFN intermediate; + # only available on the TE EP (tex.ep_dispatch) path. + raise NotImplementedError( + "apply_topk_weights_early=True is supported only with the TE EP " + "(tex.ep_dispatch / tex.ep_combine) backend." + ) if not gate_inside_vjp: raise NotImplementedError( "gate_inside_vjp=False is deferred to a follow-up PR; for now" @@ -992,7 +1001,8 @@ def _body_fwd( # pylint: disable=unused-argument # ---------------- Stage 1: gate ---------------- gate_kernel_cast = gate_kernel.astype(x.dtype) - gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) # [B, S, E] + # tex.fused_topk_with_score_function_* requires rank-2 input. logits_2d = gate_logits.reshape(-1, num_experts) inputs_2d = x.reshape(-1, hidden) @@ -1025,7 +1035,9 @@ def _body_fwd( # pylint: disable=unused-argument if aux_loss_coeff > 0.0: if ep_active: collective_axes: Any = ( - ep_axis if not data_parallelism_axes else (ep_axis, *data_parallelism_axes) + ep_axis + if not data_parallelism_axes + else (*data_parallelism_axes, ep_axis) ) global_logits_2d = jax.lax.all_gather( logits_2d, axis_name=collective_axes, axis=0, tiled=True @@ -1100,22 +1112,17 @@ def _body_fwd( # pylint: disable=unused-argument if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_x.dtype) - # GEMM 1+2 (fused): up_proj_combined = sorted_x @ wi where - # wi := concat([wi_0, wi_1], axis=-1) -> shape [E, H, 2M] - # combined_out := sorted_x @ wi -> shape [T, 2M] - # Splitting the output back into ``gate_proj_out`` / ``up_proj_out`` - # is free (it's a slicing reshape). This collapses two grouped - # GEMMs and two grouped quantizes of ``sorted_x`` (one per kernel) - # into one of each. Bias is concatenated the same way. + # Fused gate+up projection: stack wi_0 / wi_1 on a new axis-(-2) so the + # downstream split is a slice on the (unsharded) stack axis. concat on + # axis=-1 would cross the M axis and force a reshard when M is TP-sharded. # - # FP8/MXFP8 caveat: per-expert amax is now computed over [H, 2M] - # rather than [H, M] for each of wi_0 / wi_1 separately, so the - # representable range for one of the two halves may shift slightly - # vs. the pre-fusion code. Numerics tests cover this. + # FP8/MXFP8 caveat: per-expert amax is computed over [H, 2, M] rather than + # [H, M] for each of wi_0 / wi_1 separately, so the representable range for + # one half may shift slightly vs. an unfused pair of casts. inter_M = wi_0.shape[-1] - wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) + wi_combined = jnp.stack([wi_0, wi_1], axis=-2) wi_combined_bias = ( - jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None + jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None ) casted_sorted_x = tex.grouped_quantize(sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1) casted_wi = tex.grouped_quantize(wi_combined, q_set_w0.kernel, flatten_axis=-1) @@ -1125,8 +1132,8 @@ def _body_fwd( # pylint: disable=unused-argument contracting_dims=((1,), (1,)), bias=wi_combined_bias, ) - gate_proj_out = combined_out[..., :inter_M] - up_proj_out = combined_out[..., inter_M:] + gate_proj_out = combined_out[..., 0, :] + up_proj_out = combined_out[..., 1, :] casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): @@ -1253,15 +1260,21 @@ def _body_bwd( # pylint: disable=unused-argument has_wo_bias: bool, has_expert_bias: bool, x_shape: Tuple[int, ...], + apply_topk_weights_early: bool = False, ) -> dict: """Per-shard backward body. Returns a dict of grads keyed identically to the ``captured`` dict consumed by :func:`_body_fwd`.""" + if apply_topk_weights_early: + raise NotImplementedError( + "apply_topk_weights_early=True is supported only with the TE EP " + "(tex.ep_dispatch / tex.ep_combine) backend." + ) if not gate_inside_vjp: raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") d_output, d_aux_loss = dy_pair # The fused FFN bwd quantizes via ``q_set_w0`` only (one quantize for - # the [E, H, 2M] fused wi tensor and one for the [T, 2M] fused dgrad), + # the [E, H, 2, M] stacked wi tensor and one for the [T, 2, M] stacked dgrad), # so ``q_set_w1`` is intentionally unused here. q_set_w0, _q_set_w1, q_set_wo = quantizer_sets batch_size, sequence_length, hidden = x_shape @@ -1347,33 +1360,37 @@ def _body_bwd( # pylint: disable=unused-argument (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx.up_proj_out) # ---------------- FFN bwd: GEMM 1+2 fused (wi_0 | wi_1) ---------------- - # Concat the two upstream grads along the output (M) axis, do one - # grouped quantize + one dgrad GEMM + one wgrad GEMM, then split. - # ``ctx.casted_wi_rhs_trans`` has shape [E, H, 2M] from the fwd - # fused quantize, so the dgrad math is: + # Mirror of the fwd stack: combine d_gate / d_up on a new axis=-2, + # run one dgrad + one wgrad GEMM, then split on axis=-2. # d_sorted_x = [d_gate | d_up] @ wi_rhs_trans # = d_gate @ wi_0^T + d_up @ wi_1^T inter_M = d_gate_proj_out.shape[-1] - d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) + d_combined = jnp.stack([d_gate_proj_out, d_up_proj_out], axis=-2) casted_d_combined = tex.grouped_quantize( d_combined, q_set_w0.dgrad, ctx.local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), ctx.casted_wi_rhs_trans, - contracting_dims=((1,), (2,)), + contracting_dims=((1, 2), (2, 3)), ) d_wi_combined = tex.grouped_gemm( ctx.casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0 = d_wi_combined[..., :inter_M] - d_wi_1 = d_wi_combined[..., inter_M:] + d_wi_0 = d_wi_combined[..., 0, :] + d_wi_1 = d_wi_combined[..., 1, :] if has_wi_bias: - d_wi_combined_bias = tex.grouped_dbias(d_combined, ctx.local_group_sizes) - d_wi_0_bias = d_wi_combined_bias[..., :inter_M] - d_wi_1_bias = d_wi_combined_bias[..., inter_M:] + # grouped_dbias requires rank-2 input; reshape around the call. + # M is not TP-sharded on the bias path, so the reshape is free. + d_combined_2d = d_combined.reshape(d_combined.shape[0], -1) + d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, ctx.local_group_sizes) + d_wi_combined_bias = d_wi_combined_bias_2d.reshape( + *d_wi_combined_bias_2d.shape[:-1], 2, inter_M + ) + d_wi_0_bias = d_wi_combined_bias[..., 0, :] + d_wi_1_bias = d_wi_combined_bias[..., 1, :] else: d_wi_0_bias = None d_wi_1_bias = None @@ -1458,23 +1475,18 @@ def _body_bwd( # pylint: disable=unused-argument score_function=score_function, compute_aux_scores=True, ) - # Step 3: under EP the aux logits were all_gathered along - # ``(ep_axis, *data_parallelism_axes)`` (the latter being FSDP - # axes that shard the batch). The bwd is the inverse of that - # multi-axis tiled all_gather: ``dynamic_slice`` to pick out - # this shard's local rows from the global cotangent. - # - # JAX's convention for tiled ``all_gather(axis_name=(a, b, ...))`` - # is row-major over the tuple: the shard at mesh position - # ``(i_a, i_b, ...)`` writes to rows - # ``[(i_a * size_b * ... + i_b * ... + ...) * local_T : - # + local_T)``. We invert that by computing the same flat - # index here and slicing. + # Inverse of the fwd tiled all_gather along + # ``(*data_parallelism_axes, ep_axis)``: pick out this shard's + # local rows from the global cotangent. JAX's tiled all_gather + # is row-major over the axis-name tuple, so the shard at mesh + # position (i_a, i_b, ...) writes to a contiguous row block + # starting at flat_index * local_T. if ep_active: local_T_aux = ctx.logits_2d.shape[0] - flat_shard = shard_id # ep is the outermost axis in the gather tuple + flat_shard = 0 for ax, sz in zip(data_parallelism_axes, fsdp_sizes): flat_shard = flat_shard * sz + jax.lax.axis_index(ax) + flat_shard = flat_shard * num_ep + shard_id d_aux_logits_local = jax.lax.dynamic_slice( d_aux_logits.astype(ctx.logits_2d.dtype), start_indices=(flat_shard * local_T_aux, 0), @@ -1698,6 +1710,7 @@ def _moe_fwd_rule( # pylint: disable=unused-argument wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ): x = with_sharding_constraint_by_logical_axes(x, input_axes) ep_active = ep_axis is not None @@ -1718,6 +1731,7 @@ def _moe_fwd_rule( # pylint: disable=unused-argument "dtype": dtype, "ep_axis": ep_axis, "data_parallelism_axes": data_parallelism_axes, + "apply_topk_weights_early": apply_topk_weights_early, } captured: dict = { "inputs": x, @@ -1792,7 +1806,10 @@ def _moe_fwd_rule( # pylint: disable=unused-argument if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the + # comm only stays within one model replica under (outer_dp, ep). + batch_pspec_axis = (*data_parallelism_axes, ep_axis) dp_size = 1 for ax in data_parallelism_axes: dp_size *= mesh.shape[ax] @@ -1876,6 +1893,7 @@ def _moe_bwd_rule( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ctx, dy_pair, ): @@ -1917,6 +1935,7 @@ def _moe_bwd_rule( "has_wo_bias": has_wo_bias, "has_expert_bias": has_expert_bias, "x_shape": x_shape, + "apply_topk_weights_early": apply_topk_weights_early, } if not ep_active: @@ -1936,7 +1955,10 @@ def _moe_bwd_rule( if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the + # comm only stays within one model replica under (outer_dp, ep). + batch_pspec_axis = (*data_parallelism_axes, ep_axis) ctx_spec = _build_ctx_specs( ep_axis, batch_pspec_axis, @@ -1995,7 +2017,7 @@ def _grads_dict_to_tuple( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 29))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 30))) def _moe( x, gate_kernel, @@ -2026,6 +2048,7 @@ def _moe( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ): # Call in `_moe`'s own signature order to match what JAX will pass # the fwd rule via ``_argnums_partial``. See the comment block at @@ -2061,6 +2084,7 @@ def _moe( wo_kernel_axes, quantizer_sets, dtype, + apply_topk_weights_early, ) return output_pair @@ -2093,8 +2117,14 @@ def moe( # Permutation permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, align_size: int = 0, - # Gate placement (Phuong: "perhaps as an option") + # Gate placement gate_inside_vjp: bool = True, + # When True, fold per-token top-k weights into the FFN intermediate + # (next to act(gate)*up) instead of into the post-down-projection + # combine. Both placements are mathematically equivalent (down-proj is + # linear); the early placement gives XLA a chance to fuse the multiply + # with the activation. Off by default. + apply_topk_weights_early: bool = False, # Parallelism (resolved by caller from MeshResource) ep_axis: Optional[str] = None, data_parallelism_axes: Tuple[str, ...] = (), @@ -2129,6 +2159,29 @@ def moe( # we bypass also normalizes here. score_function = _validate_score_function(score_function) + # Enforce ((outer_dp..., ep), None, None) on inbound activations. The + # EP comm groups consecutive global ranks (dp_color = rank // ep_size), + # so ep MUST be innermost in the partition spec. Soft re-pin: free if + # upstream already matches, single reshard otherwise. + if ep_axis is not None: + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + expected_leading: Any = ( + (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + ) + expected_spec = P(expected_leading, None, None) + actual_spec = getattr(getattr(x, "sharding", None), "spec", None) + if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): + warnings.warn( + f"moe(...): inbound x sharding {actual_spec} does not match expected " + f"{expected_spec}; inserting a reshard. Apply " + "jax.lax.with_sharding_constraint upstream to avoid this overhead.", + UserWarning, + stacklevel=2, + ) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + output, aux_loss = _moe( x, gate_kernel, @@ -2159,6 +2212,7 @@ def moe( wo_kernel_axes=wo_kernel_axes, quantizer_sets=quantizer_sets, dtype=dtype, + apply_topk_weights_early=apply_topk_weights_early, ) if aux_loss_coeff <= 0.0: aux_loss = None From 9194fe3e0d7991f7f9f991e01970223fa2aaa48f Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 2 Jun 2026 17:59:03 -0700 Subject: [PATCH 28/36] integrate tex.* calls, remove all ragged-a2a + triton/pure jax step by step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks Signed-off-by: tdophung --- tests/jax/run_te_ep_moe.sh | 122 ++ tests/jax/test_te_ep_moe.py | 762 ++++++++ transformer_engine/jax/flax/moe.py | 109 +- transformer_engine/jax/moe.py | 2620 ++++++++-------------------- 4 files changed, 1678 insertions(+), 1935 deletions(-) create mode 100755 tests/jax/run_te_ep_moe.sh create mode 100644 tests/jax/test_te_ep_moe.py diff --git a/tests/jax/run_te_ep_moe.sh b/tests/jax/run_te_ep_moe.sh new file mode 100755 index 0000000000..32d5f21956 --- /dev/null +++ b/tests/jax/run_te_ep_moe.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Multiprocess (one-GPU-per-process) launcher for the TE-EP MoE custom_vjp +# test suite. Forks one pytest invocation per visible GPU, passing each +# its own --num-process=N --process-id=i, and waits for all of them. Each +# child calls jax.distributed.initialize(..., local_device_ids=process_id) +# so each Python process only sees its one GPU as a local device and the +# participating processes form a global (ep, fsdp) mesh. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +TEST_FILE="$TE_ROOT/tests/jax/test_te_ep_moe.py" +PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini" + +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}" +if [ "$NUM_GPUS" -lt 4 ]; then + echo "[run_te_ep_moe.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2 + exit 1 +fi + +export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}" +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}" +export TE_EP_MOE_COORDINATOR_ADDRESS="${TE_EP_MOE_COORDINATOR_ADDRESS:-127.0.0.1:13457}" + +echo "============================================================" +echo "TE-EP MoE MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)" +echo " test file : $TEST_FILE" +echo " coordinator : $TE_EP_MOE_COORDINATOR_ADDRESS" +echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE" +echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION" +echo "============================================================" + +if [ -n "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + LOG_DIR="$TE_EP_MOE_MP_LOG_DIR" + mkdir -p "$LOG_DIR" +else + LOG_DIR=$(mktemp -d -t te_ep_moe_mp_XXXXXX) +fi +echo "Per-process logs: $LOG_DIR" + +PIDS=() + +cleanup() { + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -TERM "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -KILL "$pid" 2>/dev/null || true + fi + done +} +trap cleanup EXIT INT TERM + +for i in $(seq 0 $((NUM_GPUS - 1))); do + LOG_FILE="$LOG_DIR/proc_${i}.log" + PYTEST_CMD=( + python3 -m pytest -c "$PYTEST_INI" + "$TEST_FILE" + -p no:typeguard + -v -s + --num-process="$NUM_GPUS" + --process-id="$i" + ) + if [ "$i" -eq 0 ]; then + echo "=== Live output from process 0 ===" + "${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" & + else + "${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 & + fi + PIDS+=("$!") +done + +EXITS=() +for pid in "${PIDS[@]}"; do + if wait "$pid"; then + EXITS+=("0") + else + EXITS+=("$?") + fi +done + +echo +echo "============================================================" +echo "Per-process exit codes:" +for i in "${!EXITS[@]}"; do + echo " proc $i -> ${EXITS[$i]}" +done + +# Treat exit 0 (pass) and exit 5 (pytest "no tests collected", which the +# file emits via pytest.skip(allow_module_level=True) on pre-Blackwell +# GPUs) as success. +FAILED=0 +for e in "${EXITS[@]}"; do + if [ "$e" != "0" ] && [ "$e" != "5" ]; then + FAILED=1 + break + fi +done + +echo +if [ "$FAILED" -eq 0 ]; then + echo "[run_te_ep_moe.sh] all processes PASSED" + if [ -z "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + rm -rf "$LOG_DIR" + fi + exit 0 +fi + +echo "[run_te_ep_moe.sh] at least one process FAILED" +echo " retaining logs at $LOG_DIR for diagnosis" +echo " process 0 tail:" +tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true +exit 1 diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py new file mode 100644 index 0000000000..cc878e0bd1 --- /dev/null +++ b/tests/jax/test_te_ep_moe.py @@ -0,0 +1,762 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-process (one-GPU-per-process) tests for the TE-EP MoE custom_vjp. + +The launcher ``tests/jax/run_te_ep_moe.sh`` forks one pytest process per +visible GPU (mirroring ``run_multiprocess_moe_vjp.sh``). Each process binds +to exactly one device via +``jax.distributed.initialize(..., local_device_ids=process_id)``; the +participating processes form a global ``(ep, fsdp)`` mesh through JAX's +distributed runtime. + +How to run +---------- + +You typically do NOT invoke pytest on this file directly -- use the +launcher, which passes ``--num-process=N --process-id=i`` to each +forked process. Driving it directly with only one process will skip +every test because :func:`jax.distributed.initialize` requires +multiple participants, and the TE EP NCCL primitives require at +least four ranks. + + bash tests/jax/run_te_ep_moe.sh + +What this suite covers +---------------------- + +This file is the TE-EP-only successor to ``test_moe_vjp.py`` and +``test_multiprocess_moe_vjp.py``. Each test exercises one MoE-block +run and bundles every check that single run supports — shape, dtype, +finiteness AND numerical parity vs a pure-JAX reference. Variations +on the block are pytest parametrize values rather than separate test +classes: + +* ``test_forward`` covers the forward across a curated set of + configurations (apply_topk_weights_early on/off, align_size=0/128, + softmax/sigmoid scoring, optional expert_bias). Each config asserts + shape, dtype, finiteness and numerical parity vs the reference in + one run. +* ``test_backward`` mirrors that for gradients. +* ``TestTeEpMoeAuxLoss`` covers the second return value end-to-end + (returned + parity + aux-only grad propagates to gate + combined + main+aux grads stay finite) in two consolidated tests. +* ``TestTeEpMoEBlockFlax`` exercises the Flax wrapper with the same + parity reference. +* ``TestZZZTeEpMoeBootstrap`` verifies the per-process NCCL bootstrap + rejects a mismatched signature. + +FP8 / MXFP8 recipes are deferred — the ``quantizer_sets`` plumbing +has not yet been re-wired across the TE-EP ``shard_map`` boundary +(see ``.pr3036-review/INTEGRATION_DESIGN.md``). +""" + +import os + +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.5") + +import sys +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from flax.linen import partitioning as nn_partitioning + + +def _init_distributed(num_process: int, process_id: int) -> bool: + """Initialize jax.distributed for this pytest process. + + Returns True on a real multi-process launch, False otherwise so + the module can fast-skip when pytest collects it without the + launcher. + """ + if num_process <= 1: + return False + coord = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457") + jax.distributed.initialize( + coordinator_address=coord, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + assert jax.local_device_count() == 1, "one GPU per process is required for TE EP" + assert ( + jax.device_count() == num_process + ), f"global device_count {jax.device_count()} != num_process {num_process}" + return True + + +def _read_mp_options(): + num = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + pid = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, a in enumerate(sys.argv): + if a.startswith("--num-process="): + num = int(a.split("=", 1)[1]) + elif a == "--num-process" and i + 1 < len(sys.argv): + num = int(sys.argv[i + 1]) + elif a.startswith("--process-id="): + pid = int(a.split("=", 1)[1]) + elif a == "--process-id" and i + 1 < len(sys.argv): + pid = int(sys.argv[i + 1]) + return num, pid + + +_MP_NUM_PROCESS, _MP_PROCESS_ID = _read_mp_options() +_MP_ACTIVE = _init_distributed(_MP_NUM_PROCESS, _MP_PROCESS_ID) + +if not _MP_ACTIVE: + pytest.skip( + "test_te_ep_moe.py requires the multiprocess launcher " + "(run_te_ep_moe.sh). Skipping.", + allow_module_level=True, + ) + +from transformer_engine_jax import get_device_compute_capability + +# Grouped GEMM in the MoE custom_vjp requires Blackwell (sm_100+). The +# TE EP NCCL primitives themselves need SM>=90, but the FFN body uses +# grouped_gemm, so the file as a whole gates on sm_100+. +if get_device_compute_capability(0) < 100: + pytest.skip( + "MoE TE EP tests require Blackwell (sm_100+) for grouped GEMM", + allow_module_level=True, + ) + +from transformer_engine.jax.flax import _MoEBlock as MoEBlock +from transformer_engine.jax.moe import moe +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ----------------------------------------------------------------------------- +# Mesh / shape config +# ----------------------------------------------------------------------------- + +EP_AXIS = "ep" +FSDP_AXIS = "fsdp" +EP_SIZE = 2 +assert ( + jax.device_count() % EP_SIZE == 0 +), f"device_count {jax.device_count()} must be divisible by EP_SIZE={EP_SIZE}" +FSDP_SIZE = jax.device_count() // EP_SIZE +NUM_DEVICES_REQUIRED = EP_SIZE * FSDP_SIZE + +LOGICAL_AXIS_RULES = ( + ("exp", EP_AXIS), + ("embed", FSDP_AXIS), + ("mlp", None), + ("batch", (EP_AXIS, FSDP_AXIS)), +) + +# Small shapes so the parity tests stay tight on bf16. The block still +# has all four ranks participating in dispatch/combine. +DTYPE = jnp.bfloat16 +BATCH = EP_SIZE * FSDP_SIZE * 2 # 8 on 4-GPU, 16 on 8-GPU +SEQ = 32 +HIDDEN = 64 +INTER = 128 +NUM_EXPERTS = 8 +TOPK = 2 + +# bf16 grouped_gemm + softmax-topk + ep all-to-all stack drifts ~1e-1 vs a +# fp32 numpy reference. Keep these tight enough to catch real bugs but +# loose enough to absorb expected bf16 rounding. +FWD_ATOL = 5e-2 +FWD_RTOL = 5e-2 +GRAD_FFN_ATOL = 1e-1 +GRAD_FFN_RTOL = 1e-1 +GRAD_GATE_ATOL = 5e-1 +GRAD_GATE_RTOL = 5e-1 + +# Two TE EP runs that should be bitwise-equal modulo XLA fusion order +# (align_size rounding, etc.). +TE_TO_TE_ATOL = 5e-3 +TE_TO_TE_RTOL = 5e-3 + +# Aux loss is computed in float32 from the SAME logits as the routing +# path. Numerical drift between TE-EP and the reference is dominated by +# the bf16-rounded softmax inside the topk kernel. +AUX_ATOL = 1e-3 +AUX_RTOL = 1e-3 + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def mesh(): + if jax.device_count() < NUM_DEVICES_REQUIRED: + pytest.skip( + f"Need >={NUM_DEVICES_REQUIRED} devices for ep={EP_SIZE} x fsdp={FSDP_SIZE};" + f" have {jax.device_count()}" + ) + # ``ep`` must be the inner axis: ``ep_bootstrap`` forms NCCL EP groups + # from consecutive global ranks via ``dp_color = rank // ep_size``, so + # only an (outer_fsdp, inner_ep) device layout groups ranks correctly. + devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE)) + return Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + + +# ----------------------------------------------------------------------------- +# Pure-JAX reference MoE (no EP). Mirrors the exact math of TE's fused +# router primitive (see tests/jax/test_fused_router.py for the same +# reference applied to the standalone router kernel): +# +# softmax + post-softmax (use_pre_softmax=False, the default): +# 1. top_k by raw logits +# 2. softmax over just the K selected logits (so weights sum to 1) +# +# sigmoid + optional expert_bias: +# 1. scores = sigmoid(logits) +# 2. top_k by (scores + expert_bias) [bias only steers selection] +# 3. weights = scores at top_k positions, normalized when K > 1 +# +# Then for both: +# * weights *= scaling_factor (we leave scaling_factor=1.0 in this +# suite, matching _make_block's default). +# * per-expert FFN: silu(layer_w0) * layer_w1 → wo. +# ----------------------------------------------------------------------------- + + +@partial( + jax.jit, + static_argnames=( + "num_experts", + "num_experts_per_tok", + "aux_loss_coeff", + "score_function", + ), +) +def _pure_jax_moe_reference( + x, + gate_kernel, + wi_0, + wi_1, + wo, + expert_bias=None, + *, + num_experts, + num_experts_per_tok, + aux_loss_coeff: float = 0.0, + score_function: str = "softmax", +): + B, S, H = x.shape + T = B * S + K = num_experts_per_tok + x_2d = x.reshape(T, H) + + gate_kernel_cast = gate_kernel.astype(x.dtype) + logits = (x_2d @ gate_kernel_cast).astype(jnp.float32) # [T, E] + + if score_function == "softmax": + # use_pre_softmax=False: topk on raw logits, then softmax over K. + top_logits, top_indices = jax.lax.top_k(logits, k=K) + weights = jax.nn.softmax(top_logits, axis=-1) # [T, K], sums to 1 + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits) # [T, E] + if expert_bias is not None and expert_bias.shape != (0,): + scores_for_routing = scores + expert_bias.astype(jnp.float32)[None, :] + _, top_indices = jax.lax.top_k(scores_for_routing, k=K) + weights = jnp.take_along_axis(scores, top_indices, axis=-1) + else: + weights, top_indices = jax.lax.top_k(scores, k=K) + # Sigmoid weights are normalized when K > 1 (matches the kernel). + if K > 1: + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + else: + raise ValueError(f"Unsupported score_function={score_function!r}") + + routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32) + routing_weights_full = routing_weights_full.at[ + jnp.arange(T)[:, None], top_indices + ].set(weights) + + # FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't + # change the math (wo is linear), so the reference is identical for + # both placements. + layer_w0 = jnp.einsum("th,ehm->tem", x_2d, wi_0) + layer_w1 = jnp.einsum("th,ehm->tem", x_2d, wi_1) + intermediate = jax.nn.silu(layer_w0.astype(jnp.float32)) * layer_w1.astype(jnp.float32) + intermediate = intermediate.astype(x.dtype) + expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H] + output_2d = jnp.einsum( + "te,teh->th", routing_weights_full.astype(x.dtype), expert_out + ) + output = output_2d.reshape(B, S, H).astype(x.dtype) + + if aux_loss_coeff > 0.0: + # tex.fused_moe_aux_loss formula (matches the same + # reference_aux_loss helper from test_fused_router.py). The + # "aux scores" use the same score_function but always with + # K-normalised sigmoid (when sigmoid) / plain softmax (when + # softmax) — see tex.fused_topk_with_score_function_fwd with + # compute_aux_scores=True. + if score_function == "softmax": + aux_scores = jax.nn.softmax(logits, axis=-1) + else: # sigmoid + aux_scores = jax.nn.sigmoid(logits) + if K > 1: + aux_scores = aux_scores / ( + aux_scores.sum(axis=-1, keepdims=True) + 1e-20 + ) + routing_map = (routing_weights_full > 0).astype(jnp.int32) + tokens_per_expert = jnp.sum(routing_map, axis=0) # [E] + sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E] + aux_loss = (num_experts * aux_loss_coeff / (K * (T**2))) * jnp.sum( + sum_probs_per_expert * tokens_per_expert.astype(jnp.float32) + ) + aux_loss = aux_loss.astype(x.dtype) + else: + aux_loss = jnp.zeros((), dtype=x.dtype) + return output, aux_loss + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_block( + *, + apply_topk_weights_early=False, + align_size=0, + aux_loss_coeff=0.0, + use_expert_bias=False, + score_function="softmax", + bias_init=None, +): + kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + data_parallelism_axes=(FSDP_AXIS,), + apply_topk_weights_early=apply_topk_weights_early, + align_size=align_size, + aux_loss_coeff=aux_loss_coeff, + use_expert_bias=use_expert_bias, + score_function=score_function, + dtype=DTYPE, + ) + # Custom bias_init lets tests inject a non-zero expert_bias without + # poking variables['params'] post-init. + if bias_init is not None: + kwargs["bias_init"] = bias_init + return MoEBlock(**kwargs) + + +def _strong_expert_bias_init(key, shape, dtype): + """Half +5, half -5 — large enough to force topk onto the +ve half.""" + del key + n = shape[0] + return jnp.concatenate( + [ + jnp.full((n // 2,), 5.0, dtype=dtype), + jnp.full((n - n // 2,), -5.0, dtype=dtype), + ] + ) + + +def _shard_inputs(x, mesh): + # Match the layout moe.py re-pins to: outer dp axes, then ep innermost. + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None)) + ) + + +def _ctx(mesh): + """Combined mesh + global_shard_guard + axis_rules context.""" + + class _Combo: + def __enter__(self_inner): + self_inner._m = mesh.__enter__() + self_inner._gs = global_shard_guard( + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + ) + self_inner._gs.__enter__() + self_inner._ar = nn_partitioning.axis_rules(LOGICAL_AXIS_RULES) + self_inner._ar.__enter__() + return self_inner._m + + def __exit__(self_inner, *args): + self_inner._ar.__exit__(*args) + self_inner._gs.__exit__(*args) + mesh.__exit__(*args) + + return _Combo() + + +def _init_apply(block, mesh, x, key): + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + variables = jax.jit(block.init)(key, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) + output, aux = jax.jit(block.apply)(variables, x_sh) + jax.block_until_ready(output) + return variables, output, aux + + +def _grad_step(block, variables, mesh, x, *, include_aux=False): + """Run jax.grad of mean(out^2) [+ aux if include_aux] vs params.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def loss_fn(variables, x): + output, aux = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if include_aux and aux is not None: + loss = loss + aux.astype(jnp.float32) + return loss + + grads = jax.jit(jax.grad(loss_fn))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _grad_aux_only(block, variables, mesh, x): + """Jit'd grad of just the aux loss scalar — proves it reaches the + gate even when no main-output contribution is present.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def aux_only(variables, x): + _, aux = block.apply(variables, x) + return aux.astype(jnp.float32) + + grads = jax.jit(jax.grad(aux_only))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _unwrap(x): + return x.value if hasattr(x, "value") else x + + +def _to_global_numpy(arr, mesh): + """Replicate a sharded JAX array onto every rank and return as numpy. + + Triggers an all-gather inside JIT. The resulting addressable_data(0) + contains the full global array on every process, so we can run the + pure-JAX reference and compare against it from any process. + """ + rep = NamedSharding(mesh, P()) + with mesh: + full = jax.jit(lambda a: jax.lax.with_sharding_constraint(a, rep))(arr) + full.block_until_ready() + return np.asarray(jax.device_get(full.addressable_data(0))) + + +def _params_global_numpy(variables, mesh): + """Pull every entry of variables['params'] to a replicated numpy array.""" + params = variables["params"] + return {name: _to_global_numpy(_unwrap(p), mesh) for name, p in params.items()} + + +def _make_inputs(key): + """Generate a globally-identical input tensor on every process.""" + return jax.random.normal(key, (BATCH, SEQ, HIDDEN), dtype=DTYPE) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +# Parametrize variants exercised by both the forward and the backward +# parity tests. Each config is one MoE-block configuration the suite +# wants covered; the test body checks shape, dtype, finiteness AND +# numerical parity vs the same pure-JAX reference (which understands +# the same set of knobs). +# ----------------------------------------------------------------------------- + +_CONFIGS = [ + pytest.param( + dict(score_function="softmax"), + id="softmax", + ), + pytest.param( + dict(score_function="softmax", apply_topk_weights_early=True), + id="softmax-topk-early", + ), + pytest.param( + dict(score_function="softmax", align_size=128), + id="softmax-align128", + ), + pytest.param( + dict(score_function="sigmoid"), + id="sigmoid", + ), + pytest.param( + dict(score_function="sigmoid", use_expert_bias=True), + id="sigmoid-bias-zero", + ), + pytest.param( + dict( + score_function="sigmoid", + use_expert_bias=True, + bias_init=_strong_expert_bias_init, + ), + id="sigmoid-bias-strong", + ), +] + + +def _reference_kwargs_from_config(config, params_np): + """Pick out the reference-relevant pieces of a parametrize config.""" + return dict( + score_function=config.get("score_function", "softmax"), + expert_bias=( + jnp.asarray(params_np["expert_bias"]) + if config.get("use_expert_bias", False) + else None + ), + ) + + +class TestTeEpMoeForward: + """Per-config forward correctness in a single run: shape, dtype, + finiteness AND numerical parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_forward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(0)) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) + + # Shape / dtype / finiteness (cheap; on the local shard). + assert output.shape == x.shape + assert output.dtype == x.dtype + out_local = np.asarray(jax.device_get(output.addressable_data(0))) + assert np.all(np.isfinite(out_local)), "output has NaN/Inf" + assert aux is None, "aux_loss should be None when aux_loss_coeff == 0" + + # Numerical parity (replicated global view -> single rank's numpy). + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + out_te_np = _to_global_numpy(output, mesh) + + out_ref, _ = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **_reference_kwargs_from_config(config, params_np), + ) + np.testing.assert_allclose( + out_te_np.astype(np.float32), + np.asarray(jax.device_get(out_ref)).astype(np.float32), + atol=FWD_ATOL, + rtol=FWD_RTOL, + err_msg=f"forward parity breach for config={config}", + ) + + +class TestTeEpMoeBackward: + """Per-config backward correctness in a single run: per-tensor + grads finite, non-zero AND parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_backward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(2)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(3)) + grads_te = _grad_step(block, variables, mesh, x) + + # Reference grads via jax.grad over the pure-JAX MoE with the + # same config. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + ref_kwargs = _reference_kwargs_from_config(config, params_np) + ref_expert_bias = ref_kwargs.pop("expert_bias") + + def loss_fn(params, x): + out, _ = _pure_jax_moe_reference( + x, + params["gate_kernel"], + params["wi_0"], + params["wi_1"], + params["wo"], + ref_expert_bias, + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **ref_kwargs, + ) + return jnp.mean(out.astype(jnp.float32) ** 2) + + grads_ref = jax.jit(jax.grad(loss_fn))( + {k: jnp.asarray(v) for k, v in params_np.items() if k != "expert_bias"}, + jnp.asarray(x_np), + ) + grads_ref_np = {k: np.asarray(jax.device_get(v)) for k, v in grads_ref.items()} + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + # Per-tensor: finite + non-zero + parity in one pass. + g_te = _to_global_numpy(_unwrap(grads_te["params"][name]), mesh) + assert np.all(np.isfinite(g_te)), f"{name} grad has NaN/Inf [config={config}]" + assert np.any(g_te != 0.0), f"{name} grad identically zero [config={config}]" + atol, rtol = ( + (GRAD_GATE_ATOL, GRAD_GATE_RTOL) + if name == "gate_kernel" + else (GRAD_FFN_ATOL, GRAD_FFN_RTOL) + ) + np.testing.assert_allclose( + g_te.astype(np.float32), + grads_ref_np[name].astype(np.float32), + atol=atol, + rtol=rtol, + err_msg=f"grad parity breach on {name} [config={config}]", + ) + + +class TestTeEpMoeAuxLoss: + """Aux-loss path. Consolidated into: + * ``test_aux_loss``: one run that checks the returned scalar's + shape / dtype / finiteness / magnitude AND numerical parity vs the + reference AND that the aux-only bwd propagates to gate_kernel. + * ``test_combined_loss_grads``: one run for joint main+aux bwd + finite + non-zero per tensor. + """ + + def test_aux_loss(self, mesh): + coeff = 1e-2 + block = _make_block(aux_loss_coeff=coeff) + x = _make_inputs(jax.random.PRNGKey(20)) + variables, _, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(21)) + + # Shape / dtype / finiteness / magnitude. + assert aux is not None, "aux_loss should be returned when coeff > 0" + assert aux.shape == (), f"aux_loss must be 0-d scalar, got {aux.shape}" + assert aux.dtype == DTYPE, f"aux_loss dtype {aux.dtype} != {DTYPE}" + aux_np = _to_global_numpy(aux, mesh) + assert np.isfinite(aux_np), "aux_loss is NaN/Inf" + assert abs(float(aux_np)) < 1e2, f"aux_loss looks unreasonable: {aux_np}" + + # Numerical parity vs the reference. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + _, aux_ref = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + aux_loss_coeff=coeff, + ) + np.testing.assert_allclose( + float(aux_np), + float(jax.device_get(aux_ref)), + atol=AUX_ATOL, + rtol=AUX_RTOL, + ) + + # Aux-only bwd must propagate to gate_kernel — proves the + # fused_moe_aux_loss_bwd → topk(compute_aux_scores)_bwd chain is + # wired. + aux_grads = _grad_aux_only(block, variables, mesh, x) + g_gate = np.asarray( + jax.device_get( + _unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0) + ) + ) + assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss" + assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel" + + def test_combined_loss_grads(self, mesh): + """Joint main + aux loss bwd: per-tensor finite + non-zero in + one pass.""" + block = _make_block(aux_loss_coeff=1e-2) + x = _make_inputs(jax.random.PRNGKey(22)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23)) + grads = _grad_step(block, variables, mesh, x, include_aux=True) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = np.asarray( + jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) + ) + assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux" + assert np.any(g_local != 0.0), f"{name} grad zero under main+aux" + + +class TestTeEpMoEBlockFlax: + """Flax wrapper end-to-end in one run: shape/dtype/finiteness on the + forward, numerical parity vs the same reference, and per-tensor + grad finiteness + non-zeroness.""" + + def test_init_apply_parity(self, mesh): + block = _make_block() + x = _make_inputs(jax.random.PRNGKey(12)) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(13)) + + assert aux is None + assert output.shape == x.shape + assert output.dtype == x.dtype + out_local = np.asarray(jax.device_get(output.addressable_data(0))) + assert np.all(np.isfinite(out_local)) + + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + out_te_np = _to_global_numpy(output, mesh) + out_ref, _ = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + ) + np.testing.assert_allclose( + out_te_np.astype(np.float32), + np.asarray(jax.device_get(out_ref)).astype(np.float32), + atol=FWD_ATOL, + rtol=FWD_RTOL, + ) + + grads = _grad_step(block, variables, mesh, x) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = np.asarray( + jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)) + ) + assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf" + assert np.any(g_local != 0.0), f"{name} grad zero" + + +# Keep the bootstrap-signature test last in the module (the "ZZZ" prefix +# ensures pytest's alphabetic class ordering picks it last): it +# intentionally mismatches the NCCL EP bootstrap signature, which +# permanently taints the per-process bootstrap cache for the rest of +# the file. +class TestZZZTeEpMoeBootstrap: + """Per-process NCCL bootstrap re-bootstrap rejection.""" + + def test_bootstrap_signature_mismatch_raises(self, mesh): + block_a = _make_block() + x_a = _make_inputs(jax.random.PRNGKey(14)) + _init_apply(block_a, mesh, x_a, jax.random.PRNGKey(15)) + + # Different hidden dim → different bootstrap signature. + bigger_hidden = HIDDEN * 2 + x_b = jax.random.normal( + jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE + ) + block_b = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + data_parallelism_axes=(FSDP_AXIS,), + dtype=DTYPE, + ) + with pytest.raises(ValueError, match="bootstrapped"): + _init_apply(block_b, mesh, x_b, jax.random.PRNGKey(17)) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index b5a4afc2ad..67b2f5dfdd 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -5,27 +5,20 @@ """Flax Linen MoE block for TransformerEngine JAX. This module exposes :class:`_MoEBlock`, an experimental Flax Linen layer -that is a thin wrapper around the framework-agnostic functional MoE entry -point :func:`transformer_engine.jax.moe.moe`. The wrapper's only job is -to: +that wraps the framework-agnostic functional MoE entry point +:func:`transformer_engine.jax.moe.moe`. The wrapper's only job is to: -1. Register the gate kernel, per-expert FFN kernels, and optional biases - as ``self.param`` slots (with the right +1. Register the gate kernel, per-expert FFN kernels, and optional FFN + biases as ``self.param`` slots (with the right :func:`flax.linen.with_logical_partitioning` annotations so JAX's sharding layer FSDPs the params correctly). 2. Resolve the EP axis name from the active :class:`transformer_engine.jax.sharding.MeshResource`. 3. Forward all knobs to :func:`moe`. -All routing, dispatch, FFN, combine, and aux-loss logic lives in -``moe.py`` under a *single* ``jax.custom_vjp`` so future fusions -(FP8-on-the-wire EP, fused ``ragged_all_to_all + grouped_gemm``, gate + -route + dispatch fusion) can land without touching this wrapper. - The class is intentionally underscore-prefixed; the public ``MoEBlock`` alias will be introduced once TE's NCCL-backed EP component (and the -recipe-driven alignment follow-up) stabilises (target: the TE release -following the 2.16 code freeze). +recipe-driven alignment follow-up) stabilises. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -37,8 +30,7 @@ # import P`` without a second jax.sharding import. from jax.sharding import PartitionSpec as P # noqa: F401 # pylint: disable=unused-import -from ..moe import PermutationBackend, moe -from ..quantize import noop_quantizer_set +from ..moe import moe from ..router import ScoreFunction from ..sharding import get_active_resource_axis from .module import TransformerEngineBase @@ -50,22 +42,19 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["PermutationBackend", "_MoEBlock"] +__all__ = ["_MoEBlock"] class _MoEBlock(TransformerEngineBase): """Experimental Flax MoE layer over TransformerEngine. See module docstring for the design (this class is a thin Flax - wrapper around :func:`transformer_engine.jax.moe.moe`). Constructor - knob set kept compatible with the previous bespoke implementation so - existing call sites need no changes. + wrapper around :func:`transformer_engine.jax.moe.moe`). Parameters ---------- num_experts : int - Total number of experts. Under EP this must be divisible by the - EP mesh axis size. + Total number of experts. Must be divisible by the EP mesh axis size. num_experts_per_tok : int Top-k value for routing. intermediate_size : int @@ -82,41 +71,30 @@ class _MoEBlock(TransformerEngineBase): Grouped top-k knobs (DeepSeek-style). ``None`` disables grouping. scaling_factor : float Multiplier on the routing weights. - use_expert_bias : bool - If ``True``, registers a per-expert routing bias (shape ``[E]``). - Only meaningful with ``score_function="sigmoid"``; the underlying - primitive validates the pairing. - aux_loss_coeff : float - If ``> 0``, return the MoE auxiliary load-balancing loss scalar - in addition to the main output. + + apply_topk_weights_early : bool + When True, fold per-token top-k weights into the FFN intermediate + (next to ``act(gate) * up``) instead of into the post-down-projection + combine. Both placements are mathematically equivalent (the down + projection is linear); the early placement gives XLA a chance to + fuse the multiply with the activation. gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, input_axes : Logical sharding axis tuples (consumed by Flax's :func:`with_logical_partitioning` and our internal :func:`with_sharding_constraint_by_logical_axes`). data_parallelism_axes : tuple[str, ...] - FSDP axes over which the input *batch* dim is sharded IN - ADDITION to the EP axis. Empty (default) means activations are - replicated across non-EP axes within an EP group; set e.g. - ``("fsdp",)`` for true FSDP-of-batch where each device owns a - unique slice of the batch. - permutation_backend : PermutationBackend - ``PURE_JAX`` (default) or ``TRITON``. - _align_size : int - Per-expert group-size alignment (``0`` disables; required > 0 - for quantized grouped GEMM). Internal knob; will be inferred - from the active quantization recipe in a follow-up PR. + FSDP axes over which the input *batch* dim is sharded IN ADDITION + to the EP axis. Empty (default) means activations are replicated + across non-EP axes within an EP group; set e.g. ``("fsdp",)`` for + true FSDP-of-batch where each device owns a unique slice of the + batch. dtype : jnp.dtype Compute / parameter dtype. - kernel_init, bias_init, expert_bias_init : Initializers. + kernel_init, bias_init : Initializers. use_bias : bool Register per-expert FFN biases. - - Quantization is currently configured via the standard TE autocast - context (``fp8_autocast``/``with_quantizer_set``); per-call - quantizer sets can also be passed through ``__call__``'s - ``quantizer_sets`` keyword once we stabilise the recipe pipeline. """ # Architecture @@ -131,8 +109,6 @@ class _MoEBlock(TransformerEngineBase): num_groups: Optional[int] = None group_topk: Optional[int] = None scaling_factor: float = 1.0 - use_expert_bias: bool = False - aux_loss_coeff: float = 0.0 # Sharding (logical axes) gate_kernel_axes: Tuple[Optional[str], ...] = () @@ -143,18 +119,27 @@ class _MoEBlock(TransformerEngineBase): # Parallelism data_parallelism_axes: Tuple[str, ...] = () - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX - _align_size: int = 0 + # Aux loss (global expert-load balancing). 0.0 disables; non-zero + # enables the second return value and routes its gradient back to + # the gate. + aux_loss_coeff: float = 0.0 + # Fusion knob apply_topk_weights_early: bool = False + # Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0 + # uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM + # tile alignment. + align_size: int = 0 + # Dtypes / init / misc dtype: DType = jnp.float32 kernel_init: Optional[Initializer] = None bias_init: Initializer = nn.initializers.zeros - expert_bias_init: Initializer = nn.initializers.zeros use_bias: bool = False + # Per-expert router bias added before the top-k. Only meaningful when + # score_function='sigmoid'. + use_expert_bias: bool = False def __post_init__(self): if self.kernel_init is None: @@ -165,11 +150,6 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", dtype=self.dtype ), ) - if not isinstance(self.permutation_backend, PermutationBackend): - raise TypeError( - "permutation_backend must be a PermutationBackend, got" - f" {self.permutation_backend!r}" - ) super().__post_init__() @nn.compact @@ -186,18 +166,17 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: output : jnp.ndarray ``[batch, sequence, hidden]``. aux_loss : Optional[jnp.ndarray] - Scalar load-balancing loss when ``aux_loss_coeff > 0``, - else ``None``. + 0-d scalar when ``aux_loss_coeff > 0``, ``None`` otherwise. """ assert ( inputs.ndim == 3 ), f"_MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" _, _, hidden_size = inputs.shape - # Param registrations -- must run OUTSIDE any JAX transform that + # Param registrations must run OUTSIDE any JAX transform that # alters the variable scope (e.g. shard_map). The functional - # ``moe(...)`` opens its own shard_map internally for the EP - # path, so registering params here is correct. + # ``moe(...)`` opens its own shard_map internally for the FFN + # body, so registering params here is correct. gate_kernel = self.param( "gate_kernel", nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), @@ -242,13 +221,14 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: (self.num_experts, hidden_size), self.dtype, ) + expert_bias = None if self.use_expert_bias: expert_bias = self.param( "expert_bias", - nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + nn.with_logical_partitioning(self.bias_init, ("exp",)), (self.num_experts,), - self.dtype, + jnp.float32, ) ep_axis = get_active_resource_axis("ep_resource") @@ -272,16 +252,13 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: group_topk=self.group_topk, scaling_factor=self.scaling_factor, aux_loss_coeff=self.aux_loss_coeff, - permutation_backend=self.permutation_backend, - align_size=self._align_size, - gate_inside_vjp=True, apply_topk_weights_early=self.apply_topk_weights_early, + align_size=self.align_size, ep_axis=ep_axis, data_parallelism_axes=self.data_parallelism_axes, input_axes=self.input_axes, gate_kernel_axes=self.gate_kernel_axes, wi_kernel_axes=self.wi_kernel_axes, wo_kernel_axes=self.wo_kernel_axes, - quantizer_sets=(noop_quantizer_set, noop_quantizer_set, noop_quantizer_set), dtype=self.dtype, ) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 4479b9f176..162ea8f7e5 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -1,77 +1,51 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Functional Mixture-of-Experts (MoE) entry point with a single fused VJP. - -This module exposes :func:`moe`, the framework-agnostic flat function that -implements an entire MoE block (gate -> top-k routing -> token dispatch -> -per-expert FFN -> token combine, plus optional expert parallelism via a -shard_map / ragged_all_to_all collective) under a *single* -``jax.custom_vjp``. It is the moral analog of -:func:`transformer_engine.jax.layernorm_mlp.layernorm_mlp` for MoE: one -custom_vjp boundary covers the whole block so future fusions (FP8 over the -EP wire, fused ``ragged_all_to_all + grouped_gemm``, gate+route+dispatch -fusion) can land without re-architecting the call site. - -Design rationale ----------------- - -The earlier MoE block (:class:`transformer_engine.jax.flax.moe._MoEBlock`) -composed many narrower custom_vjps -- one per :func:`grouped_dense`, one -per :func:`token_dispatch`, etc. Every nested custom_vjp is a place where -a quantized :class:`ScaledTensor` cannot survive (JAX requires custom_vjp -inputs / outputs to be plain ``jnp.ndarray`` ish pytrees). To enable -end-to-end FP8 flow -- in particular FP8 carried over the EP -ragged_all_to_all -- the dispatch's quantize, the a2a, the per-expert -FFN, the inverse a2a, and the combine all have to live inside the same -VJP. This file collapses them into one. - -Implementation conventions --------------------------- - -* No nested ``custom_vjp``. Every primitive's ``_fwd`` and ``_bwd`` is - called directly (e.g. :func:`tex.fused_topk_with_score_function_fwd` / - ``_bwd``, :func:`unpermute_with_mask_map`, - :func:`unpermute_bwd_with_merging_probs`, - :func:`sort_chunks_by_map(is_forward=False)`, - forward + reverse :func:`jax.lax.ragged_all_to_all`) so the outer - ``_moe_bwd_rule`` controls the bwd graph end-to-end without invoking - ``jax.vjp`` for re-linearization. -* The fwd/bwd context (``ctx``) is a plain ``dict`` whose keys depend on - the static configuration (permutation backend, EP active or not, - presence of biases, aux loss enabled). The ``_moe_fwd_rule`` builds a - matching ``ctx_specs`` dict in lockstep when opening the EP shard_map - so ``out_specs`` structurally matches the body's return. -* :func:`_dispatch` is the helper that wraps - ``permute -> a2a -> local_permute`` (forward); :func:`_combine` is its - inverse. Their ``_bwd`` siblings drive the inverse collectives in the - bwd rule. None of these helpers form a custom_vjp boundary. +"""Mixture-of-Experts (MoE) layer for TransformerEngine JAX. + +This module exposes :func:`moe`, a single fused MoE forward pass + bwd +built on top of TE's NCCL-backed Expert Parallelism primitives +(``tex.ep_dispatch`` / ``tex.ep_combine``). The block runs:: + + gate -> topk -> ep_dispatch -> per-expert FFN (grouped GEMMs) + -> ep_combine -> output + +under a single ``jax.custom_vjp`` so the routing, dispatch, FFN and +combine steps fuse cleanly under XLA without leaking intermediate +residuals into the user-facing autograd graph. + +Sharding model +-------------- +* Inbound activations are 3D ``[B, S, H]`` sharded + ``((*data_parallelism_axes, ep_axis), None, None)``. The public + :func:`moe` soft-repins this on entry and warns when a reshard is + inserted. +* The EP primitives operate at global view (their custom_partitioning + rules handle per-shard execution). The FFN GEMMs run per-shard inside + a small ``shard_map`` whose ``in_specs`` and ``out_specs`` mirror the + same ``((dp, ep), ...)`` layout. + +Out-of-scope (for now) +---------------------- +FP8 / MXFP8 quantizer sets are not yet wired on this path; turning +them on requires recipe-aware residual specs and ``ScaledTensor`` +leaves across the ``shard_map`` boundary. ``aux_loss_coeff`` and +``expert_bias`` are supported (the former forces a per-step +all-gather over the routing-side logits, which lives off the critical +path and overlaps with the dispatch collective). """ -import math from dataclasses import dataclass -from enum import Enum from functools import partial -from typing import Any, NewType, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import warnings import jax import jax.numpy as jnp -from flax import struct as flax_struct from jax.sharding import NamedSharding, PartitionSpec as P from . import cpp_extensions as tex -from .permutation import ( - PureJaxPermState, - compute_ragged_all_to_all_params, - compute_reverse_ragged_all_to_all_params, - pure_jax_token_combine, - pure_jax_token_dispatch, - routing_map_to_selected_experts, -) from .quantize import ( - QuantizerSet, - ScaledTensor, TensorUsage, noop_quantizer_set, with_sharding_constraint_by_logical_axes, @@ -80,1052 +54,132 @@ from .router import ScoreFunction, _validate_score_function from .sharding import _get_mesh -# Triton-backed primitives are imported lazily: callers on the PURE_JAX -# permutation backend should not need ``triton`` installed. The TRITON -# branches in this module call ``_require_triton()`` first to raise a -# clear error if the import failed. -try: - from .triton_extensions.permutation import ( - make_chunk_sort_map, - make_row_id_map, - permute_with_mask_map, - permute_with_mask_map_and_pad, - sort_chunks_by_map, - unpermute_bwd_with_merging_probs, - unpermute_bwd_with_merging_probs_and_unpad, - unpermute_with_mask_map, - unpermute_with_mask_map_and_unpad, - ) - - _TRITON_AVAILABLE = True -except ImportError: - _TRITON_AVAILABLE = False - make_chunk_sort_map = None - make_row_id_map = None - permute_with_mask_map = None - permute_with_mask_map_and_pad = None - sort_chunks_by_map = None - unpermute_bwd_with_merging_probs = None - unpermute_bwd_with_merging_probs_and_unpad = None - unpermute_with_mask_map = None - unpermute_with_mask_map_and_unpad = None - - -def _require_triton(): - """Raise a clear error if Triton permutation kernels are unavailable.""" - if not _TRITON_AVAILABLE: - raise ImportError( - "PermutationBackend.TRITON requires" - " ``transformer_engine.jax.triton_extensions`` (and ``triton``)." - " Install Triton or pass PermutationBackend.PURE_JAX." - ) - - -PRNGKey = Any -Shape = Tuple[int, ...] -DType = NewType("DType", jnp.dtype) -Array = NewType("Array", jnp.ndarray) - - -__all__ = ["moe", "PermutationBackend"] +__all__ = ["moe"] # ============================================================================= -# Enums +# Process-level NCCL EP bootstrap # ============================================================================= +# +# ``tex.ep_bootstrap`` initialises the NCCL EP communicator exactly once per +# process and stashes its state in a C++ singleton. Subsequent calls with the +# same signature are a no-op; calls with a different signature raise. +_te_ep_bootstrap_signature: Optional[Tuple[int, int, int, int, int]] = None -class PermutationBackend(Enum): - """Token-dispatch / combine backend used by :func:`moe`. - * ``TRITON``: TE's fused Triton kernels. Faster than ``PURE_JAX`` - on current hardware and the recommended default. - * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain - XLA; useful as a numerical reference and on builds without - Triton available. - """ +def _te_ep_bootstrap_if_needed( + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Bootstrap the NCCL EP communicator on first use within a process.""" + global _te_ep_bootstrap_signature + sig = (num_experts, max_tokens_per_rank, recv_capacity_per_rank, hidden_dim, ep_size) + if _te_ep_bootstrap_signature == sig: + return + if _te_ep_bootstrap_signature is not None: + raise ValueError( + "TE EP was already bootstrapped with signature " + f"{_te_ep_bootstrap_signature}; got {sig}. Re-bootstrap with" + " different params is not supported within a single process." + ) + from transformer_engine.jax.ep import ep_bootstrap # local: avoids import cycle - PURE_JAX = "pure_jax" - TRITON = "triton" + ep_bootstrap( + world_size=jax.process_count(), + rank=jax.process_index(), + ep_size=ep_size, + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + # XLA may relocate the C++ handle buffer between JIT executables; + # allow it rather than asserting on handle aliasing. + allow_handle_mem_reloc=True, + ) + _te_ep_bootstrap_signature = sig # ============================================================================= -# Dispatch-state records (carried _dispatch -> _combine / *_bwd) +# Residual container threaded fwd -> bwd # ============================================================================= -# -# Two NamedTuples (one per permutation backend) so we get type -# discrimination at the consumer side via ``isinstance``. The backend- -# specific residuals are required fields; the EP-only residuals are -# Optional and are populated only when the run is EP-active. Each field -# is either an ``ndarray`` or ``None`` -- nothing static, since these -# values cross the shard_map pytree boundary and would otherwise be -# coerced into JitTracers. - - -@flax_struct.dataclass -class _PureJaxDispatchState: - """Residuals saved by :func:`_dispatch` on the PURE_JAX path. - - Registered as a JAX pytree via ``flax.struct.dataclass``: each - annotated field is a leaf, ``None`` is a non-leaf sentinel. The - matching spec built by :func:`_build_dispatch_specs` mirrors this - layout so shard_map's value and spec trees line up. - """ - - group_sizes: jnp.ndarray - sorted_indices: jnp.ndarray - routing_weights: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None - - -@flax_struct.dataclass -class _TritonDispatchState: - """Residuals saved by :func:`_dispatch` on the TRITON path.""" - group_sizes: jnp.ndarray - row_id_map: jnp.ndarray - pad_offsets: Optional[jnp.ndarray] # populated only when align_size > 0 - merging_probs: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None +@dataclass +class _Ctx: + """Residuals carried from the fwd rule into the bwd rule.""" -_DispatchState = Union[_PureJaxDispatchState, _TritonDispatchState] - - -@flax_struct.dataclass -class _BodyCtx: - """Residuals carried fwd_rule -> bwd_rule by :func:`_body_fwd`. - - Optional fields (``expert_bias``, ``aux_*``) are ``None`` when the - matching feature is disabled. :func:`_build_ctx_specs` mirrors that - layout so the shard_map spec and value trees match leaf-for-leaf. - """ - - # Always present. - x: Any - gate_kernel: Any - logits_2d: Any - saved_scores: Any - routing_map: Any - dispatch: Any # _DispatchState + x: jnp.ndarray + gate_kernel: jnp.ndarray + expert_bias: jnp.ndarray + logits_2d: jnp.ndarray + saved_scores: jnp.ndarray + routing_map: jnp.ndarray + handle: Any + token_counts: jnp.ndarray + recv_topk_weights: jnp.ndarray casted_sorted_x_lhs_trans: Any - casted_wi_rhs_trans: Any # stacked [E, H, 2, M] residual for fused wi_0|wi_1 bwd - gate_proj_out: Any - up_proj_out: Any + casted_wi_rhs_trans: Any + gate_proj_out: jnp.ndarray + up_proj_out: jnp.ndarray casted_intermediate_lhs_trans: Any casted_wo_rhs_trans: Any - expert_outputs: Any - local_group_sizes: Any - # Feature-gated. - expert_bias: Any = None + expert_outputs: jnp.ndarray + local_group_sizes: jnp.ndarray + # Aux-loss residuals; None when aux_loss_coeff == 0. aux_const_buf: Any = None aux_tokens_per_expert: Any = None - aux_logits_for_score: Any = None aux_saved_scores: Any = None # ============================================================================= -# ctx / dispatch-state key conventions -# ============================================================================= -# -# Both ``ctx`` (carried fwd_rule -> bwd_rule) and the dispatch state -# (carried _dispatch -> _combine / _dispatch_bwd / _combine_bwd) are plain -# python dicts. Using a dict (rather than a flax_struct.dataclass) lets us -# vary the populated keys with the static config without breaking -# ``shard_map``'s ``out_specs`` structural match: the spec dict and the -# value dict are built with the SAME keys via :func:`_build_ctx_specs`. -# -# Below is the key glossary so the rest of the file reads cleanly. -# -# DispatchState (dict): values are jnp.ndarray unless noted -# Always present: -# "group_sizes" [n_groups] per-expert token counts -# (n_groups = E for no-EP, -# E_local for EP) -# "ep_active" bool (carried as a Python flag, -# not in the dict; passed -# alongside) -# PURE_JAX backend: -# "sorted_indices" [num_real + padding] argsort indices -# "routing_weights" [num_tokens, topk] per-token-per-expert weights -# TRITON backend: -# "row_id_map" [num_tokens, 2*E + 1] -# "pad_offsets" [E] or None -# "merging_probs" [num_tokens, E] -# EP-only: -# "all_shards_tokens_per_expert" [num_ep, E] -# "local_perm_row_id_map" [recv_buffer_rows] -# "local_perm_inv_row_id_map" [recv_buffer_rows] -# -# NOTE: per-shard compile-time-constant shapes (num_real_tokens, -# padding_size, pre/post_a2a_buffer_shape) are NOT stored in this -# dict; they are recomputed in _body_fwd/_body_bwd via -# _compute_static_shape_info and passed as Python ints / int tuples to -# the dispatch/combine helpers. Storing them in the dict would cause -# JAX's pytree-flatten across the shard_map boundary to coerce them -# into JitTracer 0-d arrays, which breaks Python-level control flow -# (e.g. ``if padding > 0``) and ``jnp.zeros(shape)`` in the bwd. -# -# See :class:`_BodyCtx` (NamedTuple) for the ctx layout and field -# documentation. :func:`_build_ctx_specs` returns a matching ``_BodyCtx`` -# of ``P(...)`` specs so shard_map's value/spec trees line up -# leaf-for-leaf. - - -# ============================================================================= -# Static shape helper -# ============================================================================= -# -# A set of per-shard shape/size values that the dispatch and combine -# helpers (both fwd and bwd) need. They're all derivable from existing -# static args, so we recompute them in both ``_body_fwd`` and -# ``_body_bwd`` and pass them as Python ints / int-tuples through -# explicit kwargs. We MUST NOT stash them inside the dynamic -# ``state`` / ``ctx`` dict: when the dict crosses the EP shard_map's -# out_specs/in_specs boundary, JAX's pytree-flatten coerces any Python -# int leaves into traced 0-d arrays, which then breaks dependent Python -# code in the bwd (e.g. ``if padding > 0`` and ``jnp.zeros(shape)``). - - -@dataclass(frozen=True) -class _StaticShapeInfo: - """Per-shard compile-time-constant shape info used by dispatch / - combine fwd and bwd. Fields are Python ints / int tuples (NOT jnp - arrays) so they can be passed as ordinary static keyword args. - - Attributes - ---------- - num_real_tokens : int - Per-shard count of real (non-padding) permuted tokens, - i.e. ``per_shard_num_tokens * num_experts_per_tok``. - padding_size : int - Per-shard number of alignment-padding tokens appended to the - sort buffer (``num_experts * (align_size - 1)`` when - ``align_size > 0``, else ``0``). - pre_a2a_buffer_shape : tuple[int, int] - ``(num_real_tokens + padding_size, hidden)`` -- the per-shard - shape of the sorted-inputs buffer sent over the EP - ragged_all_to_all in the fwd direction. - post_a2a_buffer_shape : Optional[tuple[int, int]] - ``(recv_buffer_rows, hidden)`` when EP is active, ``None`` - otherwise. - """ - - num_real_tokens: int - padding_size: int - pre_a2a_buffer_shape: Tuple[int, int] - post_a2a_buffer_shape: Optional[Tuple[int, int]] - - -def _compute_static_shape_info( - *, - batch_size: int, - sequence_length: int, - hidden: int, - num_experts: int, - num_experts_per_tok: int, - align_size: int, - ep_active: bool, - num_ep: int = 1, - fsdp_sizes: Tuple[int, ...] = (), - recv_buffer_rows: int = 0, - batch_is_per_shard: bool = True, -) -> _StaticShapeInfo: - """Build a :class:`_StaticShapeInfo` for the current rank. - - ``batch_is_per_shard`` controls whether ``batch_size`` is already - sharded (True -- e.g. when this is called from inside a shard_map - body, where ``x.shape[0]`` reports the per-shard batch size) or - global (False -- e.g. when computing from x.shape outside the - shard_map body). - """ - if ep_active and not batch_is_per_shard: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - else: - per_shard_batch = batch_size - per_shard_num_tokens = per_shard_batch * sequence_length - num_real_tokens = per_shard_num_tokens * num_experts_per_tok - padding_size = num_experts * (align_size - 1) if align_size > 0 else 0 - pre_a2a_buffer_shape = (num_real_tokens + padding_size, hidden) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) if ep_active else None - return _StaticShapeInfo( - num_real_tokens=num_real_tokens, - padding_size=padding_size, - pre_a2a_buffer_shape=pre_a2a_buffer_shape, - post_a2a_buffer_shape=post_a2a_buffer_shape, - ) - - -# ============================================================================= -# Dispatch / combine helpers (no VJP boundary -- pure Python) -# ============================================================================= - - -def _dispatch( - inputs_2d: jnp.ndarray, - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - *, - backend: PermutationBackend, - num_experts: int, - num_experts_per_tok: int, - align_size: int, - # EP-only: - ep_active: bool, - ep_axis: Optional[str], - num_ep: int, - recv_buffer_rows: int, - shard_id: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, dict]: - """``permute -> (a2a -> local_permute) iff ep_active``. - - Returns ``(sorted_x, state)`` where ``sorted_x`` has shape - ``[buffer_rows, hidden]`` -- ``E`` groups (no-EP) or ``E_local`` groups - (EP) -- and ``state`` is a dict carrying everything :func:`_combine` - and the bwd helpers need to reverse the operation. - - Bypasses the ``custom_vjp``-wrapped public ``token_dispatch`` / - ``pure_jax_token_dispatch`` wrappers (well, mostly: PURE_JAX still - composes through ``pure_jax_token_dispatch`` because that helper has - no ``custom_vjp`` itself -- only its inner ``_sort_activations`` does, - which is fine since we never auto-diff through it from this layer). - For TRITON we call the underlying ``permute_with_mask_map`` / - ``permute_with_mask_map_and_pad`` primitives directly. - """ - num_tokens, hidden = inputs_2d.shape - topk = num_experts_per_tok - - # Backend-specific residuals collected here, then packaged into the - # appropriate _*DispatchState below. - sorted_indices = None - routing_weights_kept = None - row_id_map = None - pad_offsets = None - merging_probs = None - - # ------------------------------------------------------------------ - # Step 1: global permute (every shard routes its own tokens over the - # full expert axis). Backend-specific. - # ------------------------------------------------------------------ - if backend is PermutationBackend.PURE_JAX: - selected_experts, routing_weights = routing_map_to_selected_experts( - sparse_probs, routing_map, topk - ) - sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( - inputs_2d, - selected_experts, - num_experts=num_experts, - num_experts_per_tok=topk, - align_size=align_size, - ) - # NOTE: ``perm_state.num_real_tokens`` and ``perm_state.padding_size`` - # are compile-time Python ints; intentionally NOT stored in the - # returned state (would be coerced to JitTracer 0-d arrays under - # the EP shard_map's pytree flatten). Recompute via - # ``_compute_static_shape_info`` in the bwd / EP-combine - # call sites that need them. - sorted_indices = perm_state.sorted_indices - routing_weights_kept = routing_weights - else: - # TRITON backend -- inline the underlying primitive sequence - # (mirrors ``_token_dispatch_fwd_rule`` but exposes the residuals - # to our ctx instead of saving them inside another custom_vjp). - num_out_tokens = num_tokens * topk - row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) - if align_size > 0: - target_tokens_per_expert = ( - jnp.ceil(tokens_per_expert / align_size) * align_size - ).astype(jnp.int32) - pad_lengths = target_tokens_per_expert - tokens_per_expert - cum_pad = jnp.cumsum(pad_lengths) - pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) - worst_case_out_tokens = ( - (num_out_tokens + num_experts * (align_size - 1)) // align_size - ) * align_size - sorted_inputs, _ = permute_with_mask_map_and_pad( - inputs_2d, - row_id_map, - None, - pad_offsets, - num_tokens, - num_experts, - worst_case_out_tokens, - hidden, - align_size=align_size, - ) - group_sizes = target_tokens_per_expert - else: - sorted_inputs, _ = permute_with_mask_map( - inputs_2d, - row_id_map, - None, - num_tokens, - num_experts, - num_out_tokens, - hidden, - ) - pad_offsets = None - group_sizes = tokens_per_expert - merging_probs = sparse_probs - - def _build_state(group_sizes_val, ep_all=None, ep_local=None): - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=group_sizes_val, - sorted_indices=sorted_indices, - routing_weights=routing_weights_kept, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=group_sizes_val, - row_id_map=row_id_map, - pad_offsets=pad_offsets, - merging_probs=merging_probs, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - - if not ep_active: - return sorted_inputs, _build_state(group_sizes) - - # ------------------------------------------------------------------ - # Step 2 (EP only): all_gather per-expert counts so every shard knows - # the [num_ep, num_experts] token-count matrix. - # ------------------------------------------------------------------ - all_shards_tokens_per_expert = jax.lax.all_gather( - group_sizes[None, :], - axis_name=ep_axis, - axis=0, - tiled=True, - ) - - # ------------------------------------------------------------------ - # Step 3 (EP only): forward ragged_all_to_all over the EP axis. - # ------------------------------------------------------------------ - in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) - recv_buf = jnp.zeros(post_a2a_buffer_shape, dtype=sorted_inputs.dtype) - x_recv = jax.lax.ragged_all_to_all( - sorted_inputs, recv_buf, in_off, send_sz, out_off, recv_sz, axis_name=ep_axis - ) - - # ------------------------------------------------------------------ - # Step 4 (EP only): local permute -- (source_shard, expert) -> - # (expert, shard). Inlined ``local_permute_after_a2a`` so we control - # both the row_id_map and its inverse for the bwd. - # ------------------------------------------------------------------ - num_experts_local = num_experts // num_ep - local_expert_start = shard_id * num_experts_local - local_expert_columns = jax.lax.dynamic_slice( - all_shards_tokens_per_expert, - start_indices=(0, local_expert_start), - slice_sizes=(num_ep, num_experts_local), - ) - split_sizes = local_expert_columns.reshape(-1) # source-major - indices_matrix = jnp.arange(num_ep * num_experts_local, dtype=jnp.int32).reshape( - num_ep, num_experts_local - ) - sorted_chunk_indices = indices_matrix.T.reshape(-1) # source-major -> expert-major - num_chunks = num_ep * num_experts_local - # Build a SINGLE row_id_map. ``is_forward=True`` permutes - # source-major -> expert-major; ``is_forward=False`` is the exact - # inverse (this is exactly what ``_sort_chunks_by_index_bwd_rule`` - # uses on the saved residual). _MoEBlock builds two row_id_maps - # only because it calls ``sort_chunks_by_index`` twice -- once in - # ``local_permute_after_a2a`` and again in ``local_unpermute_before_a2a``; - # each of those wrappers calls ``make_chunk_sort_map`` internally. - # Here we share one map across (fwd permute, fwd inverse-permute, - # bwd permute, bwd inverse-permute). - local_perm_row_id_map = make_chunk_sort_map( - split_sizes, sorted_chunk_indices, recv_buffer_rows, num_chunks - ) - sorted_x, _ = sort_chunks_by_map( - x_recv, local_perm_row_id_map, None, recv_buffer_rows, hidden, is_forward=True - ) - local_group_sizes = jnp.sum(local_expert_columns, axis=0) - - # NOTE: pre_a2a_buffer_shape and post_a2a_buffer_shape are compile- - # time int tuples; intentionally NOT stored in the returned state - # (would be coerced to JitTracer 0-d arrays under the EP shard_map's - # pytree flatten). Recompute via ``_compute_static_shape_info`` in - # the bwd call sites that need them. For EP, ``group_sizes`` here is - # the per-local-expert count (the FFN runs over E_local groups, not - # E). The global ``group_sizes`` lives inside - # ``all_shards_tokens_per_expert`` if anyone needs it for - # diagnostics. - return sorted_x, _build_state( - local_group_sizes, - ep_all=all_shards_tokens_per_expert, - ep_local=local_perm_row_id_map, - ) - - -def _combine( - expert_outputs: jnp.ndarray, - state: _DispatchState, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # Computed by _compute_static_shape_info in the caller, passed here - # rather than stored in ``state`` to survive shard_map crossings. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Inverse of :func:`_dispatch`. - - Returns ``(output, expert_outputs_post_ep)``. ``output`` is the - ``[B, S, H]`` combined activations. ``expert_outputs_post_ep`` is - the FFN-output tensor in the shape that Step 3 of the combine - actually consumed (i.e. after the reverse ragged_all_to_all on EP - runs, or the original input on non-EP). The caller stashes this as - the bwd residual so that ``_combine_bwd``'s Step-3 inverse sees - the same tensor the forward Step 3 used. - """ - if ep_active: - # Step 1 (EP): inverse local permute. Reuse the SAME row_id_map - # built in _dispatch by setting is_forward=False (this is the - # exact inverse, identical to what - # ``_sort_chunks_by_index_bwd_rule`` does with the saved residual). - recv_buffer_rows, hidden = expert_outputs.shape - x_send_back, _ = sort_chunks_by_map( - expert_outputs, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 2 (EP): reverse ragged_all_to_all. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - send_back_buf = jnp.zeros(pre_a2a_buffer_shape, dtype=expert_outputs.dtype) - expert_outputs = jax.lax.ragged_all_to_all( - x_send_back, - send_back_buf, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # Step 3: global combine. ``expert_outputs`` here is the post-A2A - # tensor under EP, or the original input under non-EP -- whichever - # value Step 3 actually consumes. Returned as the second tuple - # element so the caller can stash it as the bwd residual. - if backend is PermutationBackend.PURE_JAX: - # Reuse the reference pure-jax implementation; it has no - # custom_vjp on its outer surface so we can call it freely. - perm_state = PureJaxPermState( - sorted_indices=state.sorted_indices, - num_real_tokens=num_real_tokens, - padding_size=padding_size, - ) - output = pure_jax_token_combine( - expert_outputs, - perm_state, - state.routing_weights, - num_experts_per_tok=num_experts_per_tok, - batch_size=batch_size, - sequence_length=sequence_length, - ) - return output, expert_outputs - # TRITON - num_tokens = state.row_id_map.shape[0] - num_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = expert_outputs.shape[-1] - if state.pad_offsets is not None: - out_2d, _ = unpermute_with_mask_map_and_unpad( - expert_outputs, - state.row_id_map, - state.merging_probs, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, - ) - else: - out_2d, _ = unpermute_with_mask_map( - expert_outputs, - state.row_id_map, - state.merging_probs, - None, - num_tokens, - num_experts, - hidden, - ) - return out_2d.reshape(batch_size, sequence_length, hidden).astype(dtype), expert_outputs - - -def _combine_bwd( # pylint: disable=unused-argument - d_output: jnp.ndarray, - state: _DispatchState, - expert_outputs: jnp.ndarray, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - post_a2a_buffer_shape: Optional[Tuple[int, int]], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Inverse of :func:`_combine` on the cotangent. - - Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``. - - ``expert_outputs`` is the *forward* output of the FFN (same value the - fwd handed to :func:`_combine`). It's required by the TRITON - combine_bwd kernel; for PURE_JAX we don't need it but accept it for - a symmetric signature. - """ - # Step 3 inverse: global combine bwd. - d_output_2d = d_output.reshape(-1, d_output.shape[-1]) - if backend is PermutationBackend.PURE_JAX: - # The pure-jax combine is: - # unsort = _sort_activations(expert_outputs, argsort(sorted_indices)) - # if pad: unsort = unsort[:num_real] - # reshape -> einsum BKE,BK -> BE -> reshape to BSE - # Hand-derive the bwd in plain JAX (no custom_vjp involved): - unsort_indices = jnp.argsort(state.sorted_indices) - topk = num_experts_per_tok - num_real = num_real_tokens - padding = padding_size - # Recover the unsorted intermediate that the fwd produced (we - # need it for the d_routing_weights pullback). Apply the same - # gather the fwd did. - unsort_intermediate = expert_outputs[unsort_indices] - if padding > 0: - unsort_intermediate = unsort_intermediate[:num_real] - # Bwd of einsum/reshape: - # output[B, E] = sum_K intermediate[B, K, E] * weights[B, K] - # d_intermediate[B, K, E] = d_output[B, E] * weights[B, K] - # d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E] - rw = state.routing_weights.reshape(-1, topk) - intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1) - rw_cast = rw.astype(intermediate_3d.dtype) - d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast) - d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype( - state.routing_weights.dtype - ) - d_routing_weights = d_routing_weights.reshape(state.routing_weights.shape) - d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1) - # Pad back with zeros if the fwd stripped padding. - if padding > 0: - d_unsort_intermediate = jnp.concatenate( - [ - d_unsort_intermediate, - jnp.zeros( - (padding, d_unsort_intermediate.shape[-1]), - dtype=d_unsort_intermediate.dtype, - ), - ], - axis=0, - ) - # Bwd of the gather is gather-by-original-indices: - # sorted = unsort[argsort(sorted_indices)] - # d_sorted = scatter d_unsort via argsort(sorted_indices) - # = d_unsort[sorted_indices] (gather by original sorted_indices, - # which is the inverse of argsort(sorted_indices)). - d_expert_outputs_global = d_unsort_intermediate[state.sorted_indices] - else: - # TRITON combine bwd: requires fwd_input (expert_outputs). - num_tokens = state.row_id_map.shape[0] - n_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = d_output_2d.shape[-1] - num_out_tokens = expert_outputs.shape[0] - if state.pad_offsets is not None: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - state.pad_offsets, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - # The kernel only writes positions tokens map to; padded - # positions may contain NaN. Replace with zeros (matches - # ``_token_combine_bwd_rule``). - d_expert_outputs_global = jnp.where( - jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global - ) - else: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - d_routing_weights = d_merging_probs - - if not ep_active: - return d_expert_outputs_global, d_routing_weights - - # Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward - # ragged_all_to_all using the SAME forward parameters (sender / - # receiver roles swap from the reverse direction back to forward). - in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype) - d_x_send_back = jax.lax.ragged_all_to_all( - d_expert_outputs_global, - recv_buf_for_bwd, - in_off_f, - send_sz_f, - out_off_f, - recv_sz_f, - axis_name=ep_axis, - ) - # Step 1 (EP) inverse: combine fwd applied is_forward=False; the - # bwd is is_forward=True with the SAME row_id_map. - recv_buffer_rows, hidden = d_x_send_back.shape - d_expert_outputs, _ = sort_chunks_by_map( - d_x_send_back, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=True, - ) - return d_expert_outputs, d_routing_weights - - -def _dispatch_bwd( - d_sorted_x: jnp.ndarray, - state: _DispatchState, - inputs_2d_shape: Tuple[int, ...], - *, - backend: PermutationBackend, - ep_active: bool, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> jnp.ndarray: - """Inverse of :func:`_dispatch` on the cotangent. Returns ``d_inputs_2d``. - - The probs path through dispatch is always discarded (PURE_JAX never - threads probs through dispatch; TRITON technically does but the - caller drops ``permuted_probs``, so its cotangent is structurally - zero). The probs gradient instead flows back through - :func:`_combine_bwd`. - """ - if ep_active: - # Step 4 inverse: dispatch fwd applied is_forward=True; bwd is - # is_forward=False with the SAME row_id_map. - recv_buffer_rows, hidden = d_sorted_x.shape - d_x_recv, _ = sort_chunks_by_map( - d_sorted_x, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 3 inverse: bwd of forward ragged_a2a is the reverse-direction - # ragged_a2a using the SAME params with sender/receiver swapped. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_pre = jnp.zeros(pre_a2a_buffer_shape, dtype=d_x_recv.dtype) - d_sorted_x = jax.lax.ragged_all_to_all( - d_x_recv, - recv_buf_pre, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # Step 1 inverse: global permute bwd. - if backend is PermutationBackend.PURE_JAX: - # Fwd was: replicated = repeat(inputs_2d, topk, axis=0) - # padded = pad(replicated, (0, padding_size)) - # sorted = padded[sorted_indices] - # Bwd: d_padded = scatter via sorted_indices - # = d_sorted[argsort(sorted_indices)] - # d_replicated = d_padded[:num_real] - # d_inputs_2d = d_replicated.reshape(T, topk, H).sum(axis=1) - sorted_indices = state.sorted_indices - num_real = num_real_tokens - padding = padding_size - topk = num_experts_per_tok - unsort_indices = jnp.argsort(sorted_indices) - d_padded = d_sorted_x[unsort_indices] - if padding > 0: - d_replicated = d_padded[:num_real] - else: - d_replicated = d_padded - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - d_inputs_2d = d_replicated.reshape(num_tokens, topk, hidden).sum(axis=1) - return d_inputs_2d - - # TRITON: bwd is unpermute_with_mask_map[_and_unpad]. - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - if state.pad_offsets is not None: - d_inputs_2d, _ = unpermute_with_mask_map_and_unpad( - d_sorted_x, - state.row_id_map, - None, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, - ) - else: - d_inputs_2d, _ = unpermute_with_mask_map( - d_sorted_x, - state.row_id_map, - None, - None, - num_tokens, - num_experts, - hidden, - ) - return d_inputs_2d - - -# ============================================================================= -# Per-shard body +# Per-shard FFN body (runs inside shard_map) # ============================================================================= -def _body_fwd( # pylint: disable=unused-argument - captured: dict, +def _ffn_fwd_per_shard( + recv_tokens_local: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray], + wi_1_bias: Optional[jnp.ndarray], + wo_bias: Optional[jnp.ndarray], *, - # Statics - num_experts: int, - num_experts_per_tok: int, + num_local_experts: int, + slots_per_expert: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - # EP-only statics - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - apply_topk_weights_early: bool = False, -) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: - """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. + apply_topk_weights_early: bool, +): + """Per-shard FFN forward. - ``aux_loss`` is always materialized (zeros scalar when disabled) so - the ``shard_map``'s ``out_specs`` has a static structure. + Operates on the shard-local ``[1, recv_pr, H]`` slice that + ``tex.ep_dispatch`` produces. Returns the expert outputs (shaped + ``[1, recv_pr, H_out]`` so the surrounding ``shard_map`` reassembles + them as ``[num_procs, recv_pr, H_out]``) plus the residuals consumed + by the bwd. """ - if apply_topk_weights_early: - # Requires row-aligned per-token weights at the FFN intermediate; - # only available on the TE EP (tex.ep_dispatch) path. - raise NotImplementedError( - "apply_topk_weights_early=True is supported only with the TE EP " - "(tex.ep_dispatch / tex.ep_combine) backend." - ) - if not gate_inside_vjp: - raise NotImplementedError( - "gate_inside_vjp=False is deferred to a follow-up PR; for now" - " the gate GEMM lives inside the MoE VJP." - ) - - x = captured["inputs"] - gate_kernel = captured["gate_kernel"] - wi_0 = captured["wi_0"] - wi_1 = captured["wi_1"] - wo = captured["wo"] - wi_0_bias = captured.get("wi_0_bias") - wi_1_bias = captured.get("wi_1_bias") - wo_bias = captured.get("wo_bias") - expert_bias = captured.get("expert_bias") - - batch_size, sequence_length, hidden = x.shape - - # ---------------- Stage 1: gate ---------------- - gate_kernel_cast = gate_kernel.astype(x.dtype) - gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) # [B, S, E] - # tex.fused_topk_with_score_function_* requires rank-2 input. - logits_2d = gate_logits.reshape(-1, num_experts) - inputs_2d = x.reshape(-1, hidden) - - # ---------------- Stage 2: routing ---------------- - # Under EP, expert_bias is sharded P(ep_axis); the router needs the - # full E-dim view, so all_gather it. - if ep_active and expert_bias is not None: - full_expert_bias = jax.lax.all_gather(expert_bias, axis_name=ep_axis, tiled=True) - else: - full_expert_bias = expert_bias - # Pass an empty array sentinel when expert_bias is unused (the - # underlying primitive expects a real ndarray, not None). - eb_arg = ( - full_expert_bias if full_expert_bias is not None else jnp.zeros((0,), dtype=jnp.float32) - ) - sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( - logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - sparse_probs = sparse_probs.astype(dtype) + hidden = recv_tokens_local.shape[-1] + sorted_x = recv_tokens_local.reshape(-1, hidden) + recv_w_flat = recv_topk_weights_local.reshape(-1) + local_group_sizes = jnp.full((num_local_experts,), slots_per_expert, dtype=jnp.int32) - # ---------------- Stage 2b: aux loss ---------------- - if aux_loss_coeff > 0.0: - if ep_active: - collective_axes: Any = ( - ep_axis - if not data_parallelism_axes - else (*data_parallelism_axes, ep_axis) - ) - global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=collective_axes, axis=0, tiled=True - ) - _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( - global_logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = global_logits_2d - else: - aux_tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = logits_2d - # Aux-side scores: clean per-expert scores (no grouped routing, - # no bias). compute_aux_scores=True takes a separate path that - # ignores the grouping knobs. - aux_probs, _aux_routing_map, aux_saved_scores = tex.fused_topk_with_score_function_fwd( - aux_logits_for_score.astype(jnp.float32), - topk=num_experts_per_tok, - use_pre_softmax=False, - num_groups=-1, - group_topk=-1, - scaling_factor=1.0, - score_function=score_function, - expert_bias=jnp.zeros((0,), dtype=jnp.float32), - compute_aux_scores=True, - ) - aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( - aux_probs.astype(jnp.float32), - aux_tokens_per_expert.astype(jnp.int32), - topk=num_experts_per_tok, - coeff=aux_loss_coeff, - ) - else: - aux_loss = jnp.zeros((), dtype=dtype) - aux_const_buf = None - aux_tokens_per_expert = None - aux_logits_for_score = None - aux_saved_scores = None + wi_0 = wi_0.astype(sorted_x.dtype) + wi_1 = wi_1.astype(sorted_x.dtype) + wo = wo.astype(sorted_x.dtype) - # ---------------- Stage 3: dispatch ---------------- - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - sorted_x, dispatch_state = _dispatch( - inputs_2d, - sparse_probs, - routing_map, - backend=permutation_backend, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - ep_axis=ep_axis, - num_ep=num_ep, - recv_buffer_rows=recv_buffer_rows, - shard_id=shard_id, - ) - local_group_sizes = dispatch_state.group_sizes - - # ---------------- Stage 4: per-expert FFN (inlined) ---------------- - q_set_w0, q_set_w1, q_set_wo = quantizer_sets - if q_set_w0 == noop_quantizer_set: - wi_0 = wi_0.astype(sorted_x.dtype) - if q_set_w1 == noop_quantizer_set: - wi_1 = wi_1.astype(sorted_x.dtype) - if q_set_wo == noop_quantizer_set: - wo = wo.astype(sorted_x.dtype) - - # Fused gate+up projection: stack wi_0 / wi_1 on a new axis-(-2) so the - # downstream split is a slice on the (unsharded) stack axis. concat on - # axis=-1 would cross the M axis and force a reshard when M is TP-sharded. - # - # FP8/MXFP8 caveat: per-expert amax is computed over [H, 2, M] rather than - # [H, M] for each of wi_0 / wi_1 separately, so the representable range for - # one half may shift slightly vs. an unfused pair of casts. - inter_M = wi_0.shape[-1] wi_combined = jnp.stack([wi_0, wi_1], axis=-2) wi_combined_bias = ( jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None ) - casted_sorted_x = tex.grouped_quantize(sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1) - casted_wi = tex.grouped_quantize(wi_combined, q_set_w0.kernel, flatten_axis=-1) + + q_set = noop_quantizer_set + casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1) + casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1) combined_out = tex.grouped_gemm( casted_sorted_x.get_tensor(usage=TensorUsage.LHS), casted_wi.get_tensor(usage=TensorUsage.RHS), @@ -1136,20 +190,22 @@ def _body_fwd( # pylint: disable=unused-argument up_proj_out = combined_out[..., 1, :] casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): - casted_sorted_x_lhs_trans = casted_sorted_x_lhs_trans.checkpoint(q_set_w0.x) - if isinstance(casted_wi_rhs_trans, ScaledTensor): - casted_wi_rhs_trans = casted_wi_rhs_trans.checkpoint(q_set_w0.kernel) - # Activation: intermediate = act(gate_proj_out) * up_proj_out act_fn = _convert_to_activation_function(activation_type) intermediate = act_fn(gate_proj_out) * up_proj_out - # GEMM 3: expert_outputs = intermediate @ wo + if apply_topk_weights_early: + # Fold the per-token combine weights into the FFN intermediate; + # the downstream wo GEMM is linear so this is equivalent to the + # late-weighting path, modulo elementwise op fusion gains. + w_b = recv_w_flat[:, None] + mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] + intermediate = intermediate * w_b * mask_b + casted_intermediate = tex.grouped_quantize( - intermediate, q_set_wo.x, local_group_sizes, flatten_axis=-1 + intermediate, q_set.x, local_group_sizes, flatten_axis=-1 ) - casted_wo = tex.grouped_quantize(wo, q_set_wo.kernel, flatten_axis=-1) + casted_wo = tex.grouped_quantize(wo, q_set.kernel, flatten_axis=-1) expert_outputs = tex.grouped_gemm( casted_intermediate.get_tensor(usage=TensorUsage.LHS), casted_wo.get_tensor(usage=TensorUsage.RHS), @@ -1158,234 +214,100 @@ def _body_fwd( # pylint: disable=unused-argument ) casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_intermediate_lhs_trans, ScaledTensor): - casted_intermediate_lhs_trans = casted_intermediate_lhs_trans.checkpoint(q_set_wo.x) - if isinstance(casted_wo_rhs_trans, ScaledTensor): - casted_wo_rhs_trans = casted_wo_rhs_trans.checkpoint(q_set_wo.kernel) - - # ---------------- Stage 5: combine ---------------- - # Compute per-shard static shape info once and pass through both - # _combine and (later) the bwd helpers via kwargs -- never via the - # state dict, which gets pytree-flattened across shard_map and would - # coerce Python ints into JitTracer 0-d arrays. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - ) - # ``expert_outputs_residual`` is the post-A2A FFN-output tensor that - # Step 3 of the combine actually consumed. Saving this (rather than - # the pre-A2A shard-local FFN output) is what makes - # ``_combine_bwd``'s Step-3 inverse see the same value the forward - # Step 3 saw -- otherwise EP + TRITON yields wrong d_expert_outputs. - output, expert_outputs_residual = _combine( - expert_outputs, - dispatch_state, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - # ---------------- Build ctx ---------------- - aux_enabled = aux_loss_coeff > 0.0 - ctx = _BodyCtx( - x=x, - gate_kernel=gate_kernel, - logits_2d=logits_2d, - saved_scores=saved_scores, - routing_map=routing_map, - dispatch=dispatch_state, - casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, - casted_wi_rhs_trans=casted_wi_rhs_trans, - gate_proj_out=gate_proj_out, - up_proj_out=up_proj_out, - casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, - casted_wo_rhs_trans=casted_wo_rhs_trans, - expert_outputs=expert_outputs_residual, - local_group_sizes=local_group_sizes, - expert_bias=expert_bias if expert_bias is not None else None, - aux_const_buf=aux_const_buf if aux_enabled else None, - aux_tokens_per_expert=aux_tokens_per_expert if aux_enabled else None, - aux_logits_for_score=aux_logits_for_score if aux_enabled else None, - aux_saved_scores=aux_saved_scores if aux_enabled else None, + expert_outputs_3d = expert_outputs.reshape(1, expert_outputs.shape[0], expert_outputs.shape[1]) + residuals = ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes, ) - - return output, aux_loss, ctx - - -def _body_bwd( # pylint: disable=unused-argument - ctx: _BodyCtx, - dy_pair: Tuple[jnp.ndarray, jnp.ndarray], + return expert_outputs_3d, residuals + + +def _ffn_bwd_per_shard( + d_expert_outputs_local: jnp.ndarray, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out: jnp.ndarray, + up_proj_out: jnp.ndarray, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, *, - num_experts: int, - num_experts_per_tok: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - # Static side info (kept here rather than inside ctx because they're - # python flags / shapes, not array leaves): - has_wi_bias: bool, - has_wo_bias: bool, - has_expert_bias: bool, - x_shape: Tuple[int, ...], - apply_topk_weights_early: bool = False, -) -> dict: - """Per-shard backward body. Returns a dict of grads keyed identically - to the ``captured`` dict consumed by :func:`_body_fwd`.""" - if apply_topk_weights_early: - raise NotImplementedError( - "apply_topk_weights_early=True is supported only with the TE EP " - "(tex.ep_dispatch / tex.ep_combine) backend." - ) - if not gate_inside_vjp: - raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") - - d_output, d_aux_loss = dy_pair - # The fused FFN bwd quantizes via ``q_set_w0`` only (one quantize for - # the [E, H, 2, M] stacked wi tensor and one for the [T, 2, M] stacked dgrad), - # so ``q_set_w1`` is intentionally unused here. - q_set_w0, _q_set_w1, q_set_wo = quantizer_sets - batch_size, sequence_length, hidden = x_shape - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - - # Recompute per-shard static shape info from existing statics - # (Python ints / int tuples). Plumbed via kwargs to _combine_bwd - # and _dispatch_bwd -- NOT through the ctx dict, because the - # dict gets pytree-flattened across the bwd shard_map's in_specs - # and Python ints would be coerced into JitTracer 0-d arrays - # (breaking ``if padding > 0`` and ``jnp.zeros(shape)`` callsites). - # ``batch_size`` here is the GLOBAL batch size (captured in - # ``x_shape`` by the outer fwd rule), hence ``batch_is_per_shard=False``. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - batch_is_per_shard=False, - ) + apply_topk_weights_early: bool, + has_bias: bool, +): + """Per-shard FFN backward. - # Compute per-shard input shape: under the EP shard_map body, the - # gradient tensors live at per-shard shape, so the dispatch_bwd - # reshape target and ``d_x_from_dispatch.reshape(x_shape)`` below - # must use the per-shard shape rather than the captured global - # ``x_shape``. - if ep_active: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - per_shard_x_shape: Tuple[int, ...] = (per_shard_batch, sequence_length, hidden) - else: - per_shard_x_shape = x_shape - - # ---------------- Combine bwd ---------------- - d_expert_outputs, d_routing_weights = _combine_bwd( - d_output, - ctx.dispatch, - ctx.expert_outputs, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - post_a2a_buffer_shape=_static_shape.post_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) + Mirrors :func:`_ffn_fwd_per_shard`. Returns + ``(d_sorted_x [1, recv_pr, H], d_recv_w [1, recv_pr], d_wi_0, d_wi_1, d_wo, + d_wi_0_bias, d_wi_1_bias, d_wo_bias)``. + """ + d_eo_2d = d_expert_outputs_local.reshape(-1, d_expert_outputs_local.shape[-1]) + recv_w_flat = recv_topk_weights_local.reshape(-1) + q_set = noop_quantizer_set - # ---------------- FFN bwd: GEMM 3 (wo) ---------------- - casted_d_eo = tex.grouped_quantize( - d_expert_outputs, q_set_wo.dgrad, ctx.local_group_sizes, flatten_axis=-1 - ) + # wo bwd + casted_d_eo = tex.grouped_quantize(d_eo_2d, q_set.dgrad, local_group_sizes, flatten_axis=-1) d_intermediate = tex.grouped_gemm( casted_d_eo.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wo_rhs_trans, + casted_wo_rhs_trans, contracting_dims=((1,), (2,)), ) d_wo = tex.grouped_gemm( - ctx.casted_intermediate_lhs_trans, + casted_intermediate_lhs_trans, casted_d_eo.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wo_bias = tex.grouped_dbias(d_expert_outputs, ctx.local_group_sizes) if has_wo_bias else None + d_wo_bias = tex.grouped_dbias(d_eo_2d, local_group_sizes) if has_bias else None - # ---------------- Activation bwd ---------------- - # intermediate = act(gate_proj_out) * up_proj_out - # d(gate_proj_out) = vjp(act, gate_proj_out)(d_intermediate * up_proj_out) - # d(up_proj_out) = d_intermediate * act(gate_proj_out) act_fn = _convert_to_activation_function(activation_type) - act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, ctx.gate_proj_out) + if apply_topk_weights_early: + # intermediate' = intermediate * w * mask. Split the cotangent + # across both factors before the activation bwd consumes it. + w_b = recv_w_flat[:, None] + mask_b = (recv_w_flat != 0).astype(d_intermediate.dtype)[:, None] + intermediate_unweighted = act_fn(gate_proj_out) * up_proj_out + d_recv_w_from_intermediate = jnp.sum( + d_intermediate * intermediate_unweighted * mask_b, axis=-1 + ).astype(recv_w_flat.dtype) + d_intermediate = d_intermediate * w_b * mask_b + else: + d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) + + # Activation bwd + act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, gate_proj_out) d_up_proj_out = d_intermediate * act_gate_proj_out - (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx.up_proj_out) + (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * up_proj_out) - # ---------------- FFN bwd: GEMM 1+2 fused (wi_0 | wi_1) ---------------- - # Mirror of the fwd stack: combine d_gate / d_up on a new axis=-2, - # run one dgrad + one wgrad GEMM, then split on axis=-2. - # d_sorted_x = [d_gate | d_up] @ wi_rhs_trans - # = d_gate @ wi_0^T + d_up @ wi_1^T + # wi bwd (fused gate/up) inter_M = d_gate_proj_out.shape[-1] d_combined = jnp.stack([d_gate_proj_out, d_up_proj_out], axis=-2) casted_d_combined = tex.grouped_quantize( - d_combined, q_set_w0.dgrad, ctx.local_group_sizes, flatten_axis=-1 + d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wi_rhs_trans, + casted_wi_rhs_trans, contracting_dims=((1, 2), (2, 3)), ) d_wi_combined = tex.grouped_gemm( - ctx.casted_sorted_x_lhs_trans, + casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) d_wi_0 = d_wi_combined[..., 0, :] d_wi_1 = d_wi_combined[..., 1, :] - if has_wi_bias: - # grouped_dbias requires rank-2 input; reshape around the call. - # M is not TP-sharded on the bias path, so the reshape is free. + if has_bias: + # tex.grouped_dbias takes a rank-2 input; reshape around the call. d_combined_2d = d_combined.reshape(d_combined.shape[0], -1) - d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, ctx.local_group_sizes) + d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, local_group_sizes) d_wi_combined_bias = d_wi_combined_bias_2d.reshape( *d_wi_combined_bias_2d.shape[:-1], 2, inter_M ) @@ -1395,292 +317,26 @@ def _body_bwd( # pylint: disable=unused-argument d_wi_0_bias = None d_wi_1_bias = None - # ---------------- Dispatch bwd ---------------- - inputs_2d_shape = (per_shard_x_shape[0] * per_shard_x_shape[1], hidden) - d_inputs_2d = _dispatch_bwd( - d_sorted_x, - ctx.dispatch, - inputs_2d_shape=inputs_2d_shape, - backend=permutation_backend, - ep_active=ep_active, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - d_x_from_dispatch = d_inputs_2d.reshape(per_shard_x_shape) - - # ---------------- Routing bwd ---------------- - # The probs cotangent comes from _combine_bwd. For PURE_JAX it's the - # cotangent of routing_weights (post-routing_map_to_selected_experts); - # we need to bridge back to sparse_probs. For TRITON it's already the - # cotangent of merging_probs == sparse_probs. - if d_routing_weights is not None: - if permutation_backend is PermutationBackend.PURE_JAX: - # routing_map_to_selected_experts: - # selected_experts = argsort(routing_map)[..., -topk:] - # weights = take_along_axis(sparse_probs, selected_experts, axis=-1) - # routing_map is bool (non-diff); the gradient of weights - # w.r.t. sparse_probs is a scatter-into-zero along the - # selected_experts indices. - selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -num_experts_per_tok:] - d_sparse_probs = jnp.zeros_like(ctx.saved_scores).astype(d_routing_weights.dtype) - d_sparse_probs = jnp.take_along_axis(d_sparse_probs, selected_experts, axis=-1) - # Actually scatter: build via jnp.zeros + .at[].set - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_routing_weights.dtype) - d_sparse_probs = d_sparse_probs.at[ - jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts - ].set(d_routing_weights) - else: - d_sparse_probs = d_routing_weights.astype(jnp.float32) - else: - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=jnp.float32) - - # Topk bwd primitive: returns d_logits (no d_expert_bias). - d_logits_2d_main = tex.fused_topk_with_score_function_bwd( - ctx.routing_map, - ctx.saved_scores, - d_sparse_probs.astype(ctx.saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - scaling_factor=scaling_factor, - score_function=score_function, - compute_aux_scores=False, - ) - - # ---------------- Aux loss bwd ---------------- - if aux_loss_coeff > 0.0: - # Step 1: aux_loss bwd -> d_aux_probs - aux_num_tokens = ctx.aux_logits_for_score.shape[0] - d_aux_probs = tex.fused_moe_aux_loss_bwd( - ctx.aux_const_buf, - ctx.aux_tokens_per_expert.astype(jnp.int32), - d_aux_loss.reshape(()), - num_tokens=aux_num_tokens, - ) - # Step 2: aux-side topk bwd (compute_aux_scores=True path). - # The routing_map argument is ignored in this branch (the kernel - # uses saved_scores); pass any shape-correct integer tensor. - d_aux_logits = tex.fused_topk_with_score_function_bwd( - jnp.zeros(ctx.aux_logits_for_score.shape, dtype=jnp.bool_), - ctx.aux_saved_scores, - d_aux_probs.astype(ctx.aux_saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=False, - scaling_factor=1.0, - score_function=score_function, - compute_aux_scores=True, - ) - # Inverse of the fwd tiled all_gather along - # ``(*data_parallelism_axes, ep_axis)``: pick out this shard's - # local rows from the global cotangent. JAX's tiled all_gather - # is row-major over the axis-name tuple, so the shard at mesh - # position (i_a, i_b, ...) writes to a contiguous row block - # starting at flat_index * local_T. - if ep_active: - local_T_aux = ctx.logits_2d.shape[0] - flat_shard = 0 - for ax, sz in zip(data_parallelism_axes, fsdp_sizes): - flat_shard = flat_shard * sz + jax.lax.axis_index(ax) - flat_shard = flat_shard * num_ep + shard_id - d_aux_logits_local = jax.lax.dynamic_slice( - d_aux_logits.astype(ctx.logits_2d.dtype), - start_indices=(flat_shard * local_T_aux, 0), - slice_sizes=(local_T_aux, num_experts), - ) - else: - d_aux_logits_local = d_aux_logits.astype(d_logits_2d_main.dtype) - d_logits_2d = d_logits_2d_main + d_aux_logits_local.astype(d_logits_2d_main.dtype) - else: - d_logits_2d = d_logits_2d_main - - # ---------------- Gate bwd ---------------- - d_gate_logits = d_logits_2d.reshape(per_shard_x_shape[0], per_shard_x_shape[1], num_experts) - gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) - d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) - d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) - d_x = d_x_from_gate + d_x_from_dispatch - - # Reduce per-rank partial contributions to match the out_specs - # declared by _build_grads_specs: - # gate_kernel : P() -> psum across (ep, *fsdp) - # wi_0/wi_1/wo : P(ep_axis, ...) -> psum across (*fsdp) only - # inputs : P((ep, fsdp), ...) -> already shard-local, no reduction - if ep_active: - replicate_all = (ep_axis,) + tuple(data_parallelism_axes) - d_gate_kernel = jax.lax.psum(d_gate_kernel, axis_name=replicate_all) - if data_parallelism_axes: - replicate_fsdp = tuple(data_parallelism_axes) - d_wi_0 = jax.lax.psum(d_wi_0, axis_name=replicate_fsdp) - d_wi_1 = jax.lax.psum(d_wi_1, axis_name=replicate_fsdp) - d_wo = jax.lax.psum(d_wo, axis_name=replicate_fsdp) - if has_wi_bias: - d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=replicate_fsdp) - d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=replicate_fsdp) - if has_wo_bias: - d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=replicate_fsdp) - - grads: dict = { - "inputs": d_x, - "gate_kernel": d_gate_kernel, - "wi_0": d_wi_0, - "wi_1": d_wi_1, - "wo": d_wo, - } - if has_wi_bias: - grads["wi_0_bias"] = d_wi_0_bias - grads["wi_1_bias"] = d_wi_1_bias - if has_wo_bias: - grads["wo_bias"] = d_wo_bias - if has_expert_bias: - # expert_bias has no gradient through topk (the topk bwd returns - # None for it). Emit a structural zero so the outer rule has - # something to package. - grads["expert_bias"] = jnp.zeros_like(ctx.expert_bias) - return grads - - -# ============================================================================= -# Spec builders for shard_map (lockstep with ctx_dict / captured_dict) -# ============================================================================= - - -def _build_in_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Build the ``in_specs`` dict for the EP fwd shard_map.""" - specs: dict = { - "inputs": P(batch_pspec_axis, None, None), - "gate_kernel": P(), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if has_bias: - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - specs[name] = P(ep_axis, None) - if has_expert_bias: - specs["expert_bias"] = P(ep_axis) - return specs - - -def _build_dispatch_specs( # pylint: disable=unused-argument - ep_axis: str, - *, - backend: PermutationBackend, - ep_active: bool, - align_size: int, -) -> _DispatchState: - """Build the shard_map ``out_specs`` for the dispatch state. - - Returns a :data:`_DispatchState` (either :class:`_PureJaxDispatchState` - or :class:`_TritonDispatchState`) whose fields are - :class:`PartitionSpec` placeholders. Optional fields are set to - ``P()`` when populated by :func:`_dispatch` and to ``None`` when - intentionally omitted, so the spec's pytree structure mirrors the - value's structure leaf-for-leaf. - """ - ep_all = P() if ep_active else None - ep_local = P() if ep_active else None - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=P(), - sorted_indices=P(), - routing_weights=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=P(), - row_id_map=P(), - pad_offsets=P() if align_size > 0 else None, - merging_probs=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - - -def _build_ctx_specs( # pylint: disable=unused-argument - ep_axis: str, - batch_pspec_axis: Any, - *, - backend: PermutationBackend, - ep_active: bool, - has_bias: bool, - has_expert_bias: bool, - aux_loss_enabled: bool, - align_size: int, -) -> _BodyCtx: - """Build the spec :class:`_BodyCtx` mirroring :func:`_body_fwd`'s ctx. - - Fields gated off by the static config (``expert_bias``, ``aux_*``) - are ``None`` here so the spec pytree matches the value pytree - leaf-for-leaf. - """ - return _BodyCtx( - # Per-shard local activations along the batch axis. - x=P(batch_pspec_axis, None, None), - gate_kernel=P(), - logits_2d=P(batch_pspec_axis, None), - saved_scores=P(batch_pspec_axis, None), - routing_map=P(batch_pspec_axis, None), - dispatch=_build_dispatch_specs( - ep_axis, backend=backend, ep_active=ep_active, align_size=align_size - ), - # FFN residuals: the LHS_TRANS / RHS_TRANS variants of - # grouped_quantize have leading "rows"/"experts" dims that are - # already shard-local (post-dispatch). Use P(ep_axis,...) on - # leading dim; that works whether the leaf is a plain ndarray - # or a ScaledTensor (shard_map applies the spec leaf-wise to - # the registered ScaledTensor pytree). - casted_sorted_x_lhs_trans=P(), - casted_wi_rhs_trans=P(ep_axis, None, None), - gate_proj_out=P(), - up_proj_out=P(), - casted_intermediate_lhs_trans=P(), - casted_wo_rhs_trans=P(ep_axis, None, None), - expert_outputs=P(), - local_group_sizes=P(), - expert_bias=P(ep_axis) if has_expert_bias else None, - aux_const_buf=P() if aux_loss_enabled else None, - aux_tokens_per_expert=P() if aux_loss_enabled else None, - aux_logits_for_score=P() if aux_loss_enabled else None, - aux_saved_scores=P() if aux_loss_enabled else None, - ) - - -def _build_grads_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Spec dict for the grads dict returned by :func:`_body_bwd`.""" - return _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + d_sorted_x_3d = d_sorted_x.reshape(1, d_sorted_x.shape[0], d_sorted_x.shape[1]) + d_recv_w_3d = d_recv_w_from_intermediate.reshape(1, -1) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, ) # ============================================================================= -# Top-level VJP rules +# Full fwd / bwd rules (custom_vjp halves) # ============================================================================= -def _moe_fwd_rule( # pylint: disable=unused-argument - # Args MUST match the positional order of ``_moe`` (diff first, - # then nondiff). See ``_moe_bwd_rule`` for the opposite convention. +def _moe_fwd_rule( x, gate_kernel, wi_0, @@ -1699,109 +355,71 @@ def _moe_fwd_rule( # pylint: disable=unused-argument group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - ep_active = ep_axis is not None - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - "apply_topk_weights_early": apply_topk_weights_early, - } - captured: dict = { - "inputs": x, - "gate_kernel": gate_kernel, - "wi_0": wi_0, - "wi_1": wi_1, - "wo": wo, - } - has_bias = wi_0_bias is not None - has_expert_bias = expert_bias is not None - if has_bias: - captured["wi_0_bias"] = wi_0_bias - captured["wi_1_bias"] = wi_1_bias - captured["wo_bias"] = wo_bias - if has_expert_bias: - captured["expert_bias"] = expert_bias - - if not ep_active: - output, aux_loss, ctx = _body_fwd( - captured, - **body_kwargs, - ep_active=False, - fsdp_sizes=(), - num_ep=1, - num_experts_local=num_experts, - recv_buffer_rows=0, - ) - # Carry static side info to the bwd rule alongside ctx. These - # are Python ints/bools/tuples (NOT pytree leaves), so we - # bundle them as a plain dict rather than putting them on the - # ``_BodyCtx`` NamedTuple where shard_map would try to flatten - # them into JitTracers. - static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x.shape, - "num_experts_local": num_experts, - "recv_buffer_rows": 0, - } - return (output, aux_loss), (ctx, static) - - # ---------------- EP path ---------------- + """Forward: gate -> topk -> ep_dispatch -> shard_map(FFN) -> ep_combine. + + Returns ``(output, aux_loss)``. ``aux_loss`` is a zero scalar when + ``aux_loss_coeff == 0``. + """ + del gate_kernel_axes, wi_kernel_axes, wo_kernel_axes # used in bwd only from jax.experimental.shard_map import shard_map + x = with_sharding_constraint_by_logical_axes(x, input_axes) + mesh = _get_mesh() if mesh is None or mesh.empty: - raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + if ep_axis is None: + raise ValueError("moe(...) requires ep_axis to be set (TE EP backend).") num_ep = mesh.shape[ep_axis] if num_experts % num_ep != 0: raise ValueError(f"num_experts={num_experts} must be divisible by EP size={num_ep}") - num_experts_local = num_experts // num_ep + num_local_experts = num_experts // num_ep - # Reject overlapping EP / FSDP axes. Listing ep_axis in - # data_parallelism_axes would produce a duplicate-axis PartitionSpec - # ((ep, ep, ...)) which JAX rejects, and would also double-count - # num_ep in dp_size (under-sizing recv_buffer_rows by a factor of - # num_ep). Catch it up front with a clear error. + dp_size = 1 for ax in data_parallelism_axes: - if ax not in mesh.shape: - raise ValueError( - f"data_parallelism_axes contains {ax!r} but mesh has" - f" axes {tuple(mesh.shape.keys())}" - ) - if ax == ep_axis: - raise ValueError( - f"data_parallelism_axes={data_parallelism_axes!r} contains the EP" - f" axis {ep_axis!r}; EP is implicit in the batch sharding and must" - " not also be listed as a data-parallel axis." - ) + dp_size *= mesh.shape[ax] + num_procs = num_ep * dp_size + + B, S, H = x.shape + K = num_experts_per_tok + if B % num_procs != 0: + raise ValueError(f"batch={B} not divisible by ep*dp={num_procs}") + + # Per-rank receive capacity (dropless): every rank may receive all of one + # replica's K-expanded tokens. ``slots_per_expert`` is rounded up to a + # multiple of ``align_size`` (FP8 recipes typically need 128 here); the + # rounded value is what we feed to ``tex.ep_prepare`` as the + # ``dispatch_output_per_expert_alignment`` so each local expert's slot + # block starts on the alignment boundary that grouped_gemm expects. + natural_recv_pr = (B // dp_size) * S * K + natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts + if align_size > 0: + slots_per_expert = ((natural_spe + align_size - 1) // align_size) * align_size + else: + slots_per_expert = natural_spe + recv_pr = num_local_experts * slots_per_expert + # Per-rank input token count: B/num_procs rows x S tokens. The bootstrap + # uses this to size the dispatch send buffer; recv_pr above sizes the + # per-rank receive buffer. + max_tokens_per_rank = (B // num_procs) * S + + _te_ep_bootstrap_if_needed( + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + ep_size=num_ep, + ) if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis @@ -1810,64 +428,211 @@ def _moe_fwd_rule( # pylint: disable=unused-argument # consecutive global ranks (dp_color = rank // ep_size), so the # comm only stays within one model replica under (outer_dp, ep). batch_pspec_axis = (*data_parallelism_axes, ep_axis) - dp_size = 1 - for ax in data_parallelism_axes: - dp_size *= mesh.shape[ax] + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) - global_batch_size, sequence_length, _hidden = x.shape - topk = num_experts_per_tok - if global_batch_size % (num_ep * dp_size) != 0: - raise ValueError(f"batch={global_batch_size} not divisible by ep*dp={num_ep * dp_size}") - recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk - if align_size > 0: - recv_buffer_rows += num_experts * (align_size - 1) + # ---------------- Gate (global view) ---------------- + gate_kernel_cast = gate_kernel.astype(x.dtype) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + logits_2d = gate_logits.reshape(-1, num_experts) - in_specs = _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + # ---------------- Routing (global view) ---------------- + # expert_bias is an empty (shape-(0,)) sentinel when the caller did + # not enable it; the primitive treats that as "no bias". + eb_arg = expert_bias if expert_bias.shape != (0,) else jnp.zeros((0,), dtype=jnp.float32) + sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( + logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, ) - output_spec = P(batch_pspec_axis, None, None) - aux_spec = P() - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + sparse_probs = sparse_probs.astype(dtype) + + # ---------------- Aux loss (global view, replicated) ---------------- + # ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across + # all tokens, which is wrong when T is sharded. Force-replicate the + # gate logits and recompute the routing map at global view so the + # kernel sees a complete [T_global, E] tensor. The replication is a + # single all-gather over (*dp, ep) and lives off the dispatch + # critical path. + if aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.with_sharding_constraint( + logits_2d, NamedSharding(mesh, P()) + ) + _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( + global_logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + # compute_aux_scores=True takes a separate kernel path: clean + # per-expert softmax, no grouping / bias / scaling. + aux_probs, _aux_rm, aux_saved_scores = tex.fused_topk_with_score_function_fwd( + global_logits_2d.astype(jnp.float32), + topk=K, + use_pre_softmax=False, + num_groups=-1, + group_topk=-1, + scaling_factor=1.0, + score_function=score_function, + expert_bias=jnp.zeros((0,), dtype=jnp.float32), + compute_aux_scores=True, + ) + aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( + aux_probs.astype(jnp.float32), + aux_tokens_per_expert.astype(jnp.int32), + topk=K, + coeff=aux_loss_coeff, + ) + aux_loss = aux_loss.astype(dtype) + else: + aux_loss = jnp.zeros((), dtype=dtype) + aux_const_buf = None + aux_tokens_per_expert = None + aux_saved_scores = None + + # ---------------- Routing -> (topk_idx, topk_w) at 3D ---------------- + # argsort on a bool tensor places True last (False=0 < True=1), so the + # last K indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -K:] + routing_weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + topk_idx_3d = selected_experts.reshape(B, S, K).astype(jnp.int32) + topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) + + # ---------------- TE EP dispatch (global view) ---------------- + token_counts, handle = tex.ep_prepare(topk_idx_3d, slots_per_expert) + recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( + handle, topk_idx_3d, x, topk_w_3d, recv_pr + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3_spec)) + recv_topk_weights = jax.lax.with_sharding_constraint( + recv_topk_weights, NamedSharding(mesh, ep2_spec) + ) + + # ---------------- FFN (per-shard via shard_map) ---------------- + has_bias = wi_0_bias is not None + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + ffn_in_specs = (ep3_spec, ep2_spec, kernel_spec, kernel_spec, kernel_spec) + ffn_in_args = [recv_tokens, recv_topk_weights, wi_0, wi_1, wo] + if has_bias: + ffn_in_specs = ffn_in_specs + (bias_spec, bias_spec, bias_spec) + ffn_in_args.extend([wi_0_bias, wi_1_bias, wo_bias]) + + # FFN residuals live entirely on the local ep rank, so the leading + # "experts" / "rows" dims map to P() (already shard-local). + residuals_spec = ( + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + P(), # local_group_sizes ) + out_specs = (ep3_spec, residuals_spec) - _fsdp_sizes: Tuple[int, ...] = tuple(mesh.shape[ax] for ax in data_parallelism_axes) - - def _shardmap_body(captured_local): - return _body_fwd( - captured_local, - **body_kwargs, - ep_active=True, - fsdp_sizes=_fsdp_sizes, - num_ep=num_ep, - num_experts_local=num_experts_local, - recv_buffer_rows=recv_buffer_rows, + def _body(*args): + if has_bias: + (r_tok, r_w, w0, w1, w_o, w0b, w1b, wob) = args + else: + (r_tok, r_w, w0, w1, w_o) = args + w0b = w1b = wob = None + return _ffn_fwd_per_shard( + r_tok, + r_w, + w0, + w1, + w_o, + w0b, + w1b, + wob, + num_local_experts=num_local_experts, + slots_per_expert=slots_per_expert, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, ) - output, aux_loss, ctx = shard_map( - _shardmap_body, + expert_outputs, ffn_residuals = shard_map( + _body, mesh=mesh, - in_specs=(in_specs,), - out_specs=(output_spec, aux_spec, ctx_spec), + in_specs=ffn_in_specs, + out_specs=out_specs, check_rep=False, - )(captured) + )(*ffn_in_args) + expert_outputs = jax.lax.with_sharding_constraint( + expert_outputs, NamedSharding(mesh, ep3_spec) + ) + + # ---------------- TE EP combine (global view) ---------------- + out_partition_spec = (batch_pspec_axis, None, None) + if apply_topk_weights_early: + # expert_outputs is already weighted upstream. + output = tex.ep_combine_fwd( + handle, + expert_outputs, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + else: + w = recv_topk_weights[..., None] + mask = (recv_topk_weights != 0).astype(expert_outputs.dtype)[..., None] + weighted = expert_outputs * w * mask + output = tex.ep_combine_fwd( + handle, + weighted, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + + ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes, + ) = ffn_residuals + + ctx = _Ctx( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + handle=handle, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, + ) static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, + "has_bias": has_bias, "x_shape": x.shape, - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, + "recv_pr": recv_pr, } return (output, aux_loss), (ctx, static) @@ -1882,133 +647,260 @@ def _moe_bwd_rule( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, - ctx, - dy_pair, + align_size, + residuals, + cotangents, ): - ctx, static = ctx # split tensor residuals from static side info - has_wi_bias = static["has_wi_bias"] - has_wo_bias = static["has_wo_bias"] - has_expert_bias = static["has_expert_bias"] - x_shape = static["x_shape"] - num_experts_local = static["num_experts_local"] - recv_buffer_rows = static["recv_buffer_rows"] + """Backward mirror of :func:`_moe_fwd_rule`.""" + del num_groups, group_topk, dtype, align_size # captured in residuals / unused in bwd + from jax.experimental.shard_map import shard_map - ep_active = ep_axis is not None - mesh = _get_mesh() if ep_active else None - fsdp_sizes: Tuple[int, ...] = ( - tuple(mesh.shape[ax] for ax in data_parallelism_axes) if ep_active else () - ) - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - "fsdp_sizes": fsdp_sizes, - "num_ep": 1 if not ep_active else mesh.shape[ep_axis], - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, - "has_wi_bias": has_wi_bias, - "has_wo_bias": has_wo_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x_shape, - "apply_topk_weights_early": apply_topk_weights_early, - } + d_output, d_aux_loss = cotangents - if not ep_active: - grads = _body_bwd(ctx, dy_pair, ep_active=False, **body_kwargs) - # Apply sharding constraints on grads. - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes - ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + ctx, static = residuals + has_bias = static["has_bias"] + x_shape = static["x_shape"] + recv_pr = static["recv_pr"] - from jax.experimental.shard_map import shard_map + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + num_ep = mesh.shape[ep_axis] + dp_size = 1 + for ax in data_parallelism_axes: + dp_size *= mesh.shape[ax] + B, S, _ = x_shape + K = num_experts_per_tok if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - # ep must be innermost: ep_bootstrap forms NCCL EP comms from - # consecutive global ranks (dp_color = rank // ep_size), so the - # comm only stays within one model replica under (outer_dp, ep). batch_pspec_axis = (*data_parallelism_axes, ep_axis) - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_wi_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + out_partition_spec = (batch_pspec_axis, None, None) + + # ---------------- Combine bwd (global view) ---------------- + d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec)) + grad_pre_combine = tex.ep_combine_bwd(ctx.handle, d_output, recv_pr) + grad_pre_combine = jax.lax.with_sharding_constraint( + grad_pre_combine, NamedSharding(mesh, ep3_spec) ) - dy_specs = (P(batch_pspec_axis, None, None), P()) - grads_spec = _build_grads_specs( - ep_axis, batch_pspec_axis, has_bias=has_wi_bias, has_expert_bias=has_expert_bias + + if apply_topk_weights_early: + # combine_fwd consumed already-weighted expert_outputs; the recv_w + # cotangent flows through the early-weighting step inside the FFN bwd. + d_expert_outputs = grad_pre_combine + d_recv_w_from_combine = jnp.zeros_like(ctx.recv_topk_weights) + else: + # combine_fwd consumed weighted = expert_out * w * mask; + # split the cotangent across both factors. + w = ctx.recv_topk_weights[..., None] + mask = (ctx.recv_topk_weights != 0).astype(grad_pre_combine.dtype)[..., None] + d_expert_outputs = grad_pre_combine * w * mask + d_recv_w_from_combine = (grad_pre_combine * ctx.expert_outputs * mask).sum(axis=-1) + d_recv_w_from_combine = d_recv_w_from_combine.astype(ctx.recv_topk_weights.dtype) + + # ---------------- FFN bwd (per-shard via shard_map) ---------------- + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + + bwd_in_specs = ( + ep3_spec, # d_expert_outputs + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + P(), # local_group_sizes + ep2_spec, # recv_topk_weights + ) + bwd_in_args = [ + d_expert_outputs, + ctx.casted_sorted_x_lhs_trans, + ctx.casted_wi_rhs_trans, + ctx.gate_proj_out, + ctx.up_proj_out, + ctx.casted_intermediate_lhs_trans, + ctx.casted_wo_rhs_trans, + ctx.local_group_sizes, + ctx.recv_topk_weights, + ] + bwd_out_specs = ( + ep3_spec, # d_sorted_x + ep2_spec, # d_recv_w_from_intermediate + kernel_spec, # d_wi_0 + kernel_spec, # d_wi_1 + kernel_spec, # d_wo + bias_spec if has_bias else None, # d_wi_0_bias + bias_spec if has_bias else None, # d_wi_1_bias + bias_spec if has_bias else None, # d_wo_bias ) - def _bwd_body(ctx_local, dy_local): - return _body_bwd(ctx_local, dy_local, ep_active=True, **body_kwargs) + def _bwd_body(*args): + ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = _ffn_bwd_per_shard( + *args, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, + has_bias=has_bias, + ) + # Weight grads accumulate per-DP-shard inside the body; psum across + # DP axes so each replica sees the full sum (matches out_specs + # P(ep_axis, ...) which is DP-replicated). + if data_parallelism_axes: + dp = tuple(data_parallelism_axes) + d_wi_0 = jax.lax.psum(d_wi_0, axis_name=dp) + d_wi_1 = jax.lax.psum(d_wi_1, axis_name=dp) + d_wo = jax.lax.psum(d_wo, axis_name=dp) + if has_bias: + d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=dp) + d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=dp) + d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=dp) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) - grads = shard_map( + ( + d_sorted_x, + d_recv_w_from_intermediate, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = shard_map( _bwd_body, mesh=mesh, - in_specs=(ctx_spec, dy_specs), - out_specs=grads_spec, + in_specs=bwd_in_specs, + out_specs=bwd_out_specs, check_rep=False, - )(ctx, dy_pair) + )(*bwd_in_args) + + d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes + # ---------------- Dispatch bwd (global view) ---------------- + d_sorted_x = jax.lax.with_sharding_constraint(d_sorted_x, NamedSharding(mesh, ep3_spec)) + d_recv_w_total = jax.lax.with_sharding_constraint( + d_recv_w_total, NamedSharding(mesh, ep2_spec) + ) + d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( + ctx.handle, + d_sorted_x, + d_recv_w_total, + top_k=K, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + # ---------------- Routing bwd (global view) ---------------- + # The cotangent on routing_weights is a sparse scatter into sparse_probs + # at the selected_experts indices. + selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -K:] + d_topk_w_flat = d_topk_w.reshape(-1, K) + d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_topk_w_flat.dtype) + d_sparse_probs = d_sparse_probs.at[ + jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts + ].set(d_topk_w_flat) + + d_logits_2d = tex.fused_topk_with_score_function_bwd( + ctx.routing_map, + ctx.saved_scores, + d_sparse_probs.astype(ctx.saved_scores.dtype), + topk=K, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=False, + ) + + # ---------------- Aux loss bwd (global view, replicated) ---------------- + # Reverse the fwd's all-gather/aux pipeline: aux_loss_bwd produces + # d_aux_probs, then topk_bwd(compute_aux_scores=True) produces the + # extra d_logits contribution. The replicated tensor adds into the + # T-sharded routing-side d_logits via JAX's normal broadcast. + if aux_loss_coeff > 0.0: + T_global = ctx.logits_2d.shape[0] + d_aux_loss_scalar = d_aux_loss.reshape(()).astype(jnp.float32) + d_aux_probs = tex.fused_moe_aux_loss_bwd( + ctx.aux_const_buf, + ctx.aux_tokens_per_expert.astype(jnp.int32), + d_aux_loss_scalar, + num_tokens=int(T_global), + ) + # routing_map is ignored by the kernel when compute_aux_scores=True, + # so pass a zero placeholder of the right shape/dtype. + zero_routing_map = jnp.zeros( + ctx.aux_saved_scores.shape, dtype=ctx.routing_map.dtype + ) + d_logits_aux = tex.fused_topk_with_score_function_bwd( + zero_routing_map, + ctx.aux_saved_scores, + d_aux_probs.astype(ctx.aux_saved_scores.dtype), + topk=K, + use_pre_softmax=False, + scaling_factor=1.0, + score_function=score_function, + compute_aux_scores=True, + ) + d_logits_2d = d_logits_2d + d_logits_aux.astype(d_logits_2d.dtype) + + # ---------------- Gate bwd (global view) ---------------- + d_gate_logits = d_logits_2d.reshape(B, S, num_experts) + gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) + d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) + d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) + d_x = d_x_from_gate + d_x_from_dispatch + + # Pin output grads to the declared logical axes so downstream + # optimizers see consistent shardings. + d_x = with_sharding_constraint_by_logical_axes(d_x, input_axes) + d_gate_kernel = with_sharding_constraint_by_logical_axes(d_gate_kernel, gate_kernel_axes) + d_wi_0 = with_sharding_constraint_by_logical_axes(d_wi_0, wi_kernel_axes) + d_wi_1 = with_sharding_constraint_by_logical_axes(d_wi_1, wi_kernel_axes) + d_wo = with_sharding_constraint_by_logical_axes(d_wo, wo_kernel_axes) + + # expert_bias has no learnable bwd path through fused_topk: the + # primitive's bwd returns None for the bias slot. Match that with a + # zero cotangent of the right shape so custom_vjp's arity check + # passes. + d_expert_bias = jnp.zeros_like(ctx.expert_bias) -def _grads_dict_to_tuple( - grads: dict, has_wi_bias: bool, has_wo_bias: bool, has_expert_bias: bool -) -> Tuple: - """Pack the body_bwd's grads dict into the positional tuple JAX expects.""" return ( - grads["inputs"], - grads["gate_kernel"], - grads["wi_0"], - grads["wi_1"], - grads["wo"], - grads.get("wi_0_bias") if has_wi_bias else None, - grads.get("wi_1_bias") if has_wi_bias else None, - grads.get("wo_bias") if has_wo_bias else None, - grads.get("expert_bias") if has_expert_bias else None, + d_x, + d_gate_kernel, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias if has_bias else None, + d_wi_1_bias if has_bias else None, + d_wo_bias if has_bias else None, + d_expert_bias, ) @@ -2017,7 +909,7 @@ def _grads_dict_to_tuple( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 30))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 27))) def _moe( x, gate_kernel, @@ -2037,24 +929,17 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ): - # Call in `_moe`'s own signature order to match what JAX will pass - # the fwd rule via ``_argnums_partial``. See the comment block at - # the top of ``_moe_fwd_rule`` for why this differs from - # ``_moe_bwd_rule``'s convention. - output_pair, _ = _moe_fwd_rule( + primal, _ = _moe_fwd_rule( x, gate_kernel, wi_0, @@ -2073,20 +958,17 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, apply_topk_weights_early, + align_size, ) - return output_pair + return primal _moe.defvjp(_moe_fwd_rule, _moe_bwd_rule) @@ -2103,84 +985,87 @@ def moe( wo_bias: Optional[jnp.ndarray] = None, expert_bias: Optional[jnp.ndarray] = None, *, - # Architecture num_experts: int, num_experts_per_tok: int, activation_type: str = "silu", - # Routing score_function: Union[str, ScoreFunction] = "softmax", use_pre_softmax: bool = False, num_groups: Optional[int] = None, group_topk: Optional[int] = None, scaling_factor: float = 1.0, aux_loss_coeff: float = 0.0, - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, - align_size: int = 0, - # Gate placement - gate_inside_vjp: bool = True, - # When True, fold per-token top-k weights into the FFN intermediate - # (next to act(gate)*up) instead of into the post-down-projection - # combine. Both placements are mathematically equivalent (down-proj is - # linear); the early placement gives XLA a chance to fuse the multiply - # with the activation. Off by default. apply_topk_weights_early: bool = False, - # Parallelism (resolved by caller from MeshResource) - ep_axis: Optional[str] = None, + align_size: int = 0, + ep_axis: str, data_parallelism_axes: Tuple[str, ...] = (), - # Logical axes for sharding constraints input_axes: Tuple[Optional[str], ...] = (), gate_kernel_axes: Tuple[Optional[str], ...] = (), wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp"), wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed"), - # Quantization - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet] = ( - noop_quantizer_set, - noop_quantizer_set, - noop_quantizer_set, - ), dtype: jnp.dtype = jnp.float32, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Run a full MoE block under a single fused custom_vjp. + """Run a full MoE block under a single fused custom_vjp on the TE EP path. + + Returns ``(output, aux_loss)``. ``aux_loss`` is ``None`` when + ``aux_loss_coeff == 0`` and a 0-d scalar otherwise. - Parameters and return are documented at the call site of - ``_MoEBlock.__call__``. See module docstring for design rationale. + Parameters + ---------- + expert_bias : Optional[jnp.ndarray] + ``[num_experts]`` learnable router bias added before the top-k + when ``score_function='sigmoid'``. Pass ``None`` to disable. + The bias has no gradient through the top-k primitive itself (it + only steers expert selection); a zero cotangent is returned for + it. + aux_loss_coeff : float + Per-step expert-load-balance loss coefficient. ``0.0`` (default) + disables the aux loss entirely. When non-zero, an extra + all-gather over the routing-side logits is inserted so the + ``fused_moe_aux_loss`` kernel sees a global ``[T_global, E]`` + view; this lives off the dispatch critical path. + align_size : int + Minimum per-expert slot alignment passed to ``tex.ep_prepare`` + as ``dispatch_output_per_expert_alignment``. ``0`` (default) + means use the natural slot count + ``ceil((B/dp)*S*K / num_local_experts)``. Any positive value + rounds that count up to the nearest multiple, growing the + per-rank receive buffer accordingly. Set to ``128`` for FP8 + recipes that require 128-aligned grouped-GEMM tiles. + + See module docstring for the rest of the parameter semantics and the + surrounding design rationale. """ - if not isinstance(permutation_backend, PermutationBackend): - raise TypeError( - f"permutation_backend must be a PermutationBackend, got {permutation_backend!r}" - ) - if permutation_backend is PermutationBackend.TRITON: - _require_triton() - # Normalize string score_function ("softmax" / "sigmoid") to the - # ScoreFunction enum once here. The underlying primitive - # ``tex.fused_topk_with_score_function_fwd`` expects an int-coercible - # value (the enum has integer .value), and the public router wrapper - # we bypass also normalizes here. score_function = _validate_score_function(score_function) # Enforce ((outer_dp..., ep), None, None) on inbound activations. The # EP comm groups consecutive global ranks (dp_color = rank // ep_size), # so ep MUST be innermost in the partition spec. Soft re-pin: free if # upstream already matches, single reshard otherwise. - if ep_axis is not None: - mesh = _get_mesh() - if mesh is None or mesh.empty: - raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") - expected_leading: Any = ( - (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + expected_leading: Any = ( + (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + ) + expected_spec = P(expected_leading, None, None) + actual_spec = getattr(getattr(x, "sharding", None), "spec", None) + if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): + warnings.warn( + f"moe(...): inbound x sharding {actual_spec} does not match expected " + f"{expected_spec}; inserting a reshard. Apply " + "jax.lax.with_sharding_constraint upstream to avoid this overhead.", + UserWarning, + stacklevel=2, ) - expected_spec = P(expected_leading, None, None) - actual_spec = getattr(getattr(x, "sharding", None), "spec", None) - if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): - warnings.warn( - f"moe(...): inbound x sharding {actual_spec} does not match expected " - f"{expected_spec}; inserting a reshard. Apply " - "jax.lax.with_sharding_constraint upstream to avoid this overhead.", - UserWarning, - stacklevel=2, - ) - x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + + # custom_vjp can't trace through None args; lower expert_bias to an + # empty shape-(0,) tensor that fused_topk_with_score_function treats + # as "no bias". + if expert_bias is None: + expert_bias_arg = jnp.zeros((0,), dtype=jnp.float32) + else: + expert_bias_arg = expert_bias output, aux_loss = _moe( x, @@ -2191,28 +1076,25 @@ def moe( wi_0_bias, wi_1_bias, wo_bias, - expert_bias, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - activation_type=activation_type, - score_function=score_function, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups, - group_topk=group_topk, - scaling_factor=scaling_factor, - aux_loss_coeff=aux_loss_coeff, - permutation_backend=permutation_backend, - align_size=align_size, - gate_inside_vjp=gate_inside_vjp, - ep_axis=ep_axis, - data_parallelism_axes=data_parallelism_axes, - input_axes=input_axes, - gate_kernel_axes=gate_kernel_axes, - wi_kernel_axes=wi_kernel_axes, - wo_kernel_axes=wo_kernel_axes, - quantizer_sets=quantizer_sets, - dtype=dtype, - apply_topk_weights_early=apply_topk_weights_early, + expert_bias_arg, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + float(aux_loss_coeff), + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + dtype, + apply_topk_weights_early, + align_size, ) if aux_loss_coeff <= 0.0: aux_loss = None From 776c5effd8238b029064ba83577b931fa2b5c8dd Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 15:57:37 -0700 Subject: [PATCH 29/36] [JAX] MoE: bootstrap TE EP eagerly outside jit; assert compatibility per-call ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under jax.jit and fails with TracerArrayConversionError. Move the bootstrap to the test fixture (matching the test_multi_process_ep.py pattern from the TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule now only asserts that the recorded bootstrap signature is wide enough (num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank and recv_capacity_per_rank <= bootstrap values). Test mesh fixture bootstraps with the worst-case recv_pr across _CONFIGS so every parametrized config is compatible with a single per-process bootstrap. --- tests/jax/test_te_ep_moe.py | 51 ++++++++++++++++++++- transformer_engine/jax/moe.py | 84 +++++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 31 deletions(-) diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index cc878e0bd1..9373873bec 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -130,7 +130,8 @@ def _read_mp_options(): ) from transformer_engine.jax.flax import _MoEBlock as MoEBlock -from transformer_engine.jax.moe import moe +from transformer_engine.jax.moe import moe, record_ep_bootstrap_signature_for_moe +from transformer_engine.jax.ep import ep_bootstrap from transformer_engine.jax.sharding import MeshResource, global_shard_guard @@ -191,6 +192,23 @@ def _read_mp_options(): # ----------------------------------------------------------------------------- +def _compute_worst_case_recv_pr(): + """Worst-case per-rank recv buffer across every config in _CONFIGS. + + Bootstrap reserves NCCL EP buffers; per-call recv_pr <= bootstrap + recv_pr is fine. We size with the largest align_size in _CONFIGS so + the align128 config still fits the same singleton bootstrap. + """ + num_procs = jax.device_count() + dp_size = num_procs // EP_SIZE + num_local_experts = NUM_EXPERTS // EP_SIZE + natural_recv_pr = (BATCH // dp_size) * SEQ * TOPK + natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts + worst_align = 128 + worst_spe = ((natural_spe + worst_align - 1) // worst_align) * worst_align + return num_local_experts * worst_spe + + @pytest.fixture(scope="module") def mesh(): if jax.device_count() < NUM_DEVICES_REQUIRED: @@ -202,7 +220,36 @@ def mesh(): # from consecutive global ranks via ``dp_color = rank // ep_size``, so # only an (outer_fsdp, inner_ep) device layout groups ranks correctly. devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE)) - return Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + mesh_obj = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + + num_procs = jax.process_count() + max_tokens_per_rank = (BATCH // num_procs) * SEQ + recv_capacity_per_rank = _compute_worst_case_recv_pr() + + # Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather + # and cannot run from inside jax.jit. Sized to the worst-case recv_pr + # across _CONFIGS so every parametrized config is bootstrap-compatible. + with mesh_obj, global_shard_guard( + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + ): + ep_bootstrap( + world_size=num_procs, + rank=jax.process_index(), + ep_size=EP_SIZE, + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + allow_handle_mem_reloc=True, + ) + record_ep_bootstrap_signature_for_moe( + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + ep_size=EP_SIZE, + ) + return mesh_obj # ----------------------------------------------------------------------------- diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 162ea8f7e5..da6f0ac1e4 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -58,49 +58,75 @@ # ============================================================================= -# Process-level NCCL EP bootstrap +# Process-level NCCL EP bootstrap (must run eagerly, outside jax.jit) # ============================================================================= # -# ``tex.ep_bootstrap`` initialises the NCCL EP communicator exactly once per -# process and stashes its state in a C++ singleton. Subsequent calls with the -# same signature are a no-op; calls with a different signature raise. +# ``tex.ep_bootstrap`` does a NCCL UID allgather over the JAX runtime, which +# cannot run from inside a jit-traced function. The caller must bootstrap +# eagerly once per process before any jitted MoE call, then record the +# bootstrap signature via ``record_ep_bootstrap_signature_for_moe``. The +# per-call check below verifies the recorded signature is wide enough for +# the current MoE invocation (smaller per-call usage is fine since the C++ +# backend reserves worst-case buffers at bootstrap time). _te_ep_bootstrap_signature: Optional[Tuple[int, int, int, int, int]] = None -def _te_ep_bootstrap_if_needed( +def record_ep_bootstrap_signature_for_moe( num_experts: int, max_tokens_per_rank: int, recv_capacity_per_rank: int, hidden_dim: int, ep_size: int, ) -> None: - """Bootstrap the NCCL EP communicator on first use within a process.""" + """Record the params passed to ``ep_bootstrap`` so the per-call check + in ``_moe_fwd_rule`` can verify compatibility. Call this once per + process immediately after ``ep_bootstrap``. + """ global _te_ep_bootstrap_signature - sig = (num_experts, max_tokens_per_rank, recv_capacity_per_rank, hidden_dim, ep_size) - if _te_ep_bootstrap_signature == sig: - return - if _te_ep_bootstrap_signature is not None: + _te_ep_bootstrap_signature = ( + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + ep_size, + ) + + +def _te_ep_assert_compatible_bootstrap( + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Verify a prior eager ``ep_bootstrap`` is wide enough for this call.""" + if _te_ep_bootstrap_signature is None: + raise RuntimeError( + "TE EP was not bootstrapped. Call" + " transformer_engine.jax.ep.ep_bootstrap(...) eagerly (outside" + " any jax.jit) once per process, then" + " transformer_engine.jax.moe.record_ep_bootstrap_signature_for_moe(...)" + " with the same params, before invoking moe()." + ) + b_num_experts, b_max_tpr, b_recv_pr, b_hidden, b_ep_size = _te_ep_bootstrap_signature + if ( + num_experts != b_num_experts + or hidden_dim != b_hidden + or ep_size != b_ep_size + or max_tokens_per_rank > b_max_tpr + or recv_capacity_per_rank > b_recv_pr + ): raise ValueError( - "TE EP was already bootstrapped with signature " - f"{_te_ep_bootstrap_signature}; got {sig}. Re-bootstrap with" - " different params is not supported within a single process." + "TE EP was already bootstrapped with signature" + f" (num_experts={b_num_experts}, max_tokens_per_rank={b_max_tpr}," + f" recv_capacity_per_rank={b_recv_pr}, hidden_dim={b_hidden}," + f" ep_size={b_ep_size}); this moe() call needs" + f" (num_experts={num_experts}, max_tokens_per_rank={max_tokens_per_rank}," + f" recv_capacity_per_rank={recv_capacity_per_rank}, hidden_dim={hidden_dim}," + f" ep_size={ep_size}). Re-bootstrap with wider params (or matching exact" + " sizes) is required." ) - from transformer_engine.jax.ep import ep_bootstrap # local: avoids import cycle - - ep_bootstrap( - world_size=jax.process_count(), - rank=jax.process_index(), - ep_size=ep_size, - num_experts=num_experts, - max_tokens_per_rank=max_tokens_per_rank, - recv_capacity_per_rank=recv_capacity_per_rank, - hidden_dim=hidden_dim, - # XLA may relocate the C++ handle buffer between JIT executables; - # allow it rather than asserting on handle aliasing. - allow_handle_mem_reloc=True, - ) - _te_ep_bootstrap_signature = sig # ============================================================================= @@ -413,7 +439,7 @@ def _moe_fwd_rule( # per-rank receive buffer. max_tokens_per_rank = (B // num_procs) * S - _te_ep_bootstrap_if_needed( + _te_ep_assert_compatible_bootstrap( num_experts=num_experts, max_tokens_per_rank=max_tokens_per_rank, recv_capacity_per_rank=recv_pr, From acb610ff1028f85614a803efc675e2135088d29e Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:15:10 -0700 Subject: [PATCH 30/36] [JAX] MoE: thread EpHandle + handle_mem through dispatch / combine The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in e927903c) expects an EpHandle object plus a separate handle_mem buffer for every dispatch/combine call. The MoE wrapper was still passing the raw slots_per_expert int as the second positional and unpacking ep_dispatch_fwd as a 3-tuple, which now blows up with "AttributeError: 'int' object has no attribute 'handle_id'". Changes: - Cache one EpHandle per (top_k, alignment) at module scope so repeated jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool. - _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle) -> (token_counts, handle_mem), and pass (handle, handle_mem) into the fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple. - _Ctx: stash handle_mem alongside handle so the bwd can hand both back to ep_combine_bwd and ep_dispatch_bwd. - _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine calls. --- transformer_engine/jax/moe.py | 37 ++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index da6f0ac1e4..067be5a60b 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -37,7 +37,7 @@ from dataclasses import dataclass from functools import partial -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import warnings import jax @@ -93,6 +93,25 @@ def record_ep_bootstrap_signature_for_moe( ) +# Per-(top_k, alignment) EpHandle cache. ``tex.ep_make_handle`` mints a +# fresh handle_id from a singleton pool capped at NVTE_EP_HANDLE_CACHE_SIZE +# (default 8192); caching here keeps the pool steady across many jit traces +# of the same MoE block configuration. +_te_ep_handle_cache: Dict[Tuple[int, int], Any] = {} + + +def _get_or_make_ep_handle(top_k: int, dispatch_output_per_expert_alignment: int): + key = (int(top_k), int(dispatch_output_per_expert_alignment)) + h = _te_ep_handle_cache.get(key) + if h is None: + h = tex.ep_make_handle( + top_k=key[0], + dispatch_output_per_expert_alignment=key[1], + ) + _te_ep_handle_cache[key] = h + return h + + def _te_ep_assert_compatible_bootstrap( num_experts: int, max_tokens_per_rank: int, @@ -145,6 +164,7 @@ class _Ctx: saved_scores: jnp.ndarray routing_map: jnp.ndarray handle: Any + handle_mem: Any token_counts: jnp.ndarray recv_topk_weights: jnp.ndarray casted_sorted_x_lhs_trans: Any @@ -538,9 +558,12 @@ def _moe_fwd_rule( topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) # ---------------- TE EP dispatch (global view) ---------------- - token_counts, handle = tex.ep_prepare(topk_idx_3d, slots_per_expert) - recv_tokens, recv_topk_weights, handle = tex.ep_dispatch_fwd( - handle, topk_idx_3d, x, topk_w_3d, recv_pr + handle = _get_or_make_ep_handle( + top_k=K, dispatch_output_per_expert_alignment=slots_per_expert + ) + token_counts, handle_mem = tex.ep_prepare(topk_idx_3d, handle) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + handle, handle_mem, topk_idx_3d, x, topk_w_3d, recv_pr ) recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3_spec)) recv_topk_weights = jax.lax.with_sharding_constraint( @@ -608,6 +631,7 @@ def _body(*args): # expert_outputs is already weighted upstream. output = tex.ep_combine_fwd( handle, + handle_mem, expert_outputs, num_local_tokens=(B, S), out_partition_spec=out_partition_spec, @@ -618,6 +642,7 @@ def _body(*args): weighted = expert_outputs * w * mask output = tex.ep_combine_fwd( handle, + handle_mem, weighted, num_local_tokens=(B, S), out_partition_spec=out_partition_spec, @@ -641,6 +666,7 @@ def _body(*args): saved_scores=saved_scores, routing_map=routing_map, handle=handle, + handle_mem=handle_mem, token_counts=token_counts, recv_topk_weights=recv_topk_weights, casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, @@ -716,7 +742,7 @@ def _moe_bwd_rule( # ---------------- Combine bwd (global view) ---------------- d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec)) - grad_pre_combine = tex.ep_combine_bwd(ctx.handle, d_output, recv_pr) + grad_pre_combine = tex.ep_combine_bwd(ctx.handle, ctx.handle_mem, d_output, recv_pr) grad_pre_combine = jax.lax.with_sharding_constraint( grad_pre_combine, NamedSharding(mesh, ep3_spec) ) @@ -837,6 +863,7 @@ def _bwd_body(*args): ) d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( ctx.handle, + ctx.handle_mem, d_sorted_x, d_recv_w_total, top_k=K, From 458d1c4cf26fb20a714f0f49284b6efe23a883f7 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 3 Jun 2026 16:56:16 -0700 Subject: [PATCH 31/36] [JAX] MoE: pass bf16 as max_token_dtype to test fixture's ep_bootstrap te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap. Tests dispatch bf16 tokens; without this arg the group lands with the legacy kByte default (1 byte) and every dispatch aborts at the ep_backend.cpp:349 dtype check. --- tests/jax/test_te_ep_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index 9373873bec..d56dad2c92 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -241,6 +241,7 @@ def mesh(): recv_capacity_per_rank=recv_capacity_per_rank, hidden_dim=HIDDEN, allow_handle_mem_reloc=True, + max_token_dtype=DTYPE, ) record_ep_bootstrap_signature_for_moe( num_experts=NUM_EXPERTS, From 3d6825c7093a82a6ecc4fbfa618dc81c4fd1f6b4 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 4 Jun 2026 10:39:15 -0700 Subject: [PATCH 32/36] patching the sharding stripped by flattening logits input to topk, will fix for real in later commits Signed-off-by: tdophung --- transformer_engine/jax/moe.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 067be5a60b..a96a77b991 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -243,8 +243,10 @@ def _ffn_fwd_per_shard( if apply_topk_weights_early: # Fold the per-token combine weights into the FFN intermediate; # the downstream wo GEMM is linear so this is equivalent to the - # late-weighting path, modulo elementwise op fusion gains. - w_b = recv_w_flat[:, None] + # late-weighting path, modulo elementwise op fusion gains. w_b is + # cast to intermediate.dtype so the multiply doesn't promote + # expert_outputs to f32 (NCCL EP combine hard-asserts bf16). + w_b = recv_w_flat[:, None].astype(intermediate.dtype) mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] intermediate = intermediate * w_b * mask_b @@ -317,7 +319,9 @@ def _ffn_bwd_per_shard( if apply_topk_weights_early: # intermediate' = intermediate * w * mask. Split the cotangent # across both factors before the activation bwd consumes it. - w_b = recv_w_flat[:, None] + # Cast w_b so the multiply stays in d_intermediate.dtype and + # d_sorted_x (downstream into ep_dispatch_bwd) stays bf16. + w_b = recv_w_flat[:, None].astype(d_intermediate.dtype) mask_b = (recv_w_flat != 0).astype(d_intermediate.dtype)[:, None] intermediate_unweighted = act_fn(gate_proj_out) * up_proj_out d_recv_w_from_intermediate = jnp.sum( @@ -556,6 +560,17 @@ def _moe_fwd_rule( routing_weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) topk_idx_3d = selected_experts.reshape(B, S, K).astype(jnp.int32) topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) + # tex.ep_prepare/dispatch's partition only folds ep_axis into a replicated + # leading dim, not the outer dp/fsdp axes, so a replicated topk_idx makes + # each rank see B/ep rows (not B/num_procs) and overrun the bootstrap-sized + # send buffer. Pin both routing tensors to the (outer, ep) leading sharding + # so per-rank token counts match max_tokens_per_rank. + topk_idx_3d = jax.lax.with_sharding_constraint( + topk_idx_3d, NamedSharding(mesh, ep3_spec) + ) + topk_w_3d = jax.lax.with_sharding_constraint( + topk_w_3d, NamedSharding(mesh, ep3_spec) + ) # ---------------- TE EP dispatch (global view) ---------------- handle = _get_or_make_ep_handle( @@ -637,7 +652,7 @@ def _body(*args): out_partition_spec=out_partition_spec, ) else: - w = recv_topk_weights[..., None] + w = recv_topk_weights[..., None].astype(expert_outputs.dtype) mask = (recv_topk_weights != 0).astype(expert_outputs.dtype)[..., None] weighted = expert_outputs * w * mask output = tex.ep_combine_fwd( @@ -754,8 +769,10 @@ def _moe_bwd_rule( d_recv_w_from_combine = jnp.zeros_like(ctx.recv_topk_weights) else: # combine_fwd consumed weighted = expert_out * w * mask; - # split the cotangent across both factors. - w = ctx.recv_topk_weights[..., None] + # split the cotangent across both factors. w is cast to + # grad_pre_combine.dtype so the multiply stays bf16 and + # d_sorted_x (downstream into ep_dispatch_bwd) stays bf16. + w = ctx.recv_topk_weights[..., None].astype(grad_pre_combine.dtype) mask = (ctx.recv_topk_weights != 0).astype(grad_pre_combine.dtype)[..., None] d_expert_outputs = grad_pre_combine * w * mask d_recv_w_from_combine = (grad_pre_combine * ctx.expert_outputs * mask).sum(axis=-1) From df61642a35c4e1a12c6b10cdbaee2a34e772beb0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 8 Jun 2026 15:17:49 -0700 Subject: [PATCH 33/36] Fix MoEBlock tests Signed-off-by: Jeremy Berchtold --- tests/jax/test_multi_process_ep.py | 1 + .../jax/cpp_extensions/router.py | 2 +- transformer_engine/jax/moe.py | 108 ++++++++++++++---- 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index abdbcd32ec..fc685bf092 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -122,6 +122,7 @@ def setUpClass(cls): max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=cls.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, + max_token_dtype=jnp.bfloat16, # XLA reallocates handle_mem between JIT executables. allow_handle_mem_reloc=True, ) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 0ae267cbf3..f5fa7722f4 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -571,7 +571,7 @@ def shardy_sharding_rule(*args): # backward reconstructs the full [num_tokens, num_experts] grad_probs from # scalar inputs. Shardy will leave num_tokens unsharded, which matches the # replicated PartitionSpec(None, None) in partition(). - return "const_buf_one, num_experts, grad_one -> i num_experts" + return "const_buf_one, num_experts, -> i num_experts" register_primitive(FusedMoEAuxLossBwdPrimitive) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index a96a77b991..50fdde2dc9 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -153,6 +153,7 @@ def _te_ep_assert_compatible_bootstrap( # ============================================================================= +@jax.tree_util.register_pytree_node_class @dataclass class _Ctx: """Residuals carried from the fwd rule into the bwd rule.""" @@ -180,6 +181,79 @@ class _Ctx: aux_tokens_per_expert: Any = None aux_saved_scores: Any = None + def tree_flatten(self): + children = ( + self.x, + self.gate_kernel, + self.expert_bias, + self.logits_2d, + self.saved_scores, + self.routing_map, + self.handle_mem, + self.token_counts, + self.recv_topk_weights, + self.casted_sorted_x_lhs_trans, + self.casted_wi_rhs_trans, + self.gate_proj_out, + self.up_proj_out, + self.casted_intermediate_lhs_trans, + self.casted_wo_rhs_trans, + self.expert_outputs, + self.local_group_sizes, + self.aux_const_buf, + self.aux_tokens_per_expert, + self.aux_saved_scores, + ) + return children, self.handle + + @classmethod + def tree_unflatten(cls, aux_data, children): + ( + x, + gate_kernel, + expert_bias, + logits_2d, + saved_scores, + routing_map, + handle_mem, + token_counts, + recv_topk_weights, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + expert_outputs, + local_group_sizes, + aux_const_buf, + aux_tokens_per_expert, + aux_saved_scores, + ) = children + return cls( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + handle=aux_data, + handle_mem=handle_mem, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, + ) + # ============================================================================= # Per-shard FFN body (runs inside shard_map) @@ -218,9 +292,9 @@ def _ffn_fwd_per_shard( wi_1 = wi_1.astype(sorted_x.dtype) wo = wo.astype(sorted_x.dtype) - wi_combined = jnp.stack([wi_0, wi_1], axis=-2) + wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) wi_combined_bias = ( - jnp.stack([wi_0_bias, wi_1_bias], axis=-2) if wi_0_bias is not None else None + jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None ) q_set = noop_quantizer_set @@ -232,8 +306,7 @@ def _ffn_fwd_per_shard( contracting_dims=((1,), (1,)), bias=wi_combined_bias, ) - gate_proj_out = combined_out[..., 0, :] - up_proj_out = combined_out[..., 1, :] + gate_proj_out, up_proj_out = jnp.split(combined_out, 2, axis=-1) casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) @@ -337,32 +410,24 @@ def _ffn_bwd_per_shard( (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * up_proj_out) # wi bwd (fused gate/up) - inter_M = d_gate_proj_out.shape[-1] - d_combined = jnp.stack([d_gate_proj_out, d_up_proj_out], axis=-2) + d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) casted_d_combined = tex.grouped_quantize( d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), casted_wi_rhs_trans, - contracting_dims=((1, 2), (2, 3)), + contracting_dims=((1,), (2,)), ) d_wi_combined = tex.grouped_gemm( casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0 = d_wi_combined[..., 0, :] - d_wi_1 = d_wi_combined[..., 1, :] + d_wi_0, d_wi_1 = jnp.split(d_wi_combined, 2, axis=-1) if has_bias: - # tex.grouped_dbias takes a rank-2 input; reshape around the call. - d_combined_2d = d_combined.reshape(d_combined.shape[0], -1) - d_wi_combined_bias_2d = tex.grouped_dbias(d_combined_2d, local_group_sizes) - d_wi_combined_bias = d_wi_combined_bias_2d.reshape( - *d_wi_combined_bias_2d.shape[:-1], 2, inter_M - ) - d_wi_0_bias = d_wi_combined_bias[..., 0, :] - d_wi_1_bias = d_wi_combined_bias[..., 1, :] + d_wi_combined_bias = tex.grouped_dbias(d_combined, local_group_sizes) + d_wi_0_bias, d_wi_1_bias = jnp.split(d_wi_combined_bias, 2, axis=-1) else: d_wi_0_bias = None d_wi_1_bias = None @@ -453,10 +518,11 @@ def _moe_fwd_rule( # block starts on the alignment boundary that grouped_gemm expects. natural_recv_pr = (B // dp_size) * S * K natural_spe = (natural_recv_pr + num_local_experts - 1) // num_local_experts - if align_size > 0: - slots_per_expert = ((natural_spe + align_size - 1) // align_size) * align_size - else: - slots_per_expert = natural_spe + # NCCL EP requires each expert-major output block to be at least + # 128-token aligned. Keep larger caller-requested alignments, but do + # not emit the smaller natural block size for tiny tests. + effective_align = max(int(align_size), 128) + slots_per_expert = ((natural_spe + effective_align - 1) // effective_align) * effective_align recv_pr = num_local_experts * slots_per_expert # Per-rank input token count: B/num_procs rows x S tokens. The bootstrap # uses this to size the dispatch send buffer; recv_pr above sizes the From 2210702aa02b3f63fc45b1df160ad18ed861701e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 9 Jun 2026 07:24:26 -0700 Subject: [PATCH 34/36] Fixes during maxtext integration Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/ep.py | 8 +- transformer_engine/jax/cpp_extensions/gemm.py | 64 +++- transformer_engine/jax/ep.py | 13 +- transformer_engine/jax/moe.py | 327 +++++++++--------- 4 files changed, 249 insertions(+), 163 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 8fb0d90f8a..456cab4fc0 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -24,7 +24,7 @@ import transformer_engine_jax from .base import BasePrimitive, register_primitive -from ..sharding import global_mesh_resource +from ..sharding import global_mesh_resource, get_mesh_axis_size __all__ = [ "EpConfig", @@ -189,6 +189,10 @@ def _ep_outer_axis(): sees each DP color's slab as distinct (rather than replicated across DP). """ gsr = global_mesh_resource() + if gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1: + return gsr.dp_resource + if gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1: + return gsr.fsdp_resource return gsr.dp_resource or gsr.fsdp_resource @@ -536,7 +540,7 @@ def _resolve_out_partition_spec(out_partition_spec, num_leading): "ep_combine: ep_resource is not set on the active MeshResource;" " pass out_sharding=... explicitly." ) - outer = gsr.dp_resource or gsr.fsdp_resource + outer = _ep_outer_axis() leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource return (leading,) + (None,) * num_leading diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 655aa6e3f9..bda58dc8e9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1790,6 +1790,24 @@ def _parse_partition_specs( additional_arg_0_spec = filter_spec_axes(original_arg_specs[11], allowed_axes) additional_arg_1_spec = filter_spec_axes(original_arg_specs[12], allowed_axes) + def spec_has_axes(spec): + return any(axis_spec is not None for axis_spec in spec) + + if arg_infos[5].size > 0 and len(lhs_data_spec) > 0 and not spec_has_axes(lhs_first_dims_spec): + lhs_first_dims_spec = (merge_axis_specs(lhs_first_dims_spec[0], lhs_data_spec[0]),) + if arg_infos[6].size > 0 and len(lhs_data_spec) > 0 and not spec_has_axes(lhs_last_dims_spec): + lhs_last_dims_spec = (merge_axis_specs(lhs_last_dims_spec[0], lhs_data_spec[-1]),) + if arg_infos[7].size > 0 and len(rhs_data_spec) > 0 and not spec_has_axes(rhs_first_dims_spec): + rhs_first_dims_spec = (merge_axis_specs(rhs_first_dims_spec[0], rhs_data_spec[0]),) + if arg_infos[8].size > 0 and len(rhs_data_spec) > 0 and not spec_has_axes(rhs_last_dims_spec): + rhs_last_dims_spec = (merge_axis_specs(rhs_last_dims_spec[0], rhs_data_spec[-1]),) + if arg_infos[9].size > 0 and not spec_has_axes(out_first_dims_spec): + out_first_dims_spec = ( + merge_axis_specs(out_first_dims_spec[0], lhs_first_dims_spec[0]), + ) + if arg_infos[10].size > 0 and not spec_has_axes(out_last_dims_spec): + out_last_dims_spec = (merge_axis_specs(out_last_dims_spec[0], lhs_last_dims_spec[0]),) + grouped_dim_specs = ( lhs_first_dims_spec, lhs_last_dims_spec, @@ -1828,6 +1846,20 @@ def _parse_partition_specs( if len(bias_spec) > 0 and not spec_contains_axis(bias_spec, ep_axis): bias_spec = (merge_axis_specs(bias_spec[0], ep_axis), *bias_spec[1:]) + if not rhs_is_ragged and spec_contains_axis(active_group_spec, fsdp_axis): + if len(rhs_data_spec) > 0: + rhs_data_spec = ( + merge_axis_specs(rhs_data_spec[0], active_group_spec[0]), + *rhs_data_spec[1:], + ) + if len(rhs_scale_spec) > 0: + rhs_scale_spec = ( + merge_axis_specs(rhs_scale_spec[0], active_group_spec[0]), + *rhs_scale_spec[1:], + ) + if len(bias_spec) > 0: + bias_spec = (merge_axis_specs(bias_spec[0], active_group_spec[0]), *bias_spec[1:]) + gather_rhs_fsdp = ( fsdp_axis is not None and not rhs_is_ragged @@ -1839,9 +1871,23 @@ def _parse_partition_specs( ) if gather_rhs_fsdp: - rhs_data_spec = strip_axis_from_spec(rhs_data_spec, fsdp_axis) - rhs_scale_spec = strip_axis_from_spec(rhs_scale_spec, fsdp_axis) - bias_spec = strip_axis_from_spec(bias_spec, fsdp_axis) + if spec_contains_axis(active_group_spec, fsdp_axis): + if len(rhs_data_spec) > 0: + rhs_data_spec = ( + rhs_data_spec[0], + *strip_axis_from_spec(rhs_data_spec[1:], fsdp_axis), + ) + if len(rhs_scale_spec) > 0: + rhs_scale_spec = ( + rhs_scale_spec[0], + *strip_axis_from_spec(rhs_scale_spec[1:], fsdp_axis), + ) + if len(bias_spec) > 0: + bias_spec = (bias_spec[0], *strip_axis_from_spec(bias_spec[1:], fsdp_axis)) + else: + rhs_data_spec = strip_axis_from_spec(rhs_data_spec, fsdp_axis) + rhs_scale_spec = strip_axis_from_spec(rhs_scale_spec, fsdp_axis) + bias_spec = strip_axis_from_spec(bias_spec, fsdp_axis) reducible_axes = tuple( axis for axis in (gsr.dp_resource, gsr.fsdp_resource) if axis is not None @@ -1857,6 +1903,18 @@ def _parse_partition_specs( original_out_spec = None out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) + if not rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: + lhs_non_contracting_dims = ( + range(lhs_axis_boundary, len(lhs_data_spec)) + if lhs_is_trans + else range(0, lhs_axis_boundary) + ) + out_spec = list(out_spec) + for out_idx, lhs_dim in enumerate(lhs_non_contracting_dims): + if out_idx < len(out_spec) and out_spec[out_idx] is None: + out_spec[out_idx] = merge_axis_specs(out_spec[out_idx], lhs_data_spec[lhs_dim]) + out_spec = tuple(out_spec) + if rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: lhs_non_contracting_dims = ( range(lhs_axis_boundary, len(lhs_data_spec)) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 17d00bef87..cf7f936b38 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -236,7 +236,7 @@ def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs): # single-fwd-output cotangent, landing a global tensor in the FFI. gsr = global_mesh_resource() ep_axis = gsr.ep_resource - outer = gsr.dp_resource or gsr.fsdp_resource + outer = _ep_outer_axis() leading = (outer, ep_axis) if outer is not None else ep_axis g_recv_tokens = jax.lax.with_sharding_constraint( g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) @@ -253,6 +253,15 @@ def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs): ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) +def _ep_outer_axis(): + gsr = global_mesh_resource() + if gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1: + return gsr.dp_resource + if gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1: + return gsr.fsdp_resource + return gsr.dp_resource or gsr.fsdp_resource + + # ── ep_combine (custom_vjp) ────────────────────────────────────────────────── @@ -315,7 +324,7 @@ def _combine_bwd(handle, _num_local_tokens, _out_sharding, res, g_result): spec = jax.sharding.PartitionSpec(*_out_sharding) else: ep_axis = gsr.ep_resource - outer = gsr.dp_resource or gsr.fsdp_resource + outer = _ep_outer_axis() leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis spec = ( jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 50fdde2dc9..807d4ef13a 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -21,18 +21,18 @@ :func:`moe` soft-repins this on entry and warns when a reshard is inserted. * The EP primitives operate at global view (their custom_partitioning - rules handle per-shard execution). The FFN GEMMs run per-shard inside - a small ``shard_map`` whose ``in_specs`` and ``out_specs`` mirror the - same ``((dp, ep), ...)`` layout. + rules handle per-shard execution). The FFN also stays at global view: + grouped GEMM custom partitioning handles the per-shard grouped math, + including FSDP reductions for wgrads. Out-of-scope (for now) ---------------------- FP8 / MXFP8 quantizer sets are not yet wired on this path; turning them on requires recipe-aware residual specs and ``ScaledTensor`` -leaves across the ``shard_map`` boundary. ``aux_loss_coeff`` and -``expert_bias`` are supported (the former forces a per-step -all-gather over the routing-side logits, which lives off the critical -path and overlaps with the dispatch collective). +residual handling. ``aux_loss_coeff`` and ``expert_bias`` are supported +(the former forces a per-step all-gather over the routing-side logits, +which lives off the critical path and overlaps with the dispatch +collective). """ from dataclasses import dataclass @@ -57,6 +57,23 @@ __all__ = ["moe"] +def _with_sharding_constraint_cast_bwd(x: jax.Array, sharding: NamedSharding) -> jax.Array: + """Apply a sharding constraint while keeping cotangents in the primal dtype.""" + + @jax.custom_vjp + def _constraint(y): + return jax.lax.with_sharding_constraint(y, sharding) + + def _constraint_fwd(y): + return jax.lax.with_sharding_constraint(y, sharding), jnp.zeros((), dtype=y.dtype) + + def _constraint_bwd(dtype_ref, grad): + return (jax.lax.with_sharding_constraint(grad.astype(dtype_ref.dtype), sharding),) + + _constraint.defvjp(_constraint_fwd, _constraint_bwd) + return _constraint(x) + + # ============================================================================= # Process-level NCCL EP bootstrap (must run eagerly, outside jax.jit) # ============================================================================= @@ -256,7 +273,7 @@ def tree_unflatten(cls, aux_data, children): # ============================================================================= -# Per-shard FFN body (runs inside shard_map) +# Custom-partitioned grouped-GEMM FFN body # ============================================================================= @@ -270,6 +287,10 @@ def _ffn_fwd_per_shard( wi_1_bias: Optional[jnp.ndarray], wo_bias: Optional[jnp.ndarray], *, + ep_axis: str, + data_parallelism_axes: Tuple[str, ...], + dp_size: int, + num_ep: int, num_local_experts: int, slots_per_expert: int, activation_type: str, @@ -277,20 +298,61 @@ def _ffn_fwd_per_shard( ): """Per-shard FFN forward. - Operates on the shard-local ``[1, recv_pr, H]`` slice that - ``tex.ep_dispatch`` produces. Returns the expert outputs (shaped - ``[1, recv_pr, H_out]`` so the surrounding ``shard_map`` reassembles - them as ``[num_procs, recv_pr, H_out]``) plus the residuals consumed - by the bwd. + Operates on the global EP-dispatch output while grouped GEMM custom + partitioning lowers the math to per-shard groups. Returns expert + outputs shaped as ``[num_procs, recv_pr, H_out]`` plus the residuals + consumed by the bwd. """ hidden = recv_tokens_local.shape[-1] + mesh = _get_mesh() + batch_axis: Any = ep_axis if not data_parallelism_axes else (*data_parallelism_axes, ep_axis) + row_sharding = NamedSharding(mesh, P(batch_axis, None)) + group_sharding = NamedSharding(mesh, P(batch_axis)) + weight_sharding = NamedSharding(mesh, P(batch_axis, None, None)) + bias_sharding = NamedSharding(mesh, P(batch_axis, None)) + sorted_x = recv_tokens_local.reshape(-1, hidden) recv_w_flat = recv_topk_weights_local.reshape(-1) - local_group_sizes = jnp.full((num_local_experts,), slots_per_expert, dtype=jnp.int32) + num_groups = dp_size * num_ep * num_local_experts + group_sizes = jnp.full((num_groups,), slots_per_expert, dtype=jnp.int32) + sorted_x = jax.lax.with_sharding_constraint(sorted_x, row_sharding) + recv_w_flat = jax.lax.with_sharding_constraint(recv_w_flat, group_sharding) + group_sizes = jax.lax.with_sharding_constraint(group_sizes, group_sharding) wi_0 = wi_0.astype(sorted_x.dtype) wi_1 = wi_1.astype(sorted_x.dtype) wo = wo.astype(sorted_x.dtype) + wi_0 = jnp.broadcast_to( + wi_0.reshape(1, num_ep, num_local_experts, *wi_0.shape[1:]), + (dp_size, num_ep, num_local_experts, *wi_0.shape[1:]), + ).reshape(num_groups, *wi_0.shape[1:]) + wi_1 = jnp.broadcast_to( + wi_1.reshape(1, num_ep, num_local_experts, *wi_1.shape[1:]), + (dp_size, num_ep, num_local_experts, *wi_1.shape[1:]), + ).reshape(num_groups, *wi_1.shape[1:]) + wo = jnp.broadcast_to( + wo.reshape(1, num_ep, num_local_experts, *wo.shape[1:]), + (dp_size, num_ep, num_local_experts, *wo.shape[1:]), + ).reshape(num_groups, *wo.shape[1:]) + wi_0 = jax.lax.with_sharding_constraint(wi_0, weight_sharding) + wi_1 = jax.lax.with_sharding_constraint(wi_1, weight_sharding) + wo = jax.lax.with_sharding_constraint(wo, weight_sharding) + if wi_0_bias is not None: + wi_0_bias = jnp.broadcast_to( + wi_0_bias.reshape(1, num_ep, num_local_experts, *wi_0_bias.shape[1:]), + (dp_size, num_ep, num_local_experts, *wi_0_bias.shape[1:]), + ).reshape(num_groups, *wi_0_bias.shape[1:]) + wi_1_bias = jnp.broadcast_to( + wi_1_bias.reshape(1, num_ep, num_local_experts, *wi_1_bias.shape[1:]), + (dp_size, num_ep, num_local_experts, *wi_1_bias.shape[1:]), + ).reshape(num_groups, *wi_1_bias.shape[1:]) + wo_bias = jnp.broadcast_to( + wo_bias.reshape(1, num_ep, num_local_experts, *wo_bias.shape[1:]), + (dp_size, num_ep, num_local_experts, *wo_bias.shape[1:]), + ).reshape(num_groups, *wo_bias.shape[1:]) + wi_0_bias = jax.lax.with_sharding_constraint(wi_0_bias, bias_sharding) + wi_1_bias = jax.lax.with_sharding_constraint(wi_1_bias, bias_sharding) + wo_bias = jax.lax.with_sharding_constraint(wo_bias, bias_sharding) wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) wi_combined_bias = ( @@ -298,7 +360,7 @@ def _ffn_fwd_per_shard( ) q_set = noop_quantizer_set - casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1) + casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, group_sizes, flatten_axis=-1) casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1) combined_out = tex.grouped_gemm( casted_sorted_x.get_tensor(usage=TensorUsage.LHS), @@ -306,12 +368,14 @@ def _ffn_fwd_per_shard( contracting_dims=((1,), (1,)), bias=wi_combined_bias, ) + combined_out = jax.lax.with_sharding_constraint(combined_out, row_sharding) gate_proj_out, up_proj_out = jnp.split(combined_out, 2, axis=-1) casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) act_fn = _convert_to_activation_function(activation_type) intermediate = act_fn(gate_proj_out) * up_proj_out + intermediate = jax.lax.with_sharding_constraint(intermediate, row_sharding) if apply_topk_weights_early: # Fold the per-token combine weights into the FFN intermediate; @@ -322,9 +386,10 @@ def _ffn_fwd_per_shard( w_b = recv_w_flat[:, None].astype(intermediate.dtype) mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] intermediate = intermediate * w_b * mask_b + intermediate = jax.lax.with_sharding_constraint(intermediate, row_sharding) casted_intermediate = tex.grouped_quantize( - intermediate, q_set.x, local_group_sizes, flatten_axis=-1 + intermediate, q_set.x, group_sizes, flatten_axis=-1 ) casted_wo = tex.grouped_quantize(wo, q_set.kernel, flatten_axis=-1) expert_outputs = tex.grouped_gemm( @@ -333,10 +398,13 @@ def _ffn_fwd_per_shard( contracting_dims=((1,), (1,)), bias=wo_bias, ) + expert_outputs = jax.lax.with_sharding_constraint(expert_outputs, row_sharding) casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS) - expert_outputs_3d = expert_outputs.reshape(1, expert_outputs.shape[0], expert_outputs.shape[1]) + expert_outputs_3d = expert_outputs.reshape( + dp_size * num_ep, num_local_experts * slots_per_expert, expert_outputs.shape[-1] + ) residuals = ( casted_sorted_x_lhs_trans, casted_wi_rhs_trans, @@ -344,7 +412,7 @@ def _ffn_fwd_per_shard( up_proj_out, casted_intermediate_lhs_trans, casted_wo_rhs_trans, - local_group_sizes, + group_sizes, ) return expert_outputs_3d, residuals @@ -360,6 +428,12 @@ def _ffn_bwd_per_shard( local_group_sizes: jnp.ndarray, recv_topk_weights_local: jnp.ndarray, *, + ep_axis: str, + data_parallelism_axes: Tuple[str, ...], + dp_size: int, + num_ep: int, + num_local_experts: int, + slots_per_expert: int, activation_type: str, apply_topk_weights_early: bool, has_bias: bool, @@ -370,8 +444,18 @@ def _ffn_bwd_per_shard( ``(d_sorted_x [1, recv_pr, H], d_recv_w [1, recv_pr], d_wi_0, d_wi_1, d_wo, d_wi_0_bias, d_wi_1_bias, d_wo_bias)``. """ + mesh = _get_mesh() + batch_axis: Any = ep_axis if not data_parallelism_axes else (*data_parallelism_axes, ep_axis) + row_sharding = NamedSharding(mesh, P(batch_axis, None)) + group_sharding = NamedSharding(mesh, P(batch_axis)) + weight_sharding = NamedSharding(mesh, P(batch_axis, None, None)) + bias_sharding = NamedSharding(mesh, P(batch_axis, None)) + d_eo_2d = d_expert_outputs_local.reshape(-1, d_expert_outputs_local.shape[-1]) recv_w_flat = recv_topk_weights_local.reshape(-1) + d_eo_2d = jax.lax.with_sharding_constraint(d_eo_2d, row_sharding) + recv_w_flat = jax.lax.with_sharding_constraint(recv_w_flat, group_sharding) + local_group_sizes = jax.lax.with_sharding_constraint(local_group_sizes, group_sharding) q_set = noop_quantizer_set # wo bwd @@ -381,12 +465,16 @@ def _ffn_bwd_per_shard( casted_wo_rhs_trans, contracting_dims=((1,), (2,)), ) + d_intermediate = jax.lax.with_sharding_constraint(d_intermediate, row_sharding) d_wo = tex.grouped_gemm( casted_intermediate_lhs_trans, casted_d_eo.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) + d_wo = jax.lax.with_sharding_constraint(d_wo, weight_sharding) d_wo_bias = tex.grouped_dbias(d_eo_2d, local_group_sizes) if has_bias else None + if has_bias: + d_wo_bias = jax.lax.with_sharding_constraint(d_wo_bias, bias_sharding) act_fn = _convert_to_activation_function(activation_type) if apply_topk_weights_early: @@ -401,6 +489,7 @@ def _ffn_bwd_per_shard( d_intermediate * intermediate_unweighted * mask_b, axis=-1 ).astype(recv_w_flat.dtype) d_intermediate = d_intermediate * w_b * mask_b + d_intermediate = jax.lax.with_sharding_constraint(d_intermediate, row_sharding) else: d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) @@ -411,6 +500,7 @@ def _ffn_bwd_per_shard( # wi bwd (fused gate/up) d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) + d_combined = jax.lax.with_sharding_constraint(d_combined, row_sharding) casted_d_combined = tex.grouped_quantize( d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 ) @@ -419,21 +509,42 @@ def _ffn_bwd_per_shard( casted_wi_rhs_trans, contracting_dims=((1,), (2,)), ) + d_sorted_x = jax.lax.with_sharding_constraint(d_sorted_x, row_sharding) d_wi_combined = tex.grouped_gemm( casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) + d_wi_combined = jax.lax.with_sharding_constraint(d_wi_combined, weight_sharding) d_wi_0, d_wi_1 = jnp.split(d_wi_combined, 2, axis=-1) + d_wi_0 = d_wi_0.reshape(dp_size, num_ep * num_local_experts, *d_wi_0.shape[1:]).sum(axis=0) + d_wi_1 = d_wi_1.reshape(dp_size, num_ep * num_local_experts, *d_wi_1.shape[1:]).sum(axis=0) if has_bias: d_wi_combined_bias = tex.grouped_dbias(d_combined, local_group_sizes) + d_wi_combined_bias = jax.lax.with_sharding_constraint( + d_wi_combined_bias, bias_sharding + ) d_wi_0_bias, d_wi_1_bias = jnp.split(d_wi_combined_bias, 2, axis=-1) + d_wi_0_bias = d_wi_0_bias.reshape( + dp_size, num_ep * num_local_experts, *d_wi_0_bias.shape[1:] + ).sum(axis=0) + d_wi_1_bias = d_wi_1_bias.reshape( + dp_size, num_ep * num_local_experts, *d_wi_1_bias.shape[1:] + ).sum(axis=0) else: d_wi_0_bias = None d_wi_1_bias = None - d_sorted_x_3d = d_sorted_x.reshape(1, d_sorted_x.shape[0], d_sorted_x.shape[1]) - d_recv_w_3d = d_recv_w_from_intermediate.reshape(1, -1) + d_wo = d_wo.reshape(dp_size, num_ep * num_local_experts, *d_wo.shape[1:]).sum(axis=0) + if has_bias: + d_wo_bias = d_wo_bias.reshape(dp_size, num_ep * num_local_experts, *d_wo_bias.shape[1:]).sum(axis=0) + + d_sorted_x_3d = d_sorted_x.reshape( + dp_size * num_ep, num_local_experts * slots_per_expert, d_sorted_x.shape[-1] + ) + d_recv_w_3d = d_recv_w_from_intermediate.reshape( + dp_size * num_ep, num_local_experts * slots_per_expert + ) return ( d_sorted_x_3d, d_recv_w_3d, @@ -480,14 +591,12 @@ def _moe_fwd_rule( apply_topk_weights_early, align_size, ): - """Forward: gate -> topk -> ep_dispatch -> shard_map(FFN) -> ep_combine. + """Forward: gate -> topk -> ep_dispatch -> grouped-GEMM FFN -> ep_combine. Returns ``(output, aux_loss)``. ``aux_loss`` is a zero scalar when ``aux_loss_coeff == 0``. """ del gate_kernel_axes, wi_kernel_axes, wo_kernel_axes # used in bwd only - from jax.experimental.shard_map import shard_map - x = with_sharding_constraint_by_logical_axes(x, input_axes) mesh = _get_mesh() @@ -651,57 +760,26 @@ def _moe_fwd_rule( recv_topk_weights, NamedSharding(mesh, ep2_spec) ) - # ---------------- FFN (per-shard via shard_map) ---------------- + # ---------------- FFN (custom-partitioned grouped GEMM) ---------------- has_bias = wi_0_bias is not None - kernel_spec = P(ep_axis, None, None) - bias_spec = P(ep_axis, None) if has_bias else None - ffn_in_specs = (ep3_spec, ep2_spec, kernel_spec, kernel_spec, kernel_spec) - ffn_in_args = [recv_tokens, recv_topk_weights, wi_0, wi_1, wo] - if has_bias: - ffn_in_specs = ffn_in_specs + (bias_spec, bias_spec, bias_spec) - ffn_in_args.extend([wi_0_bias, wi_1_bias, wo_bias]) - - # FFN residuals live entirely on the local ep rank, so the leading - # "experts" / "rows" dims map to P() (already shard-local). - residuals_spec = ( - P(), # casted_sorted_x_lhs_trans - P(ep_axis, None, None), # casted_wi_rhs_trans - P(), # gate_proj_out - P(), # up_proj_out - P(), # casted_intermediate_lhs_trans - P(ep_axis, None, None), # casted_wo_rhs_trans - P(), # local_group_sizes + expert_outputs, ffn_residuals = _ffn_fwd_per_shard( + recv_tokens, + recv_topk_weights, + wi_0, + wi_1, + wo, + wi_0_bias, + wi_1_bias, + wo_bias, + ep_axis=ep_axis, + data_parallelism_axes=data_parallelism_axes, + dp_size=dp_size, + num_ep=num_ep, + num_local_experts=num_local_experts, + slots_per_expert=slots_per_expert, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, ) - out_specs = (ep3_spec, residuals_spec) - - def _body(*args): - if has_bias: - (r_tok, r_w, w0, w1, w_o, w0b, w1b, wob) = args - else: - (r_tok, r_w, w0, w1, w_o) = args - w0b = w1b = wob = None - return _ffn_fwd_per_shard( - r_tok, - r_w, - w0, - w1, - w_o, - w0b, - w1b, - wob, - num_local_experts=num_local_experts, - slots_per_expert=slots_per_expert, - activation_type=activation_type, - apply_topk_weights_early=apply_topk_weights_early, - ) - - expert_outputs, ffn_residuals = shard_map( - _body, - mesh=mesh, - in_specs=ffn_in_specs, - out_specs=out_specs, - check_rep=False, - )(*ffn_in_args) expert_outputs = jax.lax.with_sharding_constraint( expert_outputs, NamedSharding(mesh, ep3_spec) ) @@ -794,7 +872,6 @@ def _moe_bwd_rule( ): """Backward mirror of :func:`_moe_fwd_rule`.""" del num_groups, group_topk, dtype, align_size # captured in residuals / unused in bwd - from jax.experimental.shard_map import shard_map d_output, d_aux_loss = cotangents @@ -844,22 +921,17 @@ def _moe_bwd_rule( d_recv_w_from_combine = (grad_pre_combine * ctx.expert_outputs * mask).sum(axis=-1) d_recv_w_from_combine = d_recv_w_from_combine.astype(ctx.recv_topk_weights.dtype) - # ---------------- FFN bwd (per-shard via shard_map) ---------------- - kernel_spec = P(ep_axis, None, None) - bias_spec = P(ep_axis, None) if has_bias else None - - bwd_in_specs = ( - ep3_spec, # d_expert_outputs - P(), # casted_sorted_x_lhs_trans - P(ep_axis, None, None), # casted_wi_rhs_trans - P(), # gate_proj_out - P(), # up_proj_out - P(), # casted_intermediate_lhs_trans - P(ep_axis, None, None), # casted_wo_rhs_trans - P(), # local_group_sizes - ep2_spec, # recv_topk_weights - ) - bwd_in_args = [ + # ---------------- FFN bwd (custom-partitioned grouped GEMM) ---------------- + ( + d_sorted_x, + d_recv_w_from_intermediate, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = _ffn_bwd_per_shard( d_expert_outputs, ctx.casted_sorted_x_lhs_trans, ctx.casted_wi_rhs_trans, @@ -869,74 +941,17 @@ def _moe_bwd_rule( ctx.casted_wo_rhs_trans, ctx.local_group_sizes, ctx.recv_topk_weights, - ] - bwd_out_specs = ( - ep3_spec, # d_sorted_x - ep2_spec, # d_recv_w_from_intermediate - kernel_spec, # d_wi_0 - kernel_spec, # d_wi_1 - kernel_spec, # d_wo - bias_spec if has_bias else None, # d_wi_0_bias - bias_spec if has_bias else None, # d_wi_1_bias - bias_spec if has_bias else None, # d_wo_bias + ep_axis=ep_axis, + data_parallelism_axes=data_parallelism_axes, + dp_size=dp_size, + num_ep=num_ep, + num_local_experts=num_experts // num_ep, + slots_per_expert=recv_pr // (num_experts // num_ep), + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, + has_bias=has_bias, ) - def _bwd_body(*args): - ( - d_sorted_x_3d, - d_recv_w_3d, - d_wi_0, - d_wi_1, - d_wo, - d_wi_0_bias, - d_wi_1_bias, - d_wo_bias, - ) = _ffn_bwd_per_shard( - *args, - activation_type=activation_type, - apply_topk_weights_early=apply_topk_weights_early, - has_bias=has_bias, - ) - # Weight grads accumulate per-DP-shard inside the body; psum across - # DP axes so each replica sees the full sum (matches out_specs - # P(ep_axis, ...) which is DP-replicated). - if data_parallelism_axes: - dp = tuple(data_parallelism_axes) - d_wi_0 = jax.lax.psum(d_wi_0, axis_name=dp) - d_wi_1 = jax.lax.psum(d_wi_1, axis_name=dp) - d_wo = jax.lax.psum(d_wo, axis_name=dp) - if has_bias: - d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=dp) - d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=dp) - d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=dp) - return ( - d_sorted_x_3d, - d_recv_w_3d, - d_wi_0, - d_wi_1, - d_wo, - d_wi_0_bias, - d_wi_1_bias, - d_wo_bias, - ) - - ( - d_sorted_x, - d_recv_w_from_intermediate, - d_wi_0, - d_wi_1, - d_wo, - d_wi_0_bias, - d_wi_1_bias, - d_wo_bias, - ) = shard_map( - _bwd_body, - mesh=mesh, - in_specs=bwd_in_specs, - out_specs=bwd_out_specs, - check_rep=False, - )(*bwd_in_args) - d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate # ---------------- Dispatch bwd (global view) ---------------- @@ -1193,7 +1208,7 @@ def moe( UserWarning, stacklevel=2, ) - x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, expected_spec)) + x = _with_sharding_constraint_cast_bwd(x, NamedSharding(mesh, expected_spec)) # custom_vjp can't trace through None args; lower expert_bias to an # empty shape-(0,) tensor that fused_topk_with_score_function treats From c88812172f6df3dad40670378f463ae9ebcc0917 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 15 Jun 2026 08:32:57 -0700 Subject: [PATCH 35/36] Integrate MoEBlock with grouped quant+gemm custom partitioning Signed-off-by: Jeremy Berchtold --- tests/jax/test_distributed_grouped_gemm.py | 44 +++++- tests/jax/test_te_ep_moe.py | 109 +++++++++++++-- transformer_engine/jax/flax/moe.py | 25 ++-- transformer_engine/jax/moe.py | 147 ++++++++++++--------- 4 files changed, 242 insertions(+), 83 deletions(-) diff --git a/tests/jax/test_distributed_grouped_gemm.py b/tests/jax/test_distributed_grouped_gemm.py index be2487a8df..5eea6b544c 100644 --- a/tests/jax/test_distributed_grouped_gemm.py +++ b/tests/jax/test_distributed_grouped_gemm.py @@ -1,8 +1,16 @@ # Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Partitioning tests for grouped quantize and grouped GEMM.""" +"""Partitioning tests for grouped quantize and grouped GEMM. +The file normally runs in a single process over all local GPUs. It also +supports the ``--num-process`` / ``--process-id`` options from +``tests/jax/conftest.py`` so the executable grouped-dense MXFP8 test can +be run with one process per GPU, matching the TE-EP MoE launcher model. +""" + +import os +import sys from types import SimpleNamespace import jax @@ -18,6 +26,40 @@ from transformer_engine.jax.sharding import MeshResource, global_shard_guard +def _init_distributed_from_pytest_args() -> bool: + num_process = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + process_id = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, arg in enumerate(sys.argv): + if arg.startswith("--num-process="): + num_process = int(arg.split("=", 1)[1]) + elif arg == "--num-process" and i + 1 < len(sys.argv): + num_process = int(sys.argv[i + 1]) + elif arg.startswith("--process-id="): + process_id = int(arg.split("=", 1)[1]) + elif arg == "--process-id" and i + 1 < len(sys.argv): + process_id = int(sys.argv[i + 1]) + + if num_process <= 1: + return False + + coordinator = os.environ.get( + "TE_GROUPED_GEMM_COORDINATOR_ADDRESS", + os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457"), + ) + jax.distributed.initialize( + coordinator_address=coordinator, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + assert jax.local_device_count() == 1, "one GPU per process is required" + assert jax.device_count() == num_process + return True + + +_MP_ACTIVE = _init_distributed_from_pytest_args() + + def _mesh(): devices = jax.devices() if len(devices) < 4: diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py index d56dad2c92..a350914bb8 100644 --- a/tests/jax/test_te_ep_moe.py +++ b/tests/jax/test_te_ep_moe.py @@ -47,9 +47,8 @@ * ``TestZZZTeEpMoeBootstrap`` verifies the per-process NCCL bootstrap rejects a mismatched signature. -FP8 / MXFP8 recipes are deferred — the ``quantizer_sets`` plumbing -has not yet been re-wired across the TE-EP ``shard_map`` boundary -(see ``.pr3036-review/INTEGRATION_DESIGN.md``). +FP8 / MXFP8 recipes are deferred — the quantizer-set plumbing needs +additional validation with the TE-EP dispatch/combine path. """ import os @@ -132,6 +131,8 @@ def _read_mp_options(): from transformer_engine.jax.flax import _MoEBlock as MoEBlock from transformer_engine.jax.moe import moe, record_ep_bootstrap_signature_for_moe from transformer_engine.jax.ep import ep_bootstrap +from transformer_engine.jax.dense import grouped_dense +from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode from transformer_engine.jax.sharding import MeshResource, global_shard_guard @@ -557,6 +558,26 @@ def _make_inputs(key): ] +def _forward_tolerances(_config): + return FWD_ATOL, FWD_RTOL + + +def _grad_tolerances(_config, param_name): + if param_name == "gate_kernel": + return GRAD_GATE_ATOL, GRAD_GATE_RTOL + return GRAD_FFN_ATOL, GRAD_FFN_RTOL + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) + + def _reference_kwargs_from_config(config, params_np): """Pick out the reference-relevant pieces of a parametrize config.""" return dict( @@ -569,6 +590,77 @@ def _reference_kwargs_from_config(config, params_np): ) +class TestTeEpMoeMXFP8GroupedGemm: + """MXFP8 grouped-GEMM smoke coverage under the TE-EP multiprocess launcher.""" + + def test_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(self): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("MXFP8 grouped GEMM test requires at least 4 visible GPUs.") + + mesh = Mesh(np.asarray(devices[:4]).reshape(2, 2), ("expert", FSDP_AXIS)) + n_groups = 4 + group_tokens = 128 + hidden = 256 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + + x_sharding = NamedSharding(mesh, P("expert", None)) + w_sharding = NamedSharding(mesh, P("expert", FSDP_AXIS, None)) + group_sharding = NamedSharding(mesh, P("expert")) + out_sharding = NamedSharding(mesh, P("expert", None)) + + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + with mesh, global_shard_guard( + MeshResource(fsdp_resource=FSDP_AXIS, ep_resource="expert") + ): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=DTYPE) + * jnp.asarray(0.01, dtype=DTYPE), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=DTYPE) + * jnp.asarray(0.01, dtype=DTYPE), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) + + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=(out_sharding, x_sharding, w_sharding), + )(x, w, group_sizes) + out, dx, dw = jax.block_until_ready((out, dx, dw)) + + assert tuple(out.sharding.spec) == ("expert", None) + assert tuple(dx.sharding.spec) == ("expert", None) + assert tuple(dw.sharding.spec) == ("expert", FSDP_AXIS, None) + for name, value in (("out", out), ("dx", dx), ("dw", dw)): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)), f"{name} has NaN/Inf" + assert np.any(local_value != 0.0), f"{name} is identically zero" + + class TestTeEpMoeForward: """Per-config forward correctness in a single run: shape, dtype, finiteness AND numerical parity vs the pure-JAX reference.""" @@ -601,11 +693,12 @@ def test_forward(self, mesh, config): num_experts_per_tok=TOPK, **_reference_kwargs_from_config(config, params_np), ) + atol, rtol = _forward_tolerances(config) np.testing.assert_allclose( out_te_np.astype(np.float32), np.asarray(jax.device_get(out_ref)).astype(np.float32), - atol=FWD_ATOL, - rtol=FWD_RTOL, + atol=atol, + rtol=rtol, err_msg=f"forward parity breach for config={config}", ) @@ -653,11 +746,7 @@ def loss_fn(params, x): g_te = _to_global_numpy(_unwrap(grads_te["params"][name]), mesh) assert np.all(np.isfinite(g_te)), f"{name} grad has NaN/Inf [config={config}]" assert np.any(g_te != 0.0), f"{name} grad identically zero [config={config}]" - atol, rtol = ( - (GRAD_GATE_ATOL, GRAD_GATE_RTOL) - if name == "gate_kernel" - else (GRAD_FFN_ATOL, GRAD_FFN_RTOL) - ) + atol, rtol = _grad_tolerances(config, name) np.testing.assert_allclose( g_te.astype(np.float32), grads_ref_np[name].astype(np.float32), diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 67b2f5dfdd..f6aba7436f 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -21,6 +21,7 @@ recipe-driven alignment follow-up) stabilises. """ +from dataclasses import field from typing import Any, Callable, NewType, Optional, Tuple, Union import jax.numpy as jnp @@ -31,6 +32,7 @@ from jax.sharding import PartitionSpec as P # noqa: F401 # pylint: disable=unused-import from ..moe import moe +from ..quantize import QuantizerSet, noop_quantizer_set from ..router import ScoreFunction from ..sharding import get_active_resource_axis from .module import TransformerEngineBase @@ -84,17 +86,20 @@ class _MoEBlock(TransformerEngineBase): :func:`with_logical_partitioning` and our internal :func:`with_sharding_constraint_by_logical_axes`). data_parallelism_axes : tuple[str, ...] - FSDP axes over which the input *batch* dim is sharded IN ADDITION - to the EP axis. Empty (default) means activations are replicated - across non-EP axes within an EP group; set e.g. ``("fsdp",)`` for - true FSDP-of-batch where each device owns a unique slice of the - batch. + DP/FSDP axes over which the input *batch* dim is sharded IN + ADDITION to the EP axis. Empty (default) means activations are + replicated across non-EP axes within an EP group; set e.g. + ``("fsdp",)`` or ``("dp", "fsdp")`` for outer data axes where + each EP group owns a unique slice of the batch. dtype : jnp.dtype Compute / parameter dtype. kernel_init, bias_init : Initializers. use_bias : bool Register per-expert FFN biases. + ffn_quantizer_set : QuantizerSet + Quantizer set for the grouped-GEMM FFN body. Defaults to no + quantization; MXFP8 callers should also set ``align_size=128``. """ # Architecture @@ -140,6 +145,7 @@ class _MoEBlock(TransformerEngineBase): # Per-expert router bias added before the top-k. Only meaningful when # score_function='sigmoid'. use_expert_bias: bool = False + ffn_quantizer_set: QuantizerSet = field(default_factory=lambda: noop_quantizer_set) def __post_init__(self): if self.kernel_init is None: @@ -173,10 +179,10 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: ), f"_MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" _, _, hidden_size = inputs.shape - # Param registrations must run OUTSIDE any JAX transform that - # alters the variable scope (e.g. shard_map). The functional - # ``moe(...)`` opens its own shard_map internally for the FFN - # body, so registering params here is correct. + # Param registrations stay in the Flax module scope. The functional + # ``moe(...)`` calls custom-partitioned EP and grouped-GEMM primitives + # directly, so the FFN body no longer wraps the layer params in a + # separate mapping transform. gate_kernel = self.param( "gate_kernel", nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), @@ -243,6 +249,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: wi_1_bias, wo_bias, expert_bias, + self.ffn_quantizer_set, num_experts=self.num_experts, num_experts_per_tok=self.num_experts_per_tok, activation_type=self.activation_type, diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 807d4ef13a..fa53ee2020 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -27,12 +27,12 @@ Out-of-scope (for now) ---------------------- -FP8 / MXFP8 quantizer sets are not yet wired on this path; turning -them on requires recipe-aware residual specs and ``ScaledTensor`` -residual handling. ``aux_loss_coeff`` and ``expert_bias`` are supported -(the former forces a per-step all-gather over the routing-side logits, -which lives off the critical path and overlaps with the dispatch -collective). +The FFN grouped GEMMs accept a caller-supplied quantizer set, including +MXFP8 grouped quantizers. Full recipe/state-module integration is not +yet wired through this functional MoE API. ``aux_loss_coeff`` and +``expert_bias`` are supported (the former forces a per-step all-gather +over the routing-side logits, which lives off the critical path and +overlaps with the dispatch collective). """ from dataclasses import dataclass @@ -47,6 +47,7 @@ from . import cpp_extensions as tex from .quantize import ( TensorUsage, + QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, ) @@ -273,13 +274,21 @@ def tree_unflatten(cls, aux_data, children): # ============================================================================= -# Custom-partitioned grouped-GEMM FFN body +# Global-view grouped-GEMM FFN body # ============================================================================= -def _ffn_fwd_per_shard( - recv_tokens_local: jnp.ndarray, - recv_topk_weights_local: jnp.ndarray, +def _compound_data_ep_axis( + data_parallelism_axes: Tuple[str, ...], + ep_axis: str, +) -> Union[str, Tuple[str, ...]]: + """Leading sharding axis for tensors split across DP/FSDP and EP.""" + return ep_axis if not data_parallelism_axes else (*data_parallelism_axes, ep_axis) + + +def _ffn_fwd_global( + recv_tokens: jnp.ndarray, + recv_topk_weights: jnp.ndarray, wi_0: jnp.ndarray, wi_1: jnp.ndarray, wo: jnp.ndarray, @@ -295,24 +304,26 @@ def _ffn_fwd_per_shard( slots_per_expert: int, activation_type: str, apply_topk_weights_early: bool, + ffn_quantizer_set: QuantizerSet, ): - """Per-shard FFN forward. + """Global-view FFN forward. - Operates on the global EP-dispatch output while grouped GEMM custom - partitioning lowers the math to per-shard groups. Returns expert - outputs shaped as ``[num_procs, recv_pr, H_out]`` plus the residuals - consumed by the bwd. + The operands stay in the same global shape produced by EP dispatch. + Grouped GEMM custom partitioning owns the local grouped math, including + EP expert sharding and DP/FSDP handling on the leading group dimension. + Returns expert outputs shaped as ``[num_procs, recv_pr, H_out]`` plus + the residuals consumed by the bwd. """ - hidden = recv_tokens_local.shape[-1] + hidden = recv_tokens.shape[-1] mesh = _get_mesh() - batch_axis: Any = ep_axis if not data_parallelism_axes else (*data_parallelism_axes, ep_axis) - row_sharding = NamedSharding(mesh, P(batch_axis, None)) - group_sharding = NamedSharding(mesh, P(batch_axis)) - weight_sharding = NamedSharding(mesh, P(batch_axis, None, None)) - bias_sharding = NamedSharding(mesh, P(batch_axis, None)) - - sorted_x = recv_tokens_local.reshape(-1, hidden) - recv_w_flat = recv_topk_weights_local.reshape(-1) + data_ep_axis = _compound_data_ep_axis(data_parallelism_axes, ep_axis) + row_sharding = NamedSharding(mesh, P(data_ep_axis, None)) + group_sharding = NamedSharding(mesh, P(data_ep_axis)) + weight_sharding = NamedSharding(mesh, P(data_ep_axis, None, None)) + bias_sharding = NamedSharding(mesh, P(data_ep_axis, None)) + + sorted_x = recv_tokens.reshape(-1, hidden) + recv_w_flat = recv_topk_weights.reshape(-1) num_groups = dp_size * num_ep * num_local_experts group_sizes = jnp.full((num_groups,), slots_per_expert, dtype=jnp.int32) sorted_x = jax.lax.with_sharding_constraint(sorted_x, row_sharding) @@ -359,9 +370,12 @@ def _ffn_fwd_per_shard( jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None ) - q_set = noop_quantizer_set - casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, group_sizes, flatten_axis=-1) - casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1) + casted_sorted_x = tex.grouped_quantize( + sorted_x, ffn_quantizer_set.x, group_sizes, flatten_axis=-1 + ) + casted_wi = tex.grouped_quantize( + wi_combined, ffn_quantizer_set.kernel, flatten_axis=-1 + ) combined_out = tex.grouped_gemm( casted_sorted_x.get_tensor(usage=TensorUsage.LHS), casted_wi.get_tensor(usage=TensorUsage.RHS), @@ -389,9 +403,9 @@ def _ffn_fwd_per_shard( intermediate = jax.lax.with_sharding_constraint(intermediate, row_sharding) casted_intermediate = tex.grouped_quantize( - intermediate, q_set.x, group_sizes, flatten_axis=-1 + intermediate, ffn_quantizer_set.x, group_sizes, flatten_axis=-1 ) - casted_wo = tex.grouped_quantize(wo, q_set.kernel, flatten_axis=-1) + casted_wo = tex.grouped_quantize(wo, ffn_quantizer_set.kernel, flatten_axis=-1) expert_outputs = tex.grouped_gemm( casted_intermediate.get_tensor(usage=TensorUsage.LHS), casted_wo.get_tensor(usage=TensorUsage.RHS), @@ -417,8 +431,8 @@ def _ffn_fwd_per_shard( return expert_outputs_3d, residuals -def _ffn_bwd_per_shard( - d_expert_outputs_local: jnp.ndarray, +def _ffn_bwd_global( + d_expert_outputs: jnp.ndarray, casted_sorted_x_lhs_trans, casted_wi_rhs_trans, gate_proj_out: jnp.ndarray, @@ -426,7 +440,7 @@ def _ffn_bwd_per_shard( casted_intermediate_lhs_trans, casted_wo_rhs_trans, local_group_sizes: jnp.ndarray, - recv_topk_weights_local: jnp.ndarray, + recv_topk_weights: jnp.ndarray, *, ep_axis: str, data_parallelism_axes: Tuple[str, ...], @@ -437,29 +451,30 @@ def _ffn_bwd_per_shard( activation_type: str, apply_topk_weights_early: bool, has_bias: bool, + ffn_quantizer_set: QuantizerSet, ): - """Per-shard FFN backward. + """Global-view FFN backward. - Mirrors :func:`_ffn_fwd_per_shard`. Returns + Mirrors :func:`_ffn_fwd_global`. Returns ``(d_sorted_x [1, recv_pr, H], d_recv_w [1, recv_pr], d_wi_0, d_wi_1, d_wo, d_wi_0_bias, d_wi_1_bias, d_wo_bias)``. """ mesh = _get_mesh() - batch_axis: Any = ep_axis if not data_parallelism_axes else (*data_parallelism_axes, ep_axis) - row_sharding = NamedSharding(mesh, P(batch_axis, None)) - group_sharding = NamedSharding(mesh, P(batch_axis)) - weight_sharding = NamedSharding(mesh, P(batch_axis, None, None)) - bias_sharding = NamedSharding(mesh, P(batch_axis, None)) - - d_eo_2d = d_expert_outputs_local.reshape(-1, d_expert_outputs_local.shape[-1]) - recv_w_flat = recv_topk_weights_local.reshape(-1) + data_ep_axis = _compound_data_ep_axis(data_parallelism_axes, ep_axis) + row_sharding = NamedSharding(mesh, P(data_ep_axis, None)) + group_sharding = NamedSharding(mesh, P(data_ep_axis)) + weight_sharding = NamedSharding(mesh, P(data_ep_axis, None, None)) + bias_sharding = NamedSharding(mesh, P(data_ep_axis, None)) + + d_eo_2d = d_expert_outputs.reshape(-1, d_expert_outputs.shape[-1]) + recv_w_flat = recv_topk_weights.reshape(-1) d_eo_2d = jax.lax.with_sharding_constraint(d_eo_2d, row_sharding) recv_w_flat = jax.lax.with_sharding_constraint(recv_w_flat, group_sharding) local_group_sizes = jax.lax.with_sharding_constraint(local_group_sizes, group_sharding) - q_set = noop_quantizer_set - # wo bwd - casted_d_eo = tex.grouped_quantize(d_eo_2d, q_set.dgrad, local_group_sizes, flatten_axis=-1) + casted_d_eo = tex.grouped_quantize( + d_eo_2d, ffn_quantizer_set.dgrad, local_group_sizes, flatten_axis=-1 + ) d_intermediate = tex.grouped_gemm( casted_d_eo.get_tensor(usage=TensorUsage.LHS), casted_wo_rhs_trans, @@ -502,7 +517,7 @@ def _ffn_bwd_per_shard( d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) d_combined = jax.lax.with_sharding_constraint(d_combined, row_sharding) casted_d_combined = tex.grouped_quantize( - d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 + d_combined, ffn_quantizer_set.dgrad, local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), @@ -572,6 +587,7 @@ def _moe_fwd_rule( wi_1_bias, wo_bias, expert_bias, + ffn_quantizer_set, num_experts, num_experts_per_tok, activation_type, @@ -646,13 +662,10 @@ def _moe_fwd_rule( ep_size=num_ep, ) - if not data_parallelism_axes: - batch_pspec_axis: Any = ep_axis - else: - # ep must be innermost: ep_bootstrap forms NCCL EP comms from - # consecutive global ranks (dp_color = rank // ep_size), so the - # comm only stays within one model replica under (outer_dp, ep). - batch_pspec_axis = (*data_parallelism_axes, ep_axis) + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the comm + # only stays within one model replica under (outer DP/FSDP, ep). + batch_pspec_axis = _compound_data_ep_axis(data_parallelism_axes, ep_axis) ep3_spec = P(batch_pspec_axis, None, None) ep2_spec = P(batch_pspec_axis, None) x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) @@ -762,7 +775,7 @@ def _moe_fwd_rule( # ---------------- FFN (custom-partitioned grouped GEMM) ---------------- has_bias = wi_0_bias is not None - expert_outputs, ffn_residuals = _ffn_fwd_per_shard( + expert_outputs, ffn_residuals = _ffn_fwd_global( recv_tokens, recv_topk_weights, wi_0, @@ -779,6 +792,7 @@ def _moe_fwd_rule( slots_per_expert=slots_per_expert, activation_type=activation_type, apply_topk_weights_early=apply_topk_weights_early, + ffn_quantizer_set=ffn_quantizer_set, ) expert_outputs = jax.lax.with_sharding_constraint( expert_outputs, NamedSharding(mesh, ep3_spec) @@ -844,6 +858,7 @@ def _moe_fwd_rule( "has_bias": has_bias, "x_shape": x.shape, "recv_pr": recv_pr, + "ffn_quantizer_set": ffn_quantizer_set, } return (output, aux_loss), (ctx, static) @@ -879,6 +894,7 @@ def _moe_bwd_rule( has_bias = static["has_bias"] x_shape = static["x_shape"] recv_pr = static["recv_pr"] + ffn_quantizer_set = static["ffn_quantizer_set"] mesh = _get_mesh() if mesh is None or mesh.empty: @@ -890,10 +906,7 @@ def _moe_bwd_rule( B, S, _ = x_shape K = num_experts_per_tok - if not data_parallelism_axes: - batch_pspec_axis: Any = ep_axis - else: - batch_pspec_axis = (*data_parallelism_axes, ep_axis) + batch_pspec_axis = _compound_data_ep_axis(data_parallelism_axes, ep_axis) ep3_spec = P(batch_pspec_axis, None, None) ep2_spec = P(batch_pspec_axis, None) out_partition_spec = (batch_pspec_axis, None, None) @@ -931,7 +944,7 @@ def _moe_bwd_rule( d_wi_0_bias, d_wi_1_bias, d_wo_bias, - ) = _ffn_bwd_per_shard( + ) = _ffn_bwd_global( d_expert_outputs, ctx.casted_sorted_x_lhs_trans, ctx.casted_wi_rhs_trans, @@ -950,6 +963,7 @@ def _moe_bwd_rule( activation_type=activation_type, apply_topk_weights_early=apply_topk_weights_early, has_bias=has_bias, + ffn_quantizer_set=ffn_quantizer_set, ) d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate @@ -1052,6 +1066,7 @@ def _moe_bwd_rule( d_wi_1_bias if has_bias else None, d_wo_bias if has_bias else None, d_expert_bias, + ffn_quantizer_set, ) @@ -1060,7 +1075,7 @@ def _moe_bwd_rule( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 27))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(10, 28))) def _moe( x, gate_kernel, @@ -1071,6 +1086,7 @@ def _moe( wi_1_bias, wo_bias, expert_bias, + ffn_quantizer_set, num_experts, num_experts_per_tok, activation_type, @@ -1100,6 +1116,7 @@ def _moe( wi_1_bias, wo_bias, expert_bias, + ffn_quantizer_set, num_experts, num_experts_per_tok, activation_type, @@ -1135,6 +1152,7 @@ def moe( wi_1_bias: Optional[jnp.ndarray] = None, wo_bias: Optional[jnp.ndarray] = None, expert_bias: Optional[jnp.ndarray] = None, + ffn_quantizer_set: QuantizerSet = noop_quantizer_set, *, num_experts: int, num_experts_per_tok: int, @@ -1182,6 +1200,10 @@ def moe( rounds that count up to the nearest multiple, growing the per-rank receive buffer accordingly. Set to ``128`` for FP8 recipes that require 128-aligned grouped-GEMM tiles. + ffn_quantizer_set : QuantizerSet + Quantizers used by the two grouped-GEMM FFN projections. Defaults + to no quantization; pass an MXFP8 grouped quantizer set with + ``align_size=128`` to exercise the MXFP8 grouped-GEMM kernels. See module docstring for the rest of the parameter semantics and the surrounding design rationale. @@ -1195,9 +1217,7 @@ def moe( mesh = _get_mesh() if mesh is None or mesh.empty: raise ValueError("moe(...) requires an active jax.sharding.Mesh.") - expected_leading: Any = ( - (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis - ) + expected_leading: Any = _compound_data_ep_axis(data_parallelism_axes, ep_axis) expected_spec = P(expected_leading, None, None) actual_spec = getattr(getattr(x, "sharding", None), "spec", None) if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): @@ -1228,6 +1248,7 @@ def moe( wi_1_bias, wo_bias, expert_bias_arg, + ffn_quantizer_set, num_experts, num_experts_per_tok, activation_type, From 766c36a0ef9a207aad59da84dfc810a5d0e4d760 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 15 Jun 2026 15:50:38 -0700 Subject: [PATCH 36/36] Use 2D output from grouped quant and support GSPMD partition rules Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/base.py | 11 +- transformer_engine/jax/cpp_extensions/ep.py | 116 ++++++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 43 +++++++ .../jax/cpp_extensions/quantization.py | 58 ++++++++- .../jax/cpp_extensions/router.py | 63 ++++++++++ transformer_engine/jax/quantize/tensor.py | 1 - 6 files changed, 276 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 2cdef4bfe7..f41616fa81 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -4,27 +4,20 @@ """JAX/TE base custom ops""" import os import re +import inspect import warnings from abc import ABCMeta, abstractmethod from functools import partial -import jax from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch from jax import ffi -from packaging.version import Version as PkgVersion import transformer_engine_jax -# GSPMD sharding propagation (infer_sharding_from_operands) is removed in JAX > 0.9.1. -# Only register it for older JAX versions to maintain backwards compatibility. -# For JAX > 0.9.1, infer_sharding_from_operands is also removed from def_partition's signature, -# so it must not be passed at all. -_JAX_GSPMD_SUPPORTED = PkgVersion(jax.__version__) <= PkgVersion("0.9.1") - class BasePrimitive(metaclass=ABCMeta): """ @@ -235,7 +228,7 @@ def name_of_wrapper_p(): batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) - if _JAX_GSPMD_SUPPORTED: + if "infer_sharding_from_operands" in inspect.signature(outer_p_lower.def_partition).parameters: fn = cls.__dict__.get("infer_sharding_from_operands") if fn is not None: actual_fn = ( diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 456cab4fc0..a2a1718b59 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -237,6 +237,14 @@ def _ep_spec_ok(spec, trailing_count): return all(a in allowed for a in elts) +def _result_tree_like(result_infos, shardings): + if isinstance(result_infos, list): + return list(shardings) + if isinstance(result_infos, tuple): + return tuple(shardings) + return tuple(shardings) + + # ── ep_prepare ────────────────────────────────────────────────────────────── @@ -333,6 +341,29 @@ def sharded_impl(topk_idx): return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + @staticmethod + def infer_sharding_from_operands( + handle_id, + dispatch_output_per_expert_alignment, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del handle_id, dispatch_output_per_expert_alignment, is_outer + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + idx_spec = arg_infos[0].sharding.spec + if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpPrepare: topk_idx leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" + f" got spec={idx_spec}." + ) + out_sharding = NamedSharding(mesh, _ep_output_spec(None)) + return _result_tree_like(result_infos, (out_sharding, out_sharding)) + @staticmethod def shardy_sharding_rule(*args): # Signature: (*static_args, mesh, value_types, result_types). Static args @@ -486,6 +517,35 @@ def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def infer_sharding_from_operands( + handle_id, + recv_capacity_per_rank, + top_k, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del handle_id, recv_capacity_per_rank, top_k, is_outer + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + tokens_spec = arg_infos[2].sharding.spec + if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpDispatch: tokens leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" + f" got spec={tokens_spec}." + ) + return _result_tree_like( + result_infos, + ( + NamedSharding(mesh, _ep_output_spec(None, None)), + NamedSharding(mesh, _ep_output_spec(None)), + ), + ) + @staticmethod def shardy_sharding_rule(*args): # Signature: (*static_args, mesh, value_types, result_types). Static args @@ -643,6 +703,21 @@ def sharded_impl(handle_mem, expert_out): return mesh, sharded_impl, out_sharding, arg_shardings + @staticmethod + def infer_sharding_from_operands( + handle_id, out_leading_shape, out_partition_spec, mesh, arg_infos, result_infos + ): + del handle_id, result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + return NamedSharding(mesh, PartitionSpec(*resolved)) + @staticmethod def shardy_sharding_rule(*args): # Signature: (*static_args, mesh, value_types, result_types). Static args: @@ -791,6 +866,40 @@ def sharded_impl(handle_mem, grad, g_recv_topk_weights): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def infer_sharding_from_operands( + handle_id, + top_k, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del handle_id, top_k + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + return _result_tree_like( + result_infos, + ( + NamedSharding(mesh, PartitionSpec(*resolved)), + NamedSharding(mesh, PartitionSpec(*resolved, None)), + ), + ) + @staticmethod def shardy_sharding_rule(*args): # Signature: (*static_args, mesh, value_types, result_types). Result rank @@ -871,6 +980,13 @@ def sharded_impl(handle_mem, grad): return mesh, sharded_impl, out_sharding, arg_shardings + @staticmethod + def infer_sharding_from_operands( + handle_id, recv_capacity_per_rank, is_outer, mesh, arg_infos, result_infos + ): + del handle_id, recv_capacity_per_rank, is_outer, arg_infos, result_infos + return NamedSharding(mesh, _ep_output_spec(None, None)) + @staticmethod def shardy_sharding_rule(*args): # T axes are dynamic-rank based on the actual cotangent shape. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index bda58dc8e9..69f9e26228 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1973,6 +1973,49 @@ def spec_has_axes(spec): reduce_axis, ) + @staticmethod + def infer_sharding_from_operands( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + arg_infos, + result_infos, + ): + del ( + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + rhs_axis_boundary, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + ) + _, out_spec, _ = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape, + lhs_is_trans=lhs_is_trans, + lhs_axis_boundary=lhs_axis_boundary, + ) + return NamedSharding(mesh, PartitionSpec(*out_spec)) + @staticmethod def partition( lhs_is_trans, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 9266ab08f0..77b955ee78 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -64,6 +64,18 @@ def _flat_data_spec(input_spec): return (merge_axis_specs(*input_spec),) +def _grouped_data_spec(input_spec, flatten_axis): + return _contiguous_flat_input_spec(input_spec, flatten_axis) + + +def _result_tree_like(result_infos, shardings): + if isinstance(result_infos, list): + return list(shardings) + if isinstance(result_infos, tuple): + return tuple(shardings) + return tuple(shardings) + + def _normalize_flatten_axis(flatten_axis, ndim): return flatten_axis + ndim if flatten_axis < 0 else flatten_axis @@ -1128,7 +1140,7 @@ def abstract( """ dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - out_shape = math.prod(x_aval.shape) + out_shape = x_aval.shape # TODO(Phuong): can scale_aval be None? assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -1306,10 +1318,11 @@ def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos group_spec = (x_spec[0],) _warn_if_axes_ignored("group_sizes", original_group_spec, group_spec) flat_spec = _flat_data_spec(x_spec) + data_spec = _grouped_data_spec(x_spec, flatten_axis) replicated_spec = (None,) - rowwise_out_spec = flat_spec if q_layout.has_rowwise else replicated_spec - colwise_out_spec = flat_spec if q_layout.has_colwise else replicated_spec + rowwise_out_spec = data_spec if q_layout.has_rowwise else replicated_spec + colwise_out_spec = data_spec if q_layout.has_colwise else replicated_spec rowwise_scale_inv_spec = replicated_spec colwise_scale_inv_spec = replicated_spec @@ -1333,6 +1346,33 @@ def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos ), ) + @staticmethod + def infer_sharding_from_operands( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + arg_infos, + result_infos, + ): + del out_dtype, scale_dtype + _, _, out_specs = GroupedQuantizePrimitive._parse_partition_specs( + scaling_mode, q_layout, flatten_axis, mesh, arg_infos + ) + return _result_tree_like( + result_infos, + ( + NamedSharding( + mesh, + PartitionSpec(*spec), + desc=f"GroupedQuantizePrimitive.out_sharding_{idx}", + ) + for idx, spec in enumerate(out_specs) + ) + ) + @staticmethod def partition( out_dtype, @@ -1406,16 +1446,22 @@ def shardy_sharding_rule( value_types, result_types, ): - del out_dtype, scale_dtype, mesh, result_types, flatten_axis + del out_dtype, scale_dtype, mesh, result_types prefix = "GroupedQuantize" input_spec = tuple(f"{prefix}_x_{i}" for i in range(len(value_types[0].shape))) + data_spec = tuple( + input_spec[i] + if i < _normalize_flatten_axis(flatten_axis, len(input_spec)) + else f"{prefix}_data_{i}" + for i in range(len(input_spec)) + ) flat_spec = (f"{prefix}_flat",) group_spec = (BATCHING + f"{prefix}_group",) scalar_spec = (BATCHING + f"{prefix}_scalar",) - rowwise_out_spec = flat_spec if q_layout.has_rowwise else scalar_spec - colwise_out_spec = flat_spec if q_layout.has_colwise else scalar_spec + rowwise_out_spec = data_spec if q_layout.has_rowwise else scalar_spec + colwise_out_spec = data_spec if q_layout.has_colwise else scalar_spec if ScalingMode(scaling_mode).is_block_scaling: rowwise_scale_spec = flat_spec if q_layout.has_rowwise else scalar_spec diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index f5fa7722f4..60d7c1e2d6 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -28,6 +28,14 @@ class ScoreFunction(IntEnum): SOFTMAX = int(JAXX_Score_Function.SOFTMAX) +def _result_tree_like(result_infos, shardings): + if isinstance(result_infos, list): + return list(shardings) + if isinstance(result_infos, tuple): + return tuple(shardings) + return tuple(shardings) + + # =========================================== ================================== # Fused Top-K with Score Function - Forward # ============================================================================= @@ -203,6 +211,32 @@ def sharded_impl(logits, expert_bias): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def infer_sharding_from_operands( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del ( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + return _result_tree_like(result_infos, (out_sharding, out_sharding, out_sharding)) + @staticmethod def shardy_sharding_rule(*args): del args @@ -368,6 +402,22 @@ def sharded_impl(routing_map, intermediate, grad_probs): return mesh, sharded_impl, out_sharding, arg_shardings + @staticmethod + def infer_sharding_from_operands( + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores + del result_infos + grad_spec = get_padded_spec(arg_infos[2]) + return NamedSharding(mesh, PartitionSpec(*grad_spec)) + @staticmethod def shardy_sharding_rule(*args): del args @@ -463,6 +513,14 @@ def sharded_impl(probs, tokens_per_expert): return mesh, sharded_impl, out_shardings, arg_shardings + @staticmethod + def infer_sharding_from_operands(topk, coeff, mesh, arg_infos, result_infos): + del topk, coeff, arg_infos + return _result_tree_like( + result_infos, + (NamedSharding(mesh, PartitionSpec()), NamedSharding(mesh, PartitionSpec(None))), + ) + @staticmethod def shardy_sharding_rule(*args): del args @@ -564,6 +622,11 @@ def sharded_impl(const_buf, tokens_per_expert, grad_aux_loss): return mesh, sharded_impl, out_sharding, arg_shardings + @staticmethod + def infer_sharding_from_operands(num_tokens, mesh, arg_infos, result_infos): + del num_tokens, arg_infos, result_infos + return NamedSharding(mesh, PartitionSpec(None, None)) + @staticmethod def shardy_sharding_rule(*args): del args diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c5ad0451fd..837b179e04 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -428,7 +428,6 @@ def group_sizes(self) -> jnp.ndarray: def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" - assert self.data.ndim == 1, "Only support flattened data" assert self.flatten_axis > 0 data_ndim = len(self.original_shape)