diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 31214bc0bf6..db54da91687 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -55,7 +55,11 @@ __host__ __forceinline__ int32_t log2_pow2(int32_t v) { // blocks) // --------------------------------------------------------------------------- -struct Q8Block { +// alignas(16) pads sizeof(Q8Block) to 48 so each block (and its qs_even/qs_odd +// 16-byte halves) is 16-byte aligned. This lets the matvec load a whole block's +// int8 activations with two vectorized uint4 loads instead of eight scalar +// int32 loads, cutting activation load instructions ~4x. +struct alignas(16) Q8Block { int8_t qs_even[Q8_BLOCK_SIZE / 2]; int8_t qs_odd[Q8_BLOCK_SIZE / 2]; float d; // scale @@ -149,6 +153,18 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t k_base = i * 32; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (32 weights) maps to exactly one Q8 activation block (32 + // activations), i.e. q8_block_idx == i. Load the whole block with two + // vectorized uint4 loads (+ one scale load) instead of eight scalar int32 + // loads. ae.{x,y,z,w} == qs_even[0:4],[4:8],[8:12],[12:16] == a_even for + // w=0..3 (same for ao/qs_odd) -> bit-identical to the scalar path. + const Q8Block* qb = &q8_row[i]; + uint4 ae = *reinterpret_cast(qb->qs_even); + uint4 ao = *reinterpret_cast(qb->qs_odd); + float a_scale = qb->d; + const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w}; + const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { uint32_t packed = words[w]; @@ -164,22 +180,11 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t vi_lo = packed & 0x0F0F0F0F; int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; - int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; - int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; - const Q8Block* qb = &q8_row[q8_block_idx]; - - int32_t a_even = - *reinterpret_cast(qb->qs_even + q8_half_offset); - int32_t a_odd = - *reinterpret_cast(qb->qs_odd + q8_half_offset); - - int32_t dp = __dp4a(vi_lo, a_even, 0); - dp = __dp4a(vi_hi, a_odd, dp); - - float a_scale = qb->d; + int32_t dp = __dp4a(vi_lo, static_cast(a_even[w]), 0); + dp = __dp4a(vi_hi, static_cast(a_odd[w]), dp); - int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); - a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + int32_t a_sum8 = __dp4a(0x01010101, static_cast(a_even[w]), 0); + a_sum8 = __dp4a(0x01010101, static_cast(a_odd[w]), a_sum8); sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum8)); diff --git a/backends/cuda/runtime/shims/int8_plain_mm.cuh b/backends/cuda/runtime/shims/int8_plain_mm.cuh index 2c478854644..8458c7680b5 100644 --- a/backends/cuda/runtime/shims/int8_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int8_plain_mm.cuh @@ -58,7 +58,11 @@ __host__ __forceinline__ int32_t log2_pow2_i8(int32_t v) { // blocks, NATURAL order — qs[k] holds the quantized value for element k). // --------------------------------------------------------------------------- -struct Q8BlockNat { +// alignas(16) pads sizeof(Q8BlockNat) 36->48 so each block (and its two 16-byte +// qs halves) is 16-byte aligned. This lets the matvec load 16 int8 activations +// with one vectorized uint4 load instead of four scalar int32 loads, cutting +// activation load instructions ~4x. +struct alignas(16) Q8BlockNat { int8_t qs[Q8_NAT_BLOCK_SIZE]; float d; // scale }; @@ -135,6 +139,17 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( int32_t k_base = i * 16; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (16 int8 weights) maps to exactly one 16-byte half of a Q8 + // activation block (16 activations): block i>>1, byte offset 0 (i even) or + // 16 (i odd). Load those 16 int8 activations with a single vectorized uint4 + // load (+ one scale load) instead of four scalar int32 loads + four scale + // reloads. av.{x,y,z,w} == qs[off+0:4],[4:8],[8:12],[12:16] == a_word for + // w=0..3 -> bit-identical to the scalar path. + const Q8BlockNat* qb = &q8_row[i >> 1]; + uint4 av = *reinterpret_cast(qb->qs + ((i & 1) ? 16 : 0)); + float a_scale = qb->d; + const uint32_t a_words[4] = {av.x, av.y, av.z, av.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { int32_t k_word = k_base + w * 4; // 4 int8 weights start here @@ -147,15 +162,10 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( } int32_t w_word = static_cast(words[w]); - - int32_t q8_block_idx = k_word / Q8_NAT_BLOCK_SIZE; - int32_t q8_offset = k_word % Q8_NAT_BLOCK_SIZE; - const Q8BlockNat* qb = &q8_row[q8_block_idx]; - int32_t a_word = *reinterpret_cast(qb->qs + q8_offset); + int32_t a_word = static_cast(a_words[w]); int32_t dp = __dp4a(w_word, a_word, 0); int32_t a_sum = __dp4a(0x01010101, a_word, 0); - float a_scale = qb->d; sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum));