Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions backends/cuda/int4_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@
dequant + cuBLAS matmul kernels.

Dispatch strategy (determines what gets captured in the export graph):
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm``
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)
Small M (M<=MATVEC_MAX_M): Custom op ``executorch_cuda::int4_plain_mm``
Large M (M>MATVEC_MAX_M): Inline dequant + F.linear (standard PyTorch ops)

The custom op is memory-bound and beats dequant+cuBLAS for small M (M==1 matvec;
2<=M<=8 weight-stationary GEMM). ``MATVEC_MAX_M`` defaults to 4 (decode only).
An export may raise it up to the shim's GEMM limit (``GEMM_MAX_M`` = 8 in
``int4_plain_mm.cuh``), but then its *dynamic* shapes must not straddle the
threshold: a dynamic linear whose M range crosses MATVEC_MAX_M makes
torch.export's branch guard ambiguous, so a long-prefill export must declare
``min > MATVEC_MAX_M``. Raising the global default would break exports whose
dynamic prefill range starts below the threshold, so callers set it locally
instead.

Import this module before using nn.Linear with Int4Tensor weights::

Expand All @@ -35,6 +45,16 @@
# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager
# ---------------------------------------------------------------------------

# Largest M the C++ shim's GEMM kernel handles (GEMM_MAX_M in int4_plain_mm.cuh).
# MATVEC_MAX_M must not exceed it, else export captures a shape the shim rejects
# at runtime; the dispatch asserts this below.
SHIM_GEMM_MAX_M = 8

# Max M routed to the custom INT4 op; above this, dequant+cuBLAS wins. Defaults
# to 4 (decode); an export may raise it (<= SHIM_GEMM_MAX_M) for small-M GEMM,
# subject to the dynamic-shape constraint documented above.
MATVEC_MAX_M = 4

_lib = Library("executorch_cuda", "DEF")
_lib.define(
"int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor"
Expand Down Expand Up @@ -98,7 +118,11 @@ def _(func, types, args, kwargs):
gs = weight_tensor.block_size[-1]

M = x_2d.shape[0]
if M <= 4:
if M <= MATVEC_MAX_M:
assert MATVEC_MAX_M <= SHIM_GEMM_MAX_M, (
f"MATVEC_MAX_M={MATVEC_MAX_M} exceeds the shim's GEMM_MAX_M="
f"{SHIM_GEMM_MAX_M} (int4_plain_mm.cuh)"
)
out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs)
else:
out = _dequant_matmul(x_2d, qdata, scale, zero, gs)
Expand Down
139 changes: 130 additions & 9 deletions backends/cuda/runtime/shims/int4_plain_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,107 @@ __global__ void __launch_bounds__(MV_THREADS)
out[static_cast<int64_t>(m) * N + n] = __float2bfloat16(sum);
}

// ---------------------------------------------------------------------------
// W4A8 dp4a GEMM kernel — weight-stationary for small M (2..M_MAX)
//
// The matvec above launches one block-row per M (grid.y = M) and re-reads the
// packed weights for every activation row, so its weight traffic scales with M.
// For speculative verification (M = chain_len+1, a handful of rows) that makes a
// verify cost ~M decodes. This kernel instead loads each weight chunk once and
// accumulates it into all M output rows (grid.y = 1), so weight traffic is 1x
// regardless of M — turning a verify back into ~one decode. Activations (tiny
// INT8) are the only thing read per row.
// ---------------------------------------------------------------------------

template <int32_t M_MAX>
__global__ void __launch_bounds__(MV_THREADS)
int4_w4a8_gemm_kernel(
const uint8_t* __restrict__ qdata,
const __nv_bfloat16* __restrict__ w_scale,
const __nv_bfloat16* __restrict__ w_zero,
const Q8Block* __restrict__ q8,
__nv_bfloat16* __restrict__ out,
int32_t N,
int32_t K,
int32_t gs_shift,
int32_t M) {
const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y;
if (n >= N)
return;

const int32_t K_half = K / 2;
const int32_t lane_id = threadIdx.x;
const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE;

const uint8_t* qrow = qdata + static_cast<int64_t>(n) * K_half;
const __nv_bfloat16* scale_base = w_scale + n;
const __nv_bfloat16* zero_base = w_zero + n;
const int32_t scale_stride = N;

const uint4* qrow16 = reinterpret_cast<const uint4*>(qrow);
const int32_t K_half_16 = K_half / 16;

float sum[M_MAX];
#pragma unroll
for (int32_t mm = 0; mm < M_MAX; mm++)
sum[mm] = 0.0f;

int32_t prev_g = -1;
float ws = 0.0f, wz = 0.0f;

for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) {
uint4 packed16 = __ldg(&qrow16[i]);
int32_t k_base = i * 32;
uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w};

#pragma unroll
for (int32_t w = 0; w < 4; w++) {
uint32_t packed = words[w];
int32_t k_word = k_base + w * 8;
int32_t g = k_word >> gs_shift;

if (g != prev_g) {
ws = __bfloat162float(__ldg(&scale_base[g * scale_stride]));
wz = __bfloat162float(__ldg(&zero_base[g * scale_stride]));
prev_g = g;
}

// Weight nibbles loaded once and reused across all M activation rows.
int32_t vi_lo = packed & 0x0F0F0F0F;
int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F;

int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE;
int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2;

for (int32_t mm = 0; mm < M; mm++) {
const Q8Block* qb =
&q8[static_cast<int64_t>(mm) * n_q8_blocks + q8_block_idx];
int32_t a_even =
*reinterpret_cast<const int32_t*>(qb->qs_even + q8_half_offset);
int32_t a_odd =
*reinterpret_cast<const int32_t*>(qb->qs_odd + q8_half_offset);

int32_t dp = __dp4a(vi_lo, a_even, 0);
dp = __dp4a(vi_hi, a_odd, dp);

int32_t a_sum8 = __dp4a(0x01010101, a_even, 0);
a_sum8 = __dp4a(0x01010101, a_odd, a_sum8);

sum[mm] += ws * qb->d *
(static_cast<float>(dp) - wz * static_cast<float>(a_sum8));
}
}
}

for (int32_t mm = 0; mm < M; mm++) {
float s = sum[mm];
for (int offset = MV_WARP_SIZE / 2; offset > 0; offset >>= 1)
s += __shfl_xor_sync(0xffffffff, s, offset);
if (lane_id == 0)
out[static_cast<int64_t>(mm) * N + n] = __float2bfloat16(s);
}
}

// ---------------------------------------------------------------------------
// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only)
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -263,16 +364,36 @@ void _int4_plain_mm_cuda(
q8_buf,
K);

// dp4a matvec
dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M);
// M==1 (decode): per-row matvec. M>1 (e.g. speculative verify): weight-
// stationary GEMM that reads each weight once and accumulates into all M rows.
// Must be >= the Python dispatch's MATVEC_MAX_M (int4_dispatch.py): the export
// captures int4_plain_mm for M<=MATVEC_MAX_M, and those M reach this kernel.
constexpr int32_t GEMM_MAX_M = 8;
dim3 block(MV_WARP_SIZE, MV_NWARPS);
int4_w4a8_matvec_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(qdata.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(scale.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(zero.data_ptr()),
q8_buf,
reinterpret_cast<__nv_bfloat16*>(output->data_ptr()),
N, K, gs_shift);
if (M == 1) {
dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, 1);
int4_w4a8_matvec_kernel<<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(qdata.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(scale.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(zero.data_ptr()),
q8_buf,
reinterpret_cast<__nv_bfloat16*>(output->data_ptr()),
N, K, gs_shift);
} else {
ET_CHECK_MSG(
M <= GEMM_MAX_M,
"int4_plain_mm GEMM kernel supports M<=%d, got M=%d",
GEMM_MAX_M,
M);
dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, 1);
int4_w4a8_gemm_kernel<GEMM_MAX_M><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(qdata.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(scale.data_ptr()),
reinterpret_cast<const __nv_bfloat16*>(zero.data_ptr()),
q8_buf,
reinterpret_cast<__nv_bfloat16*>(output->data_ptr()),
N, K, gs_shift, M);
}
}

} // namespace executorch::backends::cuda
31 changes: 31 additions & 0 deletions backends/cuda/tests/test_int4_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,36 @@ def test_21504x5376_prefill(self):
self._check(module(x), F.linear(x, w_ref))


class TestLimitConsistency(unittest.TestCase):
"""The Python export-side GEMM limit must match the C++ shim's GEMM_MAX_M.

``SHIM_GEMM_MAX_M`` is a hand-maintained Python copy of the C++ constant; if
they drift, export can capture a shape the runtime shim rejects (or block one
it supports). No CUDA needed -- this just compares the two source constants.
"""

def test_shim_gemm_max_m_matches_cuh(self):
import os
import re

import executorch.backends.cuda.int4_dispatch as int4_dispatch

cuh = os.path.join(
os.path.dirname(int4_dispatch.__file__),
"runtime",
"shims",
"int4_plain_mm.cuh",
)
with open(cuh) as f:
m = re.search(r"\bGEMM_MAX_M\s*=\s*(\d+)", f.read())
self.assertIsNotNone(m, "GEMM_MAX_M not found in int4_plain_mm.cuh")
self.assertEqual(
int(m.group(1)),
int4_dispatch.SHIM_GEMM_MAX_M,
"SHIM_GEMM_MAX_M (int4_dispatch.py) is out of sync with the C++ "
"GEMM_MAX_M (int4_plain_mm.cuh)",
)


if __name__ == "__main__":
unittest.main()
Loading