diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py index 3c117f4574f..0b95f736102 100644 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ b/backends/cuda/benchmarks/benchmark_sdpa.py @@ -6,16 +6,27 @@ # LICENSE file in the root directory of this source tree. """ -Benchmark the Triton SDPA kernel against PyTorch SDPA backends. - -Measures latency across decode shapes matching the Qwen3.5 MoE model -(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA -(2 KV heads), while Flash/Efficient/Math require pre-expanded KV -(16 heads) since they lack native GQA support. - +Benchmark the Triton SDPA kernels against PyTorch SDPA backends at decode. + +Cross-backend latency comparison ("is our kernel competitive vs PyTorch / +Flash?") across a few representative decode configs and the L_kv range, in BOTH +CUDA-graph and plain timing modes. The ET Triton kernels use native GQA; the +Flash/Efficient/Math backends require pre-expanded KV (no native GQA), matching +the test reference. PyTorch (default) is the correctness reference. + +Timing: CUDA-graph mode (capture+replay) is faithful to the deployed +``--cuda_graph`` runtime; plain ``do_bench`` charges each kernel its full +per-call launch/alloc overhead. Run both to see the effect (it is large for ET +split-K, which allocates partial buffers per call). + +Usage: + python benchmark_sdpa.py # both timing modes + python benchmark_sdpa.py --mode cudagraph + python benchmark_sdpa.py --mode plain """ import argparse +import statistics import warnings from functools import partial @@ -23,17 +34,67 @@ import torch.nn.functional as F from executorch.backends.cuda.triton.kernels.sdpa import ( - sdpa as triton_sdpa, - sdpa_decode_splitk as triton_splitk, + sdpa as _triton_sdpa, + sdpa_decode_splitk as _triton_splitk, ) from torch.nn.attention import sdpa_kernel, SDPBackend -from triton.testing import do_bench +from triton.testing import do_bench, do_bench_cudagraph + + +# -- Timing primitive + ET kernel runners (self-contained) ------------------- +# do_bench budgets are millisecond windows (NOT iteration counts). +_WARMUP_MS = 10 +_REP_MS = 50 +# Warmup calls before graph capture so the Triton autotuner has cached a config +# (autotuning cannot run inside graph capture). +_GRAPH_WARMUP_CALLS = 20 + + +def run_standard(q, k, v, attn_mask, enable_gqa): + return _triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def run_splitk(q, k, v, attn_mask, enable_gqa): + return _triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def time_us(fn, cudagraph: bool = True) -> float: + """Median latency (us). cudagraph=True is faithful to the --cuda_graph path. + + Under CUDA-graph the op is captured once (its split-K partial/LSE workspace + is allocated once into the graph's private pool and reused across replays) + and only replay() is timed, so the per-call buffer alloc + launch overhead + is excluded -- exactly as the deployed runtime eliminates it. We warm up + first so the Triton autotuner has cached a config before capture. + """ + if cudagraph: + for _ in range(_GRAPH_WARMUP_CALLS): + fn() + torch.cuda.synchronize() + ms = do_bench_cudagraph(fn, rep=_REP_MS, return_mode="median") + else: + ms = do_bench(fn, warmup=_WARMUP_MS, rep=_REP_MS, return_mode="median") + return ms * 1000.0 + + +# Each reported number repeats the timing primitive N_RUNS times, discards the +# first N_WARMUP as warmup, and reports mean +/- std over the remaining runs. +N_RUNS = 10 +N_WARMUP = 4 + + +def measure_us(fn, cudagraph: bool): + """Repeat time_us N_RUNS times; return (mean, std) over runs[N_WARMUP:].""" + samples = [time_us(fn, cudagraph=cudagraph) for _ in range(N_RUNS)] + kept = samples[N_WARMUP:] + mean = statistics.fmean(kept) + std = statistics.stdev(kept) if len(kept) > 1 else 0.0 + return mean, std # PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. -# We expand KV heads via repeat_interleave so they can run, matching what -# the test reference does. This is fair: it measures the kernel itself, not -# the GQA dispatch overhead. +# We expand KV heads via repeat_interleave so they can run, matching what the +# test reference does. This measures the kernel itself, not GQA dispatch. def _expand_kv(k, v, num_groups): @@ -49,21 +110,9 @@ def _expand_mask(mask, H_q): return mask -def _run_triton(q, k, v, attn_mask, enable_gqa): - return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_splitk(q, k, v, attn_mask, enable_gqa): - return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - enable_gqa=enable_gqa, + q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa ) @@ -75,50 +124,40 @@ def run(q, k, v, attn_mask, enable_gqa): return run -# Flash doesn't support attn_mask at all, only is_causal. -# Our benchmark mask is all-ones, so no mask is equivalent. +# Flash doesn't support attn_mask at all, only is_causal. Our benchmark mask is +# all-ones, so no mask is equivalent. def _run_flash(q, k, v, attn_mask, enable_gqa): with sdpa_kernel(SDPBackend.FLASH_ATTENTION): return F.scaled_dot_product_attention(q, k, v) +# ET Triton kernels reuse the shared helper runners (the real lowered kernels). BACKENDS = { - "triton": ("ET Triton (GQA)", _run_triton), - "splitk": ("ET Split-K (GQA)", _run_splitk), + "triton": ("ET Triton (GQA)", run_standard), + "splitk": ("ET Split-K (GQA)", run_splitk), "pytorch": ("PyTorch", _run_pytorch_default), - "flash": ("Flash (expanded KV)", _run_flash), + "flash": ("Flash (exp KV)", _run_flash), "efficient": ( - "Efficient (expanded KV)", + "Efficient (exp KV)", _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), ), - "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), + "math": ("Math (exp KV)", _make_pytorch_runner(SDPBackend.MATH)), } -# Backends that need KV heads expanded before calling (no native GQA support) +# Backends that need KV heads expanded before calling (no native GQA support). _NEEDS_KV_EXPAND = {"flash", "efficient", "math"} -# -- Shapes ------------------------------------------------------------------ - -# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 -QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} - -DECODE_SHAPES = [ - dict(**QWEN35_BASE, Lq=1, Lk=64), - dict(**QWEN35_BASE, Lq=1, Lk=128), - dict(**QWEN35_BASE, Lq=1, Lk=256), - dict(**QWEN35_BASE, Lq=1, Lk=512), - dict(**QWEN35_BASE, Lq=1, Lk=1024), - dict(**QWEN35_BASE, Lq=1, Lk=2048), - dict(**QWEN35_BASE, Lq=1, Lk=4096), - dict(**QWEN35_BASE, Lq=1, Lk=8192), - dict(**QWEN35_BASE, Lq=1, Lk=16384), +# Representative decode configs (label, B, H_q, H_kv, D). CTA = B * H_kv. +CONFIGS = [ + ("gemma sliding (D=256, CTA=16)", 1, 32, 16, 256), + ("qwen (D=256, CTA=2)", 1, 16, 2, 256), + ("head_dim=128 (D=128, CTA=16)", 1, 32, 16, 128), ] -SCENARIOS = { - "decode": DECODE_SHAPES, -} +L_KV_RANGE = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] -# -- Helpers ----------------------------------------------------------------- +# Cross-backend validation tolerance (bf16 vs bf16). +MAX_ABS_TOL = 1e-2 def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): @@ -128,7 +167,6 @@ def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) enable_gqa = H_q != H_kv num_groups = H_q // H_kv - # Pre-expanded versions for backends without native GQA k_exp, v_exp = _expand_kv(k, v, num_groups) mask_exp = _expand_mask(mask, H_q) return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa @@ -138,170 +176,132 @@ def _max_abs_error(out, ref): return (out.float() - ref.float()).abs().max().item() -# Cross-backend validation tolerance (bf16 vs bf16). -MAX_ABS_TOL = 1e-2 - - -def _bench_us(fn, num_warmup, num_iters): - """Return median latency in microseconds using triton.testing.do_bench.""" - ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - return ms * 1000.0 - - def _try_run(run_fn, q, k, v, mask, enable_gqa): - """Run a backend, returning output or None on failure.""" try: return run_fn(q, k, v, mask, enable_gqa) - except RuntimeError: + except Exception: return None -def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): - """Benchmark a backend, returning median us or None on failure.""" +def _try_bench(run_fn, q, k, v, mask, enable_gqa, cudagraph): + """Benchmark one backend, returning (mean_us, std_us) or None on failure.""" fn = partial(run_fn, q, k, v, mask, enable_gqa) try: run_fn(q, k, v, mask, enable_gqa) - return _bench_us(fn, num_warmup, num_iters) - except RuntimeError: + return measure_us(fn, cudagraph=cudagraph) + except Exception: return None -# -- Main -------------------------------------------------------------------- - - -def _shape_label(shape): - return ( - f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " - f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" - ) - - -def _short_label(shape, scenario="decode"): - return f"Lq={shape['Lq']},Lk={shape['Lk']}" +def _bench_inputs(name, q, k, v, k_exp, v_exp, mask, mask_exp): + """Return the (k, v, mask) a backend should use (expanded or native).""" + if name in _NEEDS_KV_EXPAND: + return k_exp, v_exp, mask_exp + return k, v, mask @torch.inference_mode() -def run_benchmark( - scenario: str = "decode", - num_warmup: int = 25, - num_iters: int = 100, -): - shapes = SCENARIOS[scenario] +def run_benchmark(cudagraph: bool): + """Print a cross-backend decode latency table for each config.""" backends = [(name, *BACKENDS[name]) for name in BACKENDS] + mode = "CUDA-graph (capture+replay)" if cudagraph else "plain do_bench" + device = torch.cuda.get_device_name() + n_sm = torch.cuda.get_device_properties(0).multi_processor_count - device_name = torch.cuda.get_device_name() print() - print("=" * 100) - print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") - print(f" Device: {device_name}") - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Build column specs: (header_text, unit_text, min_width) - # Each column gets width = max(len(header), len(unit), min_width) - max_label = max(len(_short_label(s, scenario)) for s in shapes) - col_specs = [("Shape", "", max(8, max_label))] - for _, label, _ in backends: - col_specs.append((label, "(us)", 8)) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) + print("=" * 124) + print(f"SDPA decode cross-backend benchmark | timing: {mode}") + print(f" device: {device} (n_SM={n_sm}) L_q=1, bf16, all-ones mask") + print(f" backends: {', '.join(label for _, label, _ in backends)}") + print( + f" each cell = mean+/-std us over last {N_RUNS - N_WARMUP} of {N_RUNS} " + f"runs ({N_WARMUP} warmup)" ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for shape in shapes: - q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Validate outputs across backends before benchmarking - outputs = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) - - # Use PyTorch F.sdpa as the trusted reference — never validate - # against our own Triton kernels. - ref_name, ref_out = None, None - if outputs.get("pytorch") is not None: - ref_name, ref_out = "pytorch", outputs["pytorch"] - - if ref_out is not None: - for name, label, _ in backends: - if name == ref_name or outputs[name] is None: - continue - err = _max_abs_error(outputs[name], ref_out) - assert err < MAX_ABS_TOL, ( - f"Output mismatch for {_shape_label(shape)}: " - f"{label} vs {BACKENDS[ref_name][0]}, " - f"max abs error {err:.3e} >= 1e-2" + print("=" * 124) + + for label, B, H_q, H_kv, D in CONFIGS: + print(f"\n{label} [B={B} H_q={H_q} H_kv={H_kv} D={D}]") + col_specs = [("L_kv", "", 6)] + [(lbl, "(us)", 13) for _, lbl, _ in backends] + widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] + header = " | ".join( + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" + for i, ((h, _, _), w) in enumerate(zip(col_specs, widths)) + ) + units = " | ".join( + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" + for i, ((_, u, _), w) in enumerate(zip(col_specs, widths)) + ) + print(" " + header) + print(" " + units) + print(" " + "-" * len(header)) + + for Lk in L_KV_RANGE: + q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors( + B, H_q, H_kv, 1, Lk, D + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Correctness: validate every backend against PyTorch (default). + outputs = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) + ref = outputs.get("pytorch") + if ref is not None: + for name, lbl, _ in backends: + if name == "pytorch" or outputs[name] is None: + continue + err = _max_abs_error(outputs[name], ref) + assert err < MAX_ABS_TOL, ( + f"Output mismatch {label} L_kv={Lk}: {lbl} vs PyTorch, " + f"max abs error {err:.3e} >= {MAX_ABS_TOL}" + ) + del outputs + + times = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + times[name] = _try_bench( + run_fn, q, bk, bv, bmask, enable_gqa, cudagraph ) - del outputs - # Benchmark all backends - times = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp + row = [f"{Lk:<{widths[0]}}"] + for ci, (name, _, _) in enumerate(backends, start=1): + t = times[name] + if t is not None: + cell = f"{t[0]:.1f}\u00b1{t[1]:.1f}" else: - bk, bv, bmask = k, v, mask - times[name] = _try_bench( - run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters - ) - - # Format row using col_widths - ci = 0 - row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del q, k, v, k_exp, v_exp, mask, mask_exp - torch.cuda.empty_cache() - - print("-" * len(header)) + cell = "N/A" + row.append(f"{cell:>{widths[ci]}}") + print(" " + " | ".join(row)) + + del q, k, v, k_exp, v_exp, mask, mask_exp + torch.cuda.empty_cache() print() def main(): parser = argparse.ArgumentParser( - description="Benchmark Triton SDPA vs PyTorch backends" + description="Benchmark Triton SDPA vs PyTorch backends (decode)" ) parser.add_argument( - "--scenario", - choices=list(SCENARIOS.keys()) + ["all"], - default="all", - help="Which shape set to benchmark (default: all)", + "--mode", + choices=["cudagraph", "plain", "both"], + default="both", + help="Timing mode(s) to run (default: both).", ) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) args = parser.parse_args() - scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] - for s in scenarios: - run_benchmark( - scenario=s, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) + if args.mode in ("cudagraph", "both"): + run_benchmark(cudagraph=True) + if args.mode in ("plain", "both"): + run_benchmark(cudagraph=False) if __name__ == "__main__": diff --git a/backends/cuda/coalesced_int4_tensor.py b/backends/cuda/coalesced_int4_tensor.py new file mode 100644 index 00000000000..a623f7f41c4 --- /dev/null +++ b/backends/cuda/coalesced_int4_tensor.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""ExecuTorch-internal INT4 tensor for the CUDA W4A8 dp4a decode kernel. + +``CudaCoalescedInt4Tensor`` is an ExecuTorch-internal tensor subclass. It is +**NOT** torchao's ``Int4Tensor`` and is intentionally not a subclass of it, so +torchao's ``Int4Tensor`` F.linear handlers never match it via the method +resolution order. The CUDA decode/prefill dispatch (``int4_dispatch.py``) is +selected by *type* — it is registered on this class only — so stock +``Int4Tensor`` weights keep falling back to torchao's default (mslk/tinygemm) +path. + +Layout difference from torchao ``Int4Tensor``: + qdata : packed int4 weight (N, K/2), nibble-packed (same as Int4Tensor) + scale : (N, n_groups) — the *coalesced* layout, transposed from + torchao's documented (n_groups, N) + zero_point : (N, n_groups) — coalesced, transposed from (n_groups, N) + +The coalesced [N, n_groups] layout is exactly what the W4A8 dp4a matvec kernel +(``executorch_cuda::int4_plain_mm`` / ``int4_plain_mm.cuh``) reads row-for-row +with qdata, so the exported decode graph carries no per-step transpose. The +transpose is owned by :meth:`from_int4_tensor` so it is baked into the +serialized weight constant once at pack time. +""" + +from typing import List, Optional + +import torch +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "CudaCoalescedInt4Tensor", +] + + +class CudaCoalescedInt4Tensor(TorchAOBaseTensor): + """INT4 weight with scale/zero_point in the coalesced [N, n_groups] layout. + + ExecuTorch-internal; see the module docstring. Mirrors torchao + ``Int4Tensor``'s data/attribute layout (so the common tensor utilities and + serialization work) but owns the [n_groups, N] -> [N, n_groups] transpose + of scale/zero_point via :meth:`from_int4_tensor`. + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + optional_tensor_attribute_names = ["activation_dtype"] + + def __new__( + cls, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + activation_dtype: Optional[torch.dtype] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + activation_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.activation_dtype = ( + activation_dtype if activation_dtype is not None else torch.bfloat16 + ) + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}, activation_dtype={self.activation_dtype}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_int4_tensor(cls, t: Int4Tensor) -> "CudaCoalescedInt4Tensor": + """Build a coalesced tensor from a torchao ``Int4Tensor``. + + Owns the transpose: torchao stores scale/zero_point as (n_groups, N); + the CUDA decode kernel reads (N, n_groups). The ``.t().contiguous()`` + here is baked into the serialized weight constant so the exported + decode graph has no per-step transpose/clone. + """ + return cls( + t.qdata, + t.scale.t().contiguous(), + t.zero_point.t().contiguous(), + t.block_size, + t.shape, + t.act_pre_scale, + t.activation_dtype, + ) + + +# Allow a model with CudaCoalescedInt4Tensor weights to be loaded with +# `weights_only=True` (mirrors torchao Int4Tensor). +torch.serialization.add_safe_globals([CudaCoalescedInt4Tensor]) diff --git a/backends/cuda/quantize_op_dispatch/__init__.py b/backends/cuda/quantize_op_dispatch/__init__.py index 2248ef0b5c1..005c2b6e7c7 100644 --- a/backends/cuda/quantize_op_dispatch/__init__.py +++ b/backends/cuda/quantize_op_dispatch/__init__.py @@ -10,8 +10,8 @@ weight tensors so that torch.export traces through ExecuTorch's custom ops and dequant logic instead of torchao's defaults. It registers: - * INT4 (``Int4Tensor``) → ``executorch_cuda::int4_plain_mm`` - * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` + * INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm`` + * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details. diff --git a/backends/cuda/quantize_op_dispatch/int4_dispatch.py b/backends/cuda/quantize_op_dispatch/int4_dispatch.py index 27f491fef06..c3b8921e2fe 100644 --- a/backends/cuda/quantize_op_dispatch/int4_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int4_dispatch.py @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. +"""CudaCoalescedInt4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. -This module overrides Int4Tensor's F.linear dispatch so that torch.export -traces through our custom op and dequant logic instead of torchao's default -(mslk/tinygemm). The code here executes during eager inference and during -AOTI export tracing — it does NOT run at .pte runtime. +This module registers an F.linear dispatch on ``CudaCoalescedInt4Tensor`` (an +ExecuTorch-internal subclass, see ``coalesced_int4_tensor.py``) so that +torch.export traces through our custom op and dequant logic. Routing is by +*type*: stock torchao ``Int4Tensor`` weights are left untouched and keep using +torchao's default (mslk/tinygemm) path. The code here executes during eager +inference and during AOTI export tracing — it does NOT run at .pte runtime. At .pte runtime, the captured graph is executed by the AOTI-generated .so: - The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that @@ -22,17 +24,17 @@ Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) Importing the parent ``quantize_op_dispatch`` package registers this dispatch -override (along with the INT8 one) before using nn.Linear with Int4Tensor -weights:: +override (along with the INT8 one) before using nn.Linear with +CudaCoalescedInt4Tensor weights:: import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 """ import torch import torch.nn.functional as F +from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib from torch.library import impl -from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor # --------------------------------------------------------------------------- # Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager @@ -52,11 +54,18 @@ def _meta(self, qdata, scale, zero, group_size): @impl(_lib, "int4_plain_mm", "CUDA") def _cuda(self, qdata, scale, zero, group_size): + # scale/zero are stored in the coalesced [N, n_groups] layout (transposed + # at pack time, see pack_cuda.pack_linear_for_cuda), which is exactly what + # _dequant_matmul expects. return _dequant_matmul(self, qdata, scale, zero, group_size) def _dequant_matmul(x, qdata, scale, zero, group_size): - """Dequant INT4 weights to input dtype and call F.linear.""" + """Dequant INT4 weights to input dtype and call F.linear. + + scale/zero are in the coalesced [N, n_groups] layout (baked into the + weight constant at pack time), aligned row-for-row with qdata's [N, *]. + """ N, K_half = qdata.shape K = K_half * 2 n_groups = K // group_size @@ -68,20 +77,20 @@ def _dequant_matmul(x, qdata, scale, zero, group_size): high = ((p >> 4) & 0x0F).to(dtype) data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) - s = scale.to(dtype).t().unsqueeze(-1) - z = zero.to(dtype).t().unsqueeze(-1) + s = scale.to(dtype).unsqueeze(-1) + z = zero.to(dtype).unsqueeze(-1) w_deq = ((data - z) * s).reshape(N, K) return F.linear(x, w_deq) # --------------------------------------------------------------------------- -# Int4Tensor F.linear dispatch +# CudaCoalescedInt4Tensor F.linear dispatch # --------------------------------------------------------------------------- aten = torch.ops.aten -_implements = Int4Tensor.implements -_implements_torch_function = Int4Tensor.implements_torch_function +_implements = CudaCoalescedInt4Tensor.implements +_implements_torch_function = CudaCoalescedInt4Tensor.implements_torch_function @_implements([aten.linear.default]) @@ -101,6 +110,11 @@ def _(func, types, args, kwargs): M = x_2d.shape[0] if M <= 4: + # scale/zero are already in the coalesced [N, n_groups] layout the + # decode kernel reads directly (baked into the weight constant at pack + # time). Passing them straight through keeps the export graph free of + # any per-step transpose/clone, so the coalesced layout is realized + # without recomputing it every decode step. out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) else: out = _dequant_matmul(x_2d, qdata, scale, zero, gs) diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cu b/backends/cuda/runtime/shims/int4_plain_mm.cu index fd8fe3b0c3b..7cda801c348 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cu +++ b/backends/cuda/runtime/shims/int4_plain_mm.cu @@ -52,8 +52,43 @@ AOTITorchError aoti_torch_cuda_int4_plain_mm( InvalidArgument, "aoti_torch_cuda_int4_plain_mm: ret0 is null"); + // Validate the coalesced scale/zero layout [N, K/group_size] + + const int64_t N = qdata->size(0); + const int64_t K = qdata->size(1) * 2; + + ET_CHECK_OR_RETURN_ERROR( + group_size > 0 && (group_size & (group_size - 1)) == 0, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: group_size=%lld must be a positive power of 2", + static_cast(group_size)); + + const int64_t n_groups = K / group_size; + + ET_CHECK_OR_RETURN_ERROR( + scale->dim() == 2 && zero->dim() == 2, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be 2D (got scale.dim()=%lld, zero.dim()=%lld)", + static_cast(scale->dim()), + static_cast(zero->dim())); + + ET_CHECK_OR_RETURN_ERROR( + scale->size(0) == N && zero->size(0) == N, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(0)=N=%lld, got scale.size(0)=%lld, zero.size(0)=%lld", + static_cast(N), + static_cast(scale->size(0)), + static_cast(zero->size(0))); + + ET_CHECK_OR_RETURN_ERROR( + scale->size(1) == n_groups && zero->size(1) == n_groups, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(1)=K/group_size=%lld, got scale.size(1)=%lld, zero.size(1)=%lld", + static_cast(n_groups), + static_cast(scale->size(1)), + static_cast(zero->size(1))); + int32_t M = self->size(0); - int32_t N = qdata->size(0); Tensor* C = nullptr; std::array c_shape = {M, N}; std::array c_stride = {N, 1}; diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 42700969fa4..31214bc0bf6 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -9,7 +9,7 @@ // W4A8 dp4a matvec for INT4 decode (M <= 4). // // Reads plain nibble-packed [N, K//2] weights (Int4Tensor format). -// Scale/zero layout: [K//gs, N] (Int4Tensor's native layout). +// Scale/zero layout: [N, K//gs] (transposed AOT for coalesced loads). // // Dynamically quantizes bf16 activations to INT8 (per-32-element blocks), // then uses dp4a for fused int4×int8 dot products with 16-byte vectorized @@ -98,18 +98,28 @@ __global__ void quantize_activations_q8_kernel( } // --------------------------------------------------------------------------- -// W4A8 dp4a matvec kernel +// Coalesced-scale W4A8 dp4a matvec +// +// Reads scale/zero in the transposed [N, n_groups] layout (transposed AOT at +// export time). With group_size >= 32, one uint4 (32 weights) maps to exactly +// one activation block and one weight group, so within a warp the 32 lanes +// touch 32 consecutive groups. In [N, n_groups] layout those 32 group scales +// are contiguous => a single coalesced load, vs 32 stride-N cache lines in the +// native layout. For the gemma group_size=32 weights this is the dominant +// decode-matvec cost. // --------------------------------------------------------------------------- -__global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( - const uint8_t* __restrict__ qdata, - const __nv_bfloat16* __restrict__ w_scale, - const __nv_bfloat16* __restrict__ w_zero, - const Q8Block* __restrict__ q8, - __nv_bfloat16* __restrict__ out, - int32_t N, - int32_t K, - int32_t gs_shift) { +__global__ void __launch_bounds__(MV_THREADS) + int4_w4a8_matvec_coalesced_kernel( + const uint8_t* __restrict__ qdata, + const __nv_bfloat16* __restrict__ w_scale_t, // [N, n_groups] + const __nv_bfloat16* __restrict__ w_zero_t, // [N, n_groups] + const Q8Block* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift, + int32_t n_groups) { const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y; const int32_t m = blockIdx.y; if (n >= N) @@ -120,9 +130,10 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; const uint8_t* qrow = qdata + static_cast(n) * K_half; - const __nv_bfloat16* scale_base = w_scale + n; - const __nv_bfloat16* zero_base = w_zero + n; - const int32_t scale_stride = N; + const __nv_bfloat16* scale_row = + w_scale_t + static_cast(n) * n_groups; + const __nv_bfloat16* zero_row = + w_zero_t + static_cast(n) * n_groups; const Q8Block* q8_row = q8 + static_cast(m) * n_q8_blocks; const uint4* qrow16 = reinterpret_cast(qrow); @@ -145,8 +156,8 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( int32_t g = k_word >> gs_shift; if (g != prev_g) { - ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); - wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + ws = __bfloat162float(__ldg(&scale_row[g])); + wz = __bfloat162float(__ldg(&zero_row[g])); prev_g = g; } @@ -227,8 +238,8 @@ static Q8Block* get_q8_buffer(size_t needed) { void _int4_plain_mm_cuda( const Tensor& A, // [M, K] bf16 const Tensor& qdata, // [N, K//2] uint8 - const Tensor& scale, // [K//gs, N] bf16 - const Tensor& zero, // [K//gs, N] bf16 + const Tensor& scale, // [N, K//gs] bf16 + const Tensor& zero, // [N, K//gs] bf16 int64_t group_size, Tensor* output) { // [M, N] bf16, pre-allocated int32_t M = A.size(0); @@ -245,9 +256,9 @@ void _int4_plain_mm_cuda( ET_CHECK(qdata.dim() == 2); ET_CHECK(qdata.size(1) == K / 2); ET_CHECK(scale.dim() == 2); - ET_CHECK(scale.size(1) == N); + ET_CHECK(scale.size(0) == N); ET_CHECK(zero.dim() == 2); - ET_CHECK(zero.size(1) == N); + ET_CHECK(zero.size(0) == N); int32_t gs = static_cast(group_size); ET_CHECK_MSG( @@ -279,15 +290,15 @@ void _int4_plain_mm_cuda( // dp4a matvec dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M); dim3 block(MV_WARP_SIZE, MV_NWARPS); - int4_w4a8_matvec_kernel<<>>( + + int32_t n_groups = static_cast(scale.size(1)); + int4_w4a8_matvec_coalesced_kernel<<>>( reinterpret_cast(qdata.data_ptr()), reinterpret_cast(scale.data_ptr()), reinterpret_cast(zero.data_ptr()), q8_buf, reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), - N, - K, - gs_shift); + N, K, gs_shift, n_groups); } } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp index ab18e33c713..de5fd9774e0 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp @@ -70,6 +70,18 @@ class AOTITorchInt4PlainMMTest : public ::testing::Test { cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); } + // Transpose a uint16 [rows, cols] row-major buffer into [cols, rows]. + // Used to convert native [n_groups, N] scale/zero literals into the + // [N, n_groups] layout the shim now expects (transposed AOT at export). + static std::vector + transpose_u16(const uint16_t* src, int rows, int cols) { + std::vector dst(static_cast(rows) * cols); + for (int r = 0; r < rows; r++) + for (int c = 0; c < cols; c++) + dst[static_cast(c) * rows + r] = src[r * cols + c]; + return dst; + } + // Run the shim and return the output tensor (asserts success). Tensor* run( Tensor* A, @@ -111,7 +123,7 @@ class AOTITorchInt4PlainMMTest : public ::testing::Test { }; // MultiGroupRandom: M=1, N=4, K=32, gs=16 -// scale/zero layout: [K//gs=2, N=4] +// scale/zero layout: [N=4, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { int64_t M = 1, K = 32, N = 4, gs = 16; @@ -132,14 +144,17 @@ TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { uint16_t expected[] = {0xBFCC, 0x3FB5, 0x4046, 0xC01E}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -149,7 +164,7 @@ TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { } // SingleGroup: M=1, N=8, K=32, gs=32 -// scale/zero layout: [K//gs=1, N=8] +// scale/zero layout: [N=8, K//gs=1] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { int64_t M = 1, K = 32, N = 8, gs = 32; @@ -178,14 +193,17 @@ TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { uint16_t expected[] = {0xC031, 0x3BF8, 0x3E81, 0xBF19, 0x3FCB, 0xBF56, 0x4076, 0x3F20}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -195,7 +213,7 @@ TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { } // PrefillBatch: M=8, N=4, K=64, gs=32 -// scale/zero layout: [K//gs=2, N=4] +// scale/zero layout: [N=4, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { int64_t M = 8, K = 64, N = 4, gs = 32; @@ -224,14 +242,17 @@ TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { uint16_t expected[] = {0x40BD, 0xC0E3, 0x4037, 0x40A9, 0x406F, 0x4116, 0x3F8D, 0xC01F, 0xC039, 0xC043, 0x3F86, 0x410A, 0x3F07, 0xC100, 0x4019, 0x40D7, 0x40A9, 0x40F1, 0xBF89, 0x406F, 0x40FE, 0xBFB8, 0xBF88, 0x406A, 0x4004, 0x3EDE, 0x3E17, 0x4102, 0xC081, 0xC0BA, 0xBFFB, 0x3F25}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -241,7 +262,7 @@ TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { } // GroupSize128: M=1, N=2, K=256, gs=128 -// scale/zero layout: [K//gs=2, N=2] +// scale/zero layout: [N=2, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { int64_t M = 1, K = 256, N = 2, gs = 128; @@ -286,14 +307,17 @@ TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { uint16_t expected[] = {0xC013, 0xBF05}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -307,8 +331,8 @@ TEST_F(AOTITorchInt4PlainMMTest, NullInputHandling) { Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, K / gs}); + Tensor* zero = create_bf16({N, K / gs}); Tensor* output = nullptr; EXPECT_EQ( @@ -357,7 +381,7 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { 0x63, 0x9A, 0x95, 0x78, 0x95, 0x69, 0xF8, 0x58, 0x65, 0x0A, 0x6B, 0x47, 0x9C, 0x5C, 0x6A, 0x35, 0xA2, 0x8A, 0x74, 0x93, 0x28, 0x6D, 0xF0, 0xAB, 0x23, 0xA6, 0xA6, 0x3A}; - // scale/zero are [K//gs, N] = [2, 8] — Int4Tensor's native layout + // scale/zero are [N, K//gs] = [8, 2] — transposed AOT for the coalesced kernel uint16_t scale_host[] = { 0x3E46, 0x3E94, 0x3E8F, 0x3E94, 0x3E94, 0x3E8D, 0x3EA5, 0x3EA5, 0x3E9F, 0x3EAD, 0x3E91, 0x3EA0, 0x3E88, 0x3EB7, 0x3E89, 0x3E92}; @@ -380,13 +404,15 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - // Note: scale/zero shape is [n_groups, N], NOT [N, n_groups] - Tensor* scale = create_bf16({n_groups, N}); - Tensor* zero = create_bf16({n_groups, N}); + // scale/zero shape is [N, n_groups] (transposed AOT) + Tensor* scale = create_bf16({N, n_groups}); + Tensor* zero = create_bf16({N, n_groups}); + auto scale_t = transpose_u16(scale_host, n_groups, N); + auto zero_t = transpose_u16(zero_host, n_groups, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -395,3 +421,25 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { // W4A8 adds quantization noise vs bf16 reference — use wider tolerance check_bf16_output(output, expected, M * N, 0.5f); } + +// RejectsNativeLayout: scale/zero passed in the un-transposed native +// [n_groups, N] layout (instead of the coalesced [N, n_groups] AOT layout) +// must be rejected gracefully with Error::InvalidArgument, not crash. +// K=64, gs=32 -> n_groups=2, N=8; native scale is [2, 8] while the shim +// expects coalesced [8, 2]. n_groups != N so the shape guard can catch it. +TEST_F(AOTITorchInt4PlainMMTest, RejectsNativeLayout) { + int64_t M = 1, K = 64, N = 8, gs = 32; + int64_t n_groups = K / gs; // 2 + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + // Native torchao layout [n_groups, N] = [2, 8], NOT the coalesced + // [N, n_groups] = [8, 2] the shim expects. + Tensor* scale = create_bf16({n_groups, N}); + Tensor* zero = create_bf16({n_groups, N}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, zero, gs, &output), + Error::InvalidArgument); +} diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py index 51d573d33a3..fd748ae8584 100644 --- a/backends/cuda/tests/test_int4_dispatch.py +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -24,13 +24,21 @@ python -m pytest backends/cuda/tests/test_int4_dispatch.py -v """ +import contextlib import unittest +from unittest import mock import executorch.backends.cuda.quantize_op_dispatch.int4_dispatch # noqa: F401 import torch import torch.nn as nn import torch.nn.functional as F -from executorch.examples.models.gemma4_31b.quant.quantize import quantize_weight +from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor +from executorch.backends.cuda.quantize_op_dispatch.int4_dispatch import _dequant_matmul +from executorch.examples.models.gemma4_31b.quant.pack_cuda import pack_linear_for_cuda +from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + quantize_weight, +) from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig @@ -51,8 +59,9 @@ def _make_int4_linear(N, K, group_size=128, symmetric=False, bias=False): ) int4_w = quantize_weight(w_bf16, config) - module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") - module.weight = nn.Parameter(int4_w.cuda(), requires_grad=False) + module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16) + pack_linear_for_cuda(module, {"weight": int4_w}) + module.cuda() return module, w_bf16.cuda() @@ -174,7 +183,7 @@ def test_to_cuda(self): config = QuantConfig(bits=4, group_size=128, symmetric=False, method="min_max") int4_w = quantize_weight(w_bf16, config) module = nn.Linear(512, 256, bias=False) - module.weight = nn.Parameter(int4_w, requires_grad=False) + pack_linear_for_cuda(module, {"weight": int4_w}) module = module.to("cuda") x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") self._check(module(x), F.linear(x, w_bf16.cuda())) @@ -207,5 +216,114 @@ def test_21504x5376_prefill(self): self._check(module(x), F.linear(x, w_ref)) +def _make_int4_tensor(N, K, group_size=128, symmetric=False): + """Build a stock torchao ``Int4Tensor`` (NOT packed/coalesced) on CPU.""" + w = torch.randn(N, K, dtype=torch.bfloat16) + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + return quantize_weight(w, config), w + + +@contextlib.contextmanager +def _record_int4_plain_mm(): + """Record calls to the decode custom op without needing a GPU. + + Replaces ``torch.ops.executorch_cuda.int4_plain_mm`` (whose real impl is the + CUDA C shim) with a recorder that computes the result via the eager CPU + dequant, so the dispatch handler still returns a valid tensor. + """ + calls = [] + + def _fake(self, qdata, scale, zero, group_size): + calls.append((tuple(self.shape), group_size)) + return _dequant_matmul(self, qdata, scale, zero, group_size) + + with mock.patch.object(torch.ops.executorch_cuda, "int4_plain_mm", _fake): + yield calls + + +class TestDispatchRouting(unittest.TestCase): + """Type-based routing: only CudaCoalescedInt4Tensor reaches int4_plain_mm. + + These tests run without a GPU by recording calls to the decode custom op + and computing the result with the eager CPU dequant. They guard the + comment-8 refactor: the CUDA decode path must be selected by weight *type*, + not by globally overriding torchao ``Int4Tensor``'s F.linear. + """ + + def setUp(self): + torch.manual_seed(0) + + def _rel_err(self, out, ref): + return ( + (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + ).item() + + def test_stock_int4tensor_does_not_route_to_int4_plain_mm(self): + """A plain torchao Int4Tensor must fall back to torchao's default path.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int4_plain_mm() as calls: + # torchao's default path uses mslk/CUDA and is not exercised on CPU; + # we only assert that our decode op is NOT reached. + with contextlib.suppress(Exception): + F.linear(x, t) + self.assertEqual(calls, []) + + def test_coalesced_tensor_routes_to_int4_plain_mm(self): + """CudaCoalescedInt4Tensor with M<=4 routes to the decode custom op.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(len(calls), 1) + self.assertEqual(out.shape, (1, 16)) + + def test_coalesced_tensor_prefill_uses_dequant(self): + """M>4 uses inline dequant (no custom op) and is numerically correct.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + x = torch.randn(8, 64, dtype=torch.bfloat16) # M=8 > 4 (prefill regime) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(calls, []) + ref = F.linear(x, dequantize_weight(t, torch.bfloat16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_square_shape_not_misrouted(self): + """N == n_groups (square scale) stock tensor is still not routed. + + K = group_size * N makes scale square (n_groups == N); the old shape + heuristic could not distinguish this coalesced-looking case. Type-based + routing makes the scale shape irrelevant. + """ + t, _ = _make_int4_tensor(4, 128, group_size=32) + self.assertEqual(tuple(t.scale.shape), (4, 4)) # (n_groups, N), square + x = torch.randn(1, 128, dtype=torch.bfloat16) + with _record_int4_plain_mm() as calls: + with contextlib.suppress(Exception): + F.linear(x, t) + self.assertEqual(calls, []) + + def test_from_int4_tensor_transpose_correct(self): + """from_int4_tensor owns the (n_groups, N) -> (N, n_groups) transpose.""" + t, _ = _make_int4_tensor(24, 192, group_size=64) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + n_groups = 192 // 64 + self.assertEqual(tuple(t.scale.shape), (n_groups, 24)) # torchao layout + self.assertEqual(tuple(c.scale.shape), (24, n_groups)) # coalesced layout + self.assertTrue(torch.equal(c.scale, t.scale.t().contiguous())) + self.assertTrue(torch.equal(c.zero_point, t.zero_point.t().contiguous())) + # End-to-end decode result matches a reference dequant of the original. + x = torch.randn(2, 192, dtype=torch.bfloat16) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(len(calls), 1) + ref = F.linear(x, dequantize_weight(t, torch.bfloat16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + if __name__ == "__main__": unittest.main() diff --git a/backends/cuda/tests/test_sdpa_splitk_replacement.py b/backends/cuda/tests/test_sdpa_splitk_replacement.py index 414a1308777..465b0b7ecf4 100644 --- a/backends/cuda/tests/test_sdpa_splitk_replacement.py +++ b/backends/cuda/tests/test_sdpa_splitk_replacement.py @@ -6,9 +6,9 @@ """Test ReplaceEdgeOpWithTritonOpPass split-K SDPA kernel selection. -Exports a minimal model containing F.scaled_dot_product_attention through -the CUDA backend and verifies that the pass routes to split-K for decode -(L_q=1, large L_kv) and standard SDPA otherwise. +Exports a minimal model containing F.scaled_dot_product_attention through the +CUDA backend and verifies that the pass routes to split-K for decode +(L_q==1, L_kv >= 256) and standard SDPA otherwise. """ import logging @@ -106,9 +106,9 @@ class TestSplitKReplacement(unittest.TestCase): def setUp(self): _require_cuda(self) - def test_large_kv_cache_uses_splitk(self): - """L_kv=4096 > threshold → split-K selected for decode.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + def test_below_threshold_uses_standard(self): + """L_kv=128 < threshold (256) -> standard SDPA, no split-K.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=128).to( torch.bfloat16 ) args = ( @@ -119,12 +119,17 @@ def test_large_kv_cache_uses_splitk(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") - self.assertIn("L_kv=4096", splitk[0]) + self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") - def test_small_kv_cache_uses_standard(self): - """L_kv=512 <= threshold → standard SDPA, no split-K.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=512).to( + replaced = [m for m in msgs if "Replaced" in m] + self.assertTrue( + any("1 nodes" in m for m in replaced), + f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + ) + + def test_at_threshold_uses_splitk(self): + """L_kv=256 == threshold -> split-K selected (boundary, inclusive).""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=256).to( torch.bfloat16 ) args = ( @@ -135,16 +140,27 @@ def test_small_kv_cache_uses_standard(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=256", splitk[0]) - replaced = [m for m in msgs if "Replaced" in m] - self.assertTrue( - any("1 nodes" in m for m in replaced), - f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + def test_large_kv_cache_uses_splitk(self): + """L_kv=4096 > threshold -> split-K selected for decode.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + torch.bfloat16 ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=4096", splitk[0]) def test_non_pow2_head_dim_uses_standard(self): - """Non-power-of-2 head_dim → standard SDPA even with large L_kv.""" + """Non-power-of-2 head_dim -> standard SDPA even with large L_kv.""" model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=96, kv_len=8192).to( torch.bfloat16 ) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index 628222e46f7..c55965a00e1 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,7 +27,8 @@ exir_ops.edge.aten.topk.default: triton.topk, } -_SPLITK_LKV_THRESHOLD = 2048 + +_SPLITK_LKV_THRESHOLD = 256 class ReplaceEdgeOpWithTritonOpPass(PassBase): @@ -94,6 +95,9 @@ def _pick_sdpa_kernel(node: Node): (full-attention KV caches) but loses to the standard kernel for small L_kv (sliding-window ring buffers) due to the overhead of allocating partial buffers and running the reduction kernel. + + TODO(gasoonjia): Benchmarking to determine the optimal + implmentation for each shape. """ q_shape = node.args[0].meta["val"].shape k_shape = node.args[1].meta["val"].shape @@ -104,7 +108,7 @@ def _pick_sdpa_kernel(node: Node): isinstance(L_q, int) and L_q == 1 and isinstance(L_kv, int) - and L_kv > _SPLITK_LKV_THRESHOLD + and L_kv >= _SPLITK_LKV_THRESHOLD and D > 0 and (D & (D - 1)) == 0 # power of 2 ): diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 037c3bd8310..655d773e7b3 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -6,8 +6,10 @@ """CUDA packer: assign quantized weights to model modules. -Passes ``Int4Tensor`` and ``IntxUnpackedToInt8Tensor`` through as -``nn.Parameter`` without conversion. The quantize_op_dispatch package +Converts ``Int4Tensor`` weights to the ExecuTorch-internal +``CudaCoalescedInt4Tensor`` (which owns the scale/zero transpose to the +coalesced [N, n_groups] layout) and passes ``IntxUnpackedToInt8Tensor`` through +as ``nn.Parameter`` without conversion. The quantize_op_dispatch package (``int4_dispatch`` / ``int8_dispatch``) handles F.linear at runtime. No CUDA is required for packing. The backend-agnostic ``pack_model`` @@ -28,11 +30,24 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Assign a quantized weight to an ``nn.Linear`` module.""" + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] - if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)): + if isinstance(w, Int4Tensor): + # Convert to the ExecuTorch-internal CudaCoalescedInt4Tensor, which + # repacks scale/zero from torchao's native [n_groups, N] layout into the + # coalesced [N, n_groups] layout the CUDA decode kernel reads (see + # int4_dispatch.py / int4_plain_mm.cuh). The transpose lives in + # CudaCoalescedInt4Tensor.from_int4_tensor, so it is baked into the + # serialized weight constant and the exported decode graph carries NO + # per-step transpose/clone — AOTInductor (freezing=False) does not + # constant-fold ops on parameters, so the transpose must already live in + # the constant for the coalesced layout to pay off. + w = CudaCoalescedInt4Tensor.from_int4_tensor(w) + module.weight = nn.Parameter(w, requires_grad=False) + elif isinstance(w, IntxUnpackedToInt8Tensor): module.weight = nn.Parameter(w, requires_grad=False) else: raise ValueError(f"Unsupported weight type: {type(w).__name__}")