Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/pytorch/nvfp4/bench_graph_safe_swizzle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer

M, N = 8192, 7168 # your actual shape
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
split_sections = torch.tensor([128] * (M // 128), dtype=torch.int64, device="cuda")

for optimize_for_gemm in [False, True]:
q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
q.optimize_for_gemm = optimize_for_gemm

# warmup
for _ in range(10):
tex.group_quantize(x, q, split_sections.shape[0], split_sections)
torch.cuda.synchronize()

# time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
tex.group_quantize(x, q, split_sections.shape[0], split_sections)
end.record()
torch.cuda.synchronize()
print(f"optimize_for_gemm={optimize_for_gemm}: {start.elapsed_time(end) / 100 * 1000:.1f} μs")
91 changes: 91 additions & 0 deletions tests/pytorch/nvfp4/bench_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
import torch
import torch.cuda.nvtx as nvtx

N = 7168
num_experts = 64
ITERS = 50

M_VALUES = [8192, 16384, 32768, 65536, 131072]


def make_unequal_splits(M, num_experts):
base = M // num_experts
splits = []
for i in range(num_experts):
if i % 2 == 0:
splits.append(base - 128)
else:
splits.append(base + 128)
# fix rounding so sum == M
diff = M - sum(splits)
splits[-1] += diff
return splits


def bench(fn, label, iters=ITERS):
for _ in range(10):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
nvtx.range_push(label)
start.record()
for _ in range(iters):
fn()
end.record()
nvtx.range_pop()
torch.cuda.synchronize()
us = start.elapsed_time(end) / iters * 1000
print(f" {label}: {us:.1f} us")
return us


print(f"N={N}, num_experts={num_experts}")
print("-" * 60)

for M in M_VALUES:
if M % num_experts != 0 or (M // num_experts) <= 128:
print(f"M={M}: skipped")
continue

x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
label_prefix = f"M{M}"

print(f"\nM={M}:")

# --- graph-safe, equal splits (O(1) division) ---
equal_splits = [M // num_experts] * num_experts
equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda")
q_eq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
q_eq.optimize_for_gemm = False
bench(
lambda: tex.group_quantize(x, q_eq, num_experts, equal_tensor),
f"{label_prefix}_graph_safe_equal_O1",
)

# --- graph-safe, unequal splits (binary search) ---
unequal_splits = make_unequal_splits(M, num_experts)
unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda")
q_uneq = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
q_uneq.optimize_for_gemm = False
bench(
lambda: tex.group_quantize(x, q_uneq, num_experts, unequal_tensor),
f"{label_prefix}_graph_safe_unequal_bsearch",
)

# --- non-graph-safe (linear scan) ---
q_list = [
NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
for _ in range(num_experts)
]
bench(
lambda: tex.split_quantize(x, equal_splits, q_list),
f"{label_prefix}_non_graph_safe_linear",
)
67 changes: 67 additions & 0 deletions tests/pytorch/nvfp4/bench_structural.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
import torch
import torch.cuda.nvtx as nvtx

N = 7168
num_experts = 64


def make_quantizer():
q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
q.optimize_for_gemm = True
return q


def bench(fn, label, iters=100):
for _ in range(10):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
nvtx.range_push(label)
start.record()
for _ in range(iters):
fn()
end.record()
nvtx.range_pop()
torch.cuda.synchronize()
print(f"{label}: {start.elapsed_time(end) / iters * 1000:.1f} us")


for M in [16384, 65536, 131072]:
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")

# 1. graph-safe + equal splits -> O(1) division (SAME_BOTH_DIMS)
equal_splits = [M // num_experts] * num_experts
equal_tensor = torch.tensor(equal_splits, dtype=torch.int64, device="cuda")
q1 = make_quantizer()
bench(
lambda: tex.group_quantize(x, q1, num_experts, equal_tensor), f"[M={M}] graph_safe_equal_O1"
)

# 2. graph-safe + unequal splits -> binary search (VARYING_FIRST_DIM)
base = M // num_experts
unequal_splits = [base - 128 if i % 2 == 0 else base + 128 for i in range(num_experts)]
unequal_tensor = torch.tensor(unequal_splits, dtype=torch.int64, device="cuda")
q2 = make_quantizer()
bench(
lambda: tex.group_quantize(x, q2, num_experts, unequal_tensor),
f"[M={M}] graph_safe_unequal_binary_search",
)

# 3. non-graph-safe + linear scan (GetGroupIdx)
q_list = [
NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
for _ in range(num_experts)
]
bench(
lambda: tex.split_quantize(x, equal_splits, q_list), f"[M={M}] non_graph_safe_linear_scan"
)

print()
Comment on lines +5 to +67
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level GPU code will execute on pytest import

All five new scripts (bench_structural.py, bench_sweep_swizzle.py, bench_search.py, bench_graph_safe_swizzle.py, ncu_test.py) contain GPU kernel launches at module scope. When pytest discovers files in tests/pytorch/nvfp4/, it imports each one to collect tests; the imports execute the benchmarks immediately — potentially hanging or crashing CI on machines without the required GPU or package.

Wrap the benchmark body in a if __name__ == "__main__": guard on all five files, e.g.:

if __name__ == "__main__":
    for M in [16384, 65536, 131072]:
        ...

93 changes: 93 additions & 0 deletions tests/pytorch/nvfp4/bench_sweep_swizzle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
import torch
import torch.cuda.nvtx as nvtx

N = 7168
num_experts = 64
ITERS = 50

M_VALUES = [8192, 16384, 32768, 65536, 131072]


def bench(fn, label, iters=ITERS):
# warmup
for _ in range(10):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
nvtx.range_push(label)
start.record()
for _ in range(iters):
fn()
end.record()
nvtx.range_pop()
torch.cuda.synchronize()
us = start.elapsed_time(end) / iters * 1000
print(f" {label}: {us:.1f} us")
return us


print(f"N={N}, num_experts={num_experts}")
print("-" * 60)

for M in M_VALUES:
if M % num_experts != 0:
print(f"M={M}: skipped (not divisible by num_experts={num_experts})")
continue

rows_per_expert = M // num_experts
split_sections = [rows_per_expert] * num_experts
split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda")
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")

print(f"\nM={M} ({rows_per_expert} rows/expert):")

label_prefix = f"M{M}"

# --- graph-safe, swizzle ON ---
q_on = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_rht=True,
with_post_rht_amax=True,
)
q_on.optimize_for_gemm = True
bench(
lambda: tex.group_quantize(x, q_on, num_experts, split_section_tensor),
f"{label_prefix}_graph_safe_swizzle_ON",
)

# --- graph-safe, swizzle OFF ---
q_off = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_rht=True,
with_post_rht_amax=True,
)
q_off.optimize_for_gemm = False
bench(
lambda: tex.group_quantize(x, q_off, num_experts, split_section_tensor),
f"{label_prefix}_graph_safe_swizzle_OFF",
)

# --- non-graph-safe ---
q_list = [
NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_rht=True,
with_post_rht_amax=True,
)
for _ in range(num_experts)
]
bench(
lambda: tex.split_quantize(x, split_sections, q_list),
f"{label_prefix}_non_graph_safe",
)
23 changes: 23 additions & 0 deletions tests/pytorch/nvfp4/ncu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
import torch

M, N, num_experts = 16384, 7168, 64
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
splits = [M // num_experts] * num_experts
split_tensor = torch.tensor(splits, dtype=torch.int64, device="cuda")

# warmup
q = NVFP4Quantizer(rowwise=True, columnwise=True, with_rht=True, with_post_rht_amax=True)
for _ in range(3):
tex.group_quantize(x, q, num_experts, split_tensor)
torch.cuda.synchronize()

# single measured launch
tex.group_quantize(x, q, num_experts, split_tensor)
torch.cuda.synchronize()
Loading
Loading