Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ def aux_loss_pytorch(
return aux_loss


def topk_indices_to_routing_map(topk_indices: torch.Tensor, num_experts: int) -> torch.Tensor:
"""Convert dense [num_tokens, topk] top-k indices to a bool routing map."""
routing_map = torch.zeros(
topk_indices.size(0), num_experts, dtype=torch.bool, device=topk_indices.device
)
routing_map.scatter_(1, topk_indices.long(), True)
return routing_map


def run_comparison(
dtype,
num_tokens,
Expand All @@ -177,6 +186,8 @@ def run_comparison(
scaling_factor,
score_function,
enable_bias,
topk_output_mode="sparse",
topk_index_dtype=torch.int16,
):
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
Expand Down Expand Up @@ -235,8 +246,12 @@ def run_comparison(
expert_bias=expert_bias,
)

topk_indices = None
if topk_output_mode == "dense":
topk_indices = torch.empty((num_tokens, topk), device="cuda", dtype=topk_index_dtype)

# Run the fused implementation
probs_fused, routing_map_fused = fused_topk_with_score_function(
probs_fused, routing_output_fused = fused_topk_with_score_function(
logits=logits_clone,
topk=topk,
use_pre_softmax=use_pre_softmax,
Expand All @@ -245,7 +260,14 @@ def run_comparison(
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias_clone,
topk_indices=topk_indices,
)
if topk_output_mode == "dense":
assert routing_output_fused.data_ptr() == topk_indices.data_ptr()
assert routing_output_fused.dtype == topk_index_dtype
routing_map_fused = topk_indices_to_routing_map(routing_output_fused, num_experts)
else:
routing_map_fused = routing_output_fused

atol, rtol = _get_tolerances(dtype, num_experts)
torch.testing.assert_close(probs, probs_fused, atol=atol, rtol=rtol)
Expand All @@ -270,6 +292,7 @@ def run_comparison(
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64])
def test_topk_sigmoid(
dtype,
num_tokens,
Expand All @@ -278,6 +301,7 @@ def test_topk_sigmoid(
group_topk,
scaling_factor,
enable_bias,
topk_index_dtype,
):
num_groups = 8 if group_topk else None
run_comparison(
Expand All @@ -291,6 +315,8 @@ def test_topk_sigmoid(
scaling_factor=scaling_factor,
score_function="sigmoid",
enable_bias=enable_bias,
topk_output_mode="dense" if topk_index_dtype is not None else "sparse",
topk_index_dtype=topk_index_dtype or torch.int16,
)


Expand All @@ -301,6 +327,7 @@ def test_topk_sigmoid(
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64])
def test_topk_sqrtsoftplus(
dtype,
num_tokens,
Expand All @@ -309,6 +336,7 @@ def test_topk_sqrtsoftplus(
group_topk,
scaling_factor,
enable_bias,
topk_index_dtype,
):
num_groups = 8 if group_topk else None
run_comparison(
Expand All @@ -322,6 +350,8 @@ def test_topk_sqrtsoftplus(
scaling_factor=scaling_factor,
score_function="sqrtsoftplus",
enable_bias=enable_bias,
topk_output_mode="dense" if topk_index_dtype is not None else "sparse",
topk_index_dtype=topk_index_dtype or torch.int16,
)


Expand All @@ -332,6 +362,7 @@ def test_topk_sqrtsoftplus(
@pytest.mark.parametrize("use_pre_softmax", [True, False])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64])
def test_topk_softmax(
dtype,
num_tokens,
Expand All @@ -340,6 +371,7 @@ def test_topk_softmax(
use_pre_softmax,
group_topk,
scaling_factor,
topk_index_dtype,
):
num_groups = 8 if group_topk else None
run_comparison(
Expand All @@ -353,8 +385,37 @@ def test_topk_softmax(
scaling_factor=scaling_factor,
score_function="softmax",
enable_bias=False,
topk_output_mode="dense" if topk_index_dtype is not None else "sparse",
topk_index_dtype=topk_index_dtype or torch.int16,
)


@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16])
def test_topk_preserves_leading_dims(topk_index_dtype):
num_tokens = 128
num_experts = 32
topk = 4
logits = torch.randn(num_tokens, 2, num_experts, device="cuda", dtype=torch.float32)
topk_indices = None
if topk_index_dtype is not None:
topk_indices = torch.empty(num_tokens, 2, topk, device="cuda", dtype=topk_index_dtype)

probs, routing_output = fused_topk_with_score_function(
logits=logits,
topk=topk,
use_pre_softmax=False,
num_groups=None,
group_topk=None,
scaling_factor=None,
score_function="softmax",
expert_bias=None,
topk_indices=topk_indices,
)

assert probs.shape == logits.shape
expected_routing_shape = topk_indices.shape if topk_indices is not None else logits.shape
assert routing_output.shape == expected_routing_shape


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens,
Tensor &routing_map,
NVTERoutingMapFormat routing_map_format,
Tensor &intermediate_output, cudaStream_t stream) {
check_routing_map_format(routing_map_format);
NVTE_CHECK(num_tokens > 0 && num_experts > 0,
"num_tokens and num_experts must be positive; got num_tokens=", num_tokens,
", num_experts=", num_experts);
Expand Down
Loading
Loading