Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 186 additions & 186 deletions backends/cuda/benchmarks/benchmark_sdpa.py

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions backends/cuda/coalesced_int4_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""ExecuTorch-internal INT4 tensor for the CUDA W4A8 dp4a decode kernel.

``CudaCoalescedInt4Tensor`` is an ExecuTorch-internal tensor subclass. It is
**NOT** torchao's ``Int4Tensor`` and is intentionally not a subclass of it, so
torchao's ``Int4Tensor`` F.linear handlers never match it via the method
resolution order. The CUDA decode/prefill dispatch (``int4_dispatch.py``) is
selected by *type* — it is registered on this class only — so stock
``Int4Tensor`` weights keep falling back to torchao's default (mslk/tinygemm)
path.

Layout difference from torchao ``Int4Tensor``:
qdata : packed int4 weight (N, K/2), nibble-packed (same as Int4Tensor)
scale : (N, n_groups) — the *coalesced* layout, transposed from
torchao's documented (n_groups, N)
zero_point : (N, n_groups) — coalesced, transposed from (n_groups, N)

The coalesced [N, n_groups] layout is exactly what the W4A8 dp4a matvec kernel
(``executorch_cuda::int4_plain_mm`` / ``int4_plain_mm.cuh``) reads row-for-row
with qdata, so the exported decode graph carries no per-step transpose. The
transpose is owned by :meth:`from_int4_tensor` so it is baked into the
serialized weight constant once at pack time.
"""

from typing import List, Optional

import torch
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from torchao.utils import TorchAOBaseTensor

__all__ = [
"CudaCoalescedInt4Tensor",
]


class CudaCoalescedInt4Tensor(TorchAOBaseTensor):
"""INT4 weight with scale/zero_point in the coalesced [N, n_groups] layout.

ExecuTorch-internal; see the module docstring. Mirrors torchao
``Int4Tensor``'s data/attribute layout (so the common tensor utilities and
serialization work) but owns the [n_groups, N] -> [N, n_groups] transpose
of scale/zero_point via :meth:`from_int4_tensor`.
"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
optional_tensor_data_names = ["act_pre_scale"]
optional_tensor_attribute_names = ["activation_dtype"]

def __new__(
cls,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: List[int],
shape: torch.Size,
act_pre_scale: Optional[torch.Tensor] = None,
activation_dtype: Optional[torch.dtype] = None,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: List[int],
shape: torch.Size,
act_pre_scale: Optional[torch.Tensor] = None,
activation_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.activation_dtype = (
activation_dtype if activation_dtype is not None else torch.bfloat16
)
self.act_pre_scale = act_pre_scale

def _quantization_type(self):
s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}, activation_dtype={self.activation_dtype}"
if self.act_pre_scale is not None:
s += f", act_pre_scale.shape={self.act_pre_scale.shape}"
return s

@classmethod
def from_int4_tensor(cls, t: Int4Tensor) -> "CudaCoalescedInt4Tensor":
"""Build a coalesced tensor from a torchao ``Int4Tensor``.

Owns the transpose: torchao stores scale/zero_point as (n_groups, N);
the CUDA decode kernel reads (N, n_groups). The ``.t().contiguous()``
here is baked into the serialized weight constant so the exported
decode graph has no per-step transpose/clone.
"""
return cls(
t.qdata,
t.scale.t().contiguous(),
t.zero_point.t().contiguous(),
t.block_size,
t.shape,
t.act_pre_scale,
t.activation_dtype,
)


# Allow a model with CudaCoalescedInt4Tensor weights to be loaded with
# `weights_only=True` (mirrors torchao Int4Tensor).
torch.serialization.add_safe_globals([CudaCoalescedInt4Tensor])
4 changes: 2 additions & 2 deletions backends/cuda/quantize_op_dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
weight tensors so that torch.export traces through ExecuTorch's custom ops and
dequant logic instead of torchao's defaults. It registers:

* INT4 (``Int4Tensor``) → ``executorch_cuda::int4_plain_mm``
* INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm``
* INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm``
* INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm``

See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details.

Expand Down
42 changes: 28 additions & 14 deletions backends/cuda/quantize_op_dispatch/int4_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time.
"""CudaCoalescedInt4Tensor F.linear dispatch for CUDA — runs at eager / export trace time.

This module overrides Int4Tensor's F.linear dispatch so that torch.export
traces through our custom op and dequant logic instead of torchao's default
(mslk/tinygemm). The code here executes during eager inference and during
AOTI export tracing — it does NOT run at .pte runtime.
This module registers an F.linear dispatch on ``CudaCoalescedInt4Tensor`` (an
ExecuTorch-internal subclass, see ``coalesced_int4_tensor.py``) so that
torch.export traces through our custom op and dequant logic. Routing is by
*type*: stock torchao ``Int4Tensor`` weights are left untouched and keep using
torchao's default (mslk/tinygemm) path. The code here executes during eager
inference and during AOTI export tracing — it does NOT run at .pte runtime.

At .pte runtime, the captured graph is executed by the AOTI-generated .so:
- The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that
Expand All @@ -22,17 +24,17 @@
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)

Importing the parent ``quantize_op_dispatch`` package registers this dispatch
override (along with the INT8 one) before using nn.Linear with Int4Tensor
weights::
override (along with the INT8 one) before using nn.Linear with
CudaCoalescedInt4Tensor weights::

import executorch.backends.cuda.quantize_op_dispatch # noqa: F401
"""

import torch
import torch.nn.functional as F
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor
from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib
from torch.library import impl
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor

# ---------------------------------------------------------------------------
# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager
Expand All @@ -52,11 +54,18 @@ def _meta(self, qdata, scale, zero, group_size):

@impl(_lib, "int4_plain_mm", "CUDA")
def _cuda(self, qdata, scale, zero, group_size):
# scale/zero are stored in the coalesced [N, n_groups] layout (transposed
# at pack time, see pack_cuda.pack_linear_for_cuda), which is exactly what
# _dequant_matmul expects.
return _dequant_matmul(self, qdata, scale, zero, group_size)


def _dequant_matmul(x, qdata, scale, zero, group_size):
"""Dequant INT4 weights to input dtype and call F.linear."""
"""Dequant INT4 weights to input dtype and call F.linear.

scale/zero are in the coalesced [N, n_groups] layout (baked into the
weight constant at pack time), aligned row-for-row with qdata's [N, *].
"""
N, K_half = qdata.shape
K = K_half * 2
n_groups = K // group_size
Expand All @@ -68,20 +77,20 @@ def _dequant_matmul(x, qdata, scale, zero, group_size):
high = ((p >> 4) & 0x0F).to(dtype)
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)

s = scale.to(dtype).t().unsqueeze(-1)
z = zero.to(dtype).t().unsqueeze(-1)
s = scale.to(dtype).unsqueeze(-1)
z = zero.to(dtype).unsqueeze(-1)
w_deq = ((data - z) * s).reshape(N, K)

return F.linear(x, w_deq)


# ---------------------------------------------------------------------------
# Int4Tensor F.linear dispatch
# CudaCoalescedInt4Tensor F.linear dispatch
# ---------------------------------------------------------------------------

aten = torch.ops.aten
_implements = Int4Tensor.implements
_implements_torch_function = Int4Tensor.implements_torch_function
_implements = CudaCoalescedInt4Tensor.implements
_implements_torch_function = CudaCoalescedInt4Tensor.implements_torch_function


@_implements([aten.linear.default])
Expand All @@ -101,6 +110,11 @@ def _(func, types, args, kwargs):

M = x_2d.shape[0]
if M <= 4:
# scale/zero are already in the coalesced [N, n_groups] layout the
# decode kernel reads directly (baked into the weight constant at pack
# time). Passing them straight through keeps the export graph free of
# any per-step transpose/clone, so the coalesced layout is realized
# without recomputing it every decode step.
out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs)
else:
out = _dequant_matmul(x_2d, qdata, scale, zero, gs)
Expand Down
37 changes: 36 additions & 1 deletion backends/cuda/runtime/shims/int4_plain_mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,43 @@ AOTITorchError aoti_torch_cuda_int4_plain_mm(
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: ret0 is null");

// Validate the coalesced scale/zero layout [N, K/group_size]

const int64_t N = qdata->size(0);
const int64_t K = qdata->size(1) * 2;

ET_CHECK_OR_RETURN_ERROR(
group_size > 0 && (group_size & (group_size - 1)) == 0,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: group_size=%lld must be a positive power of 2",
static_cast<long long>(group_size));

const int64_t n_groups = K / group_size;

ET_CHECK_OR_RETURN_ERROR(
scale->dim() == 2 && zero->dim() == 2,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: scale/zero must be 2D (got scale.dim()=%lld, zero.dim()=%lld)",
static_cast<long long>(scale->dim()),
static_cast<long long>(zero->dim()));

ET_CHECK_OR_RETURN_ERROR(
scale->size(0) == N && zero->size(0) == N,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(0)=N=%lld, got scale.size(0)=%lld, zero.size(0)=%lld",
static_cast<long long>(N),
static_cast<long long>(scale->size(0)),
static_cast<long long>(zero->size(0)));

ET_CHECK_OR_RETURN_ERROR(
scale->size(1) == n_groups && zero->size(1) == n_groups,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(1)=K/group_size=%lld, got scale.size(1)=%lld, zero.size(1)=%lld",
static_cast<long long>(n_groups),
static_cast<long long>(scale->size(1)),
static_cast<long long>(zero->size(1)));

int32_t M = self->size(0);
int32_t N = qdata->size(0);
Tensor* C = nullptr;
std::array<int64_t, 2> c_shape = {M, N};
std::array<int64_t, 2> c_stride = {N, 1};
Expand Down
59 changes: 35 additions & 24 deletions backends/cuda/runtime/shims/int4_plain_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// W4A8 dp4a matvec for INT4 decode (M <= 4).
//
// Reads plain nibble-packed [N, K//2] weights (Int4Tensor format).
// Scale/zero layout: [K//gs, N] (Int4Tensor's native layout).
// Scale/zero layout: [N, K//gs] (transposed AOT for coalesced loads).
//
// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks),
// then uses dp4a for fused int4×int8 dot products with 16-byte vectorized
Expand Down Expand Up @@ -98,18 +98,28 @@ __global__ void quantize_activations_q8_kernel(
}

// ---------------------------------------------------------------------------
// W4A8 dp4a matvec kernel
// Coalesced-scale W4A8 dp4a matvec
//
// Reads scale/zero in the transposed [N, n_groups] layout (transposed AOT at
// export time). With group_size >= 32, one uint4 (32 weights) maps to exactly
// one activation block and one weight group, so within a warp the 32 lanes
// touch 32 consecutive groups. In [N, n_groups] layout those 32 group scales
// are contiguous => a single coalesced load, vs 32 stride-N cache lines in the
// native layout. For the gemma group_size=32 weights this is the dominant
// decode-matvec cost.
// ---------------------------------------------------------------------------

__global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel(
const uint8_t* __restrict__ qdata,
const __nv_bfloat16* __restrict__ w_scale,
const __nv_bfloat16* __restrict__ w_zero,
const Q8Block* __restrict__ q8,
__nv_bfloat16* __restrict__ out,
int32_t N,
int32_t K,
int32_t gs_shift) {
__global__ void __launch_bounds__(MV_THREADS)
int4_w4a8_matvec_coalesced_kernel(
const uint8_t* __restrict__ qdata,
const __nv_bfloat16* __restrict__ w_scale_t, // [N, n_groups]
const __nv_bfloat16* __restrict__ w_zero_t, // [N, n_groups]
const Q8Block* __restrict__ q8,
__nv_bfloat16* __restrict__ out,
int32_t N,
int32_t K,
int32_t gs_shift,
int32_t n_groups) {
const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y;
const int32_t m = blockIdx.y;
if (n >= N)
Expand All @@ -120,9 +130,10 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel(
const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE;

const uint8_t* qrow = qdata + static_cast<int64_t>(n) * K_half;
const __nv_bfloat16* scale_base = w_scale + n;
const __nv_bfloat16* zero_base = w_zero + n;
const int32_t scale_stride = N;
const __nv_bfloat16* scale_row =
w_scale_t + static_cast<int64_t>(n) * n_groups;
const __nv_bfloat16* zero_row =
w_zero_t + static_cast<int64_t>(n) * n_groups;
const Q8Block* q8_row = q8 + static_cast<int64_t>(m) * n_q8_blocks;

const uint4* qrow16 = reinterpret_cast<const uint4*>(qrow);
Expand All @@ -145,8 +156,8 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel(
int32_t g = k_word >> gs_shift;

if (g != prev_g) {
ws = __bfloat162float(__ldg(&scale_base[g * scale_stride]));
wz = __bfloat162float(__ldg(&zero_base[g * scale_stride]));
ws = __bfloat162float(__ldg(&scale_row[g]));
wz = __bfloat162float(__ldg(&zero_row[g]));
prev_g = g;
}

Expand Down Expand Up @@ -227,8 +238,8 @@ static Q8Block* get_q8_buffer(size_t needed) {
void _int4_plain_mm_cuda(
const Tensor& A, // [M, K] bf16
const Tensor& qdata, // [N, K//2] uint8
const Tensor& scale, // [K//gs, N] bf16
const Tensor& zero, // [K//gs, N] bf16
const Tensor& scale, // [N, K//gs] bf16
const Tensor& zero, // [N, K//gs] bf16
int64_t group_size,
Tensor* output) { // [M, N] bf16, pre-allocated
int32_t M = A.size(0);
Expand All @@ -245,9 +256,9 @@ void _int4_plain_mm_cuda(
ET_CHECK(qdata.dim() == 2);
ET_CHECK(qdata.size(1) == K / 2);
ET_CHECK(scale.dim() == 2);
ET_CHECK(scale.size(1) == N);
ET_CHECK(scale.size(0) == N);
ET_CHECK(zero.dim() == 2);
ET_CHECK(zero.size(1) == N);
ET_CHECK(zero.size(0) == N);

int32_t gs = static_cast<int32_t>(group_size);
ET_CHECK_MSG(
Expand Down Expand Up @@ -279,15 +290,15 @@ void _int4_plain_mm_cuda(
// dp4a matvec
dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M);
dim3 block(MV_WARP_SIZE, MV_NWARPS);
int4_w4a8_matvec_kernel<<<grid, block, 0, stream>>>(

int32_t n_groups = static_cast<int32_t>(scale.size(1));
int4_w4a8_matvec_coalesced_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(qdata.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(scale.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(zero.data_ptr()),
q8_buf,
reinterpret_cast<__nv_bfloat16*>(output->data_ptr()),
N,
K,
gs_shift);
N, K, gs_shift, n_groups);
}

} // namespace executorch::backends::cuda
Loading
Loading