diff --git a/backends/cuda/int4_dispatch.py b/backends/cuda/int4_dispatch.py index d8bcb1acbd0..506b85768ef 100644 --- a/backends/cuda/int4_dispatch.py +++ b/backends/cuda/int4_dispatch.py @@ -18,8 +18,18 @@ dequant + cuBLAS matmul kernels. Dispatch strategy (determines what gets captured in the export graph): - Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` - Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) + Small M (M<=MATVEC_MAX_M): Custom op ``executorch_cuda::int4_plain_mm`` + Large M (M>MATVEC_MAX_M): Inline dequant + F.linear (standard PyTorch ops) + +The custom op is memory-bound and beats dequant+cuBLAS for small M (M==1 matvec; +2<=M<=8 weight-stationary GEMM). ``MATVEC_MAX_M`` defaults to 4 (decode only). +An export may raise it up to the shim's GEMM limit (``GEMM_MAX_M`` = 8 in +``int4_plain_mm.cuh``), but then its *dynamic* shapes must not straddle the +threshold: a dynamic linear whose M range crosses MATVEC_MAX_M makes +torch.export's branch guard ambiguous, so a long-prefill export must declare +``min > MATVEC_MAX_M``. Raising the global default would break exports whose +dynamic prefill range starts below the threshold, so callers set it locally +instead. Import this module before using nn.Linear with Int4Tensor weights:: @@ -35,6 +45,16 @@ # Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager # --------------------------------------------------------------------------- +# Largest M the C++ shim's GEMM kernel handles (GEMM_MAX_M in int4_plain_mm.cuh). +# MATVEC_MAX_M must not exceed it, else export captures a shape the shim rejects +# at runtime; the dispatch asserts this below. +SHIM_GEMM_MAX_M = 8 + +# Max M routed to the custom INT4 op; above this, dequant+cuBLAS wins. Defaults +# to 4 (decode); an export may raise it (<= SHIM_GEMM_MAX_M) for small-M GEMM, +# subject to the dynamic-shape constraint documented above. +MATVEC_MAX_M = 4 + _lib = Library("executorch_cuda", "DEF") _lib.define( "int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor" @@ -98,7 +118,11 @@ def _(func, types, args, kwargs): gs = weight_tensor.block_size[-1] M = x_2d.shape[0] - if M <= 4: + if M <= MATVEC_MAX_M: + assert MATVEC_MAX_M <= SHIM_GEMM_MAX_M, ( + f"MATVEC_MAX_M={MATVEC_MAX_M} exceeds the shim's GEMM_MAX_M=" + f"{SHIM_GEMM_MAX_M} (int4_plain_mm.cuh)" + ) out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) else: out = _dequant_matmul(x_2d, qdata, scale, zero, gs) diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index ea236e8d069..d0815d7b481 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -182,6 +182,107 @@ __global__ void __launch_bounds__(MV_THREADS) out[static_cast(m) * N + n] = __float2bfloat16(sum); } +// --------------------------------------------------------------------------- +// W4A8 dp4a GEMM kernel — weight-stationary for small M (2..M_MAX) +// +// The matvec above launches one block-row per M (grid.y = M) and re-reads the +// packed weights for every activation row, so its weight traffic scales with M. +// For speculative verification (M = chain_len+1, a handful of rows) that makes a +// verify cost ~M decodes. This kernel instead loads each weight chunk once and +// accumulates it into all M output rows (grid.y = 1), so weight traffic is 1x +// regardless of M — turning a verify back into ~one decode. Activations (tiny +// INT8) are the only thing read per row. +// --------------------------------------------------------------------------- + +template +__global__ void __launch_bounds__(MV_THREADS) + int4_w4a8_gemm_kernel( + const uint8_t* __restrict__ qdata, + const __nv_bfloat16* __restrict__ w_scale, + const __nv_bfloat16* __restrict__ w_zero, + const Q8Block* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift, + int32_t M) { + const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y; + if (n >= N) + return; + + const int32_t K_half = K / 2; + const int32_t lane_id = threadIdx.x; + const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; + + const uint8_t* qrow = qdata + static_cast(n) * K_half; + const __nv_bfloat16* scale_base = w_scale + n; + const __nv_bfloat16* zero_base = w_zero + n; + const int32_t scale_stride = N; + + const uint4* qrow16 = reinterpret_cast(qrow); + const int32_t K_half_16 = K_half / 16; + + float sum[M_MAX]; +#pragma unroll + for (int32_t mm = 0; mm < M_MAX; mm++) + sum[mm] = 0.0f; + + int32_t prev_g = -1; + float ws = 0.0f, wz = 0.0f; + + for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) { + uint4 packed16 = __ldg(&qrow16[i]); + int32_t k_base = i * 32; + uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + +#pragma unroll + for (int32_t w = 0; w < 4; w++) { + uint32_t packed = words[w]; + int32_t k_word = k_base + w * 8; + int32_t g = k_word >> gs_shift; + + if (g != prev_g) { + ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); + wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + prev_g = g; + } + + // Weight nibbles loaded once and reused across all M activation rows. + int32_t vi_lo = packed & 0x0F0F0F0F; + int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; + + int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; + int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; + + for (int32_t mm = 0; mm < M; mm++) { + const Q8Block* qb = + &q8[static_cast(mm) * n_q8_blocks + q8_block_idx]; + int32_t a_even = + *reinterpret_cast(qb->qs_even + q8_half_offset); + int32_t a_odd = + *reinterpret_cast(qb->qs_odd + q8_half_offset); + + int32_t dp = __dp4a(vi_lo, a_even, 0); + dp = __dp4a(vi_hi, a_odd, dp); + + int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); + a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + + sum[mm] += ws * qb->d * + (static_cast(dp) - wz * static_cast(a_sum8)); + } + } + } + + for (int32_t mm = 0; mm < M; mm++) { + float s = sum[mm]; + for (int offset = MV_WARP_SIZE / 2; offset > 0; offset >>= 1) + s += __shfl_xor_sync(0xffffffff, s, offset); + if (lane_id == 0) + out[static_cast(mm) * N + n] = __float2bfloat16(s); + } +} + // --------------------------------------------------------------------------- // Persistent Q8 buffer (lazy init, not thread-safe — single-stream only) // --------------------------------------------------------------------------- @@ -263,16 +364,36 @@ void _int4_plain_mm_cuda( q8_buf, K); - // dp4a matvec - dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M); + // M==1 (decode): per-row matvec. M>1 (e.g. speculative verify): weight- + // stationary GEMM that reads each weight once and accumulates into all M rows. + // Must be >= the Python dispatch's MATVEC_MAX_M (int4_dispatch.py): the export + // captures int4_plain_mm for M<=MATVEC_MAX_M, and those M reach this kernel. + constexpr int32_t GEMM_MAX_M = 8; dim3 block(MV_WARP_SIZE, MV_NWARPS); - int4_w4a8_matvec_kernel<<>>( - reinterpret_cast(qdata.data_ptr()), - reinterpret_cast(scale.data_ptr()), - reinterpret_cast(zero.data_ptr()), - q8_buf, - reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), - N, K, gs_shift); + if (M == 1) { + dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, 1); + int4_w4a8_matvec_kernel<<>>( + reinterpret_cast(qdata.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(zero.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, K, gs_shift); + } else { + ET_CHECK_MSG( + M <= GEMM_MAX_M, + "int4_plain_mm GEMM kernel supports M<=%d, got M=%d", + GEMM_MAX_M, + M); + dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, 1); + int4_w4a8_gemm_kernel<<>>( + reinterpret_cast(qdata.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(zero.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, K, gs_shift, M); + } } } // namespace executorch::backends::cuda diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py index c793544ad48..d760109b626 100644 --- a/backends/cuda/tests/test_int4_dispatch.py +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -208,5 +208,36 @@ def test_21504x5376_prefill(self): self._check(module(x), F.linear(x, w_ref)) +class TestLimitConsistency(unittest.TestCase): + """The Python export-side GEMM limit must match the C++ shim's GEMM_MAX_M. + + ``SHIM_GEMM_MAX_M`` is a hand-maintained Python copy of the C++ constant; if + they drift, export can capture a shape the runtime shim rejects (or block one + it supports). No CUDA needed -- this just compares the two source constants. + """ + + def test_shim_gemm_max_m_matches_cuh(self): + import os + import re + + import executorch.backends.cuda.int4_dispatch as int4_dispatch + + cuh = os.path.join( + os.path.dirname(int4_dispatch.__file__), + "runtime", + "shims", + "int4_plain_mm.cuh", + ) + with open(cuh) as f: + m = re.search(r"\bGEMM_MAX_M\s*=\s*(\d+)", f.read()) + self.assertIsNotNone(m, "GEMM_MAX_M not found in int4_plain_mm.cuh") + self.assertEqual( + int(m.group(1)), + int4_dispatch.SHIM_GEMM_MAX_M, + "SHIM_GEMM_MAX_M (int4_dispatch.py) is out of sync with the C++ " + "GEMM_MAX_M (int4_plain_mm.cuh)", + ) + + if __name__ == "__main__": unittest.main()