Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions backends/cuda/runtime/shims/int4_plain_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ __host__ __forceinline__ int32_t log2_pow2(int32_t v) {
// blocks)
// ---------------------------------------------------------------------------

struct Q8Block {
// alignas(16) pads sizeof(Q8Block) to 48 so each block (and its qs_even/qs_odd
// 16-byte halves) is 16-byte aligned. This lets the matvec load a whole block's
// int8 activations with two vectorized uint4 loads instead of eight scalar
// int32 loads, cutting activation load instructions ~4x.
struct alignas(16) Q8Block {
int8_t qs_even[Q8_BLOCK_SIZE / 2];
int8_t qs_odd[Q8_BLOCK_SIZE / 2];
float d; // scale
Expand Down Expand Up @@ -149,6 +153,18 @@ __global__ void __launch_bounds__(MV_THREADS)
int32_t k_base = i * 32;
uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w};

// One uint4 (32 weights) maps to exactly one Q8 activation block (32
// activations), i.e. q8_block_idx == i. Load the whole block with two
// vectorized uint4 loads (+ one scale load) instead of eight scalar int32
// loads. ae.{x,y,z,w} == qs_even[0:4],[4:8],[8:12],[12:16] == a_even for
// w=0..3 (same for ao/qs_odd) -> bit-identical to the scalar path.
const Q8Block* qb = &q8_row[i];
uint4 ae = *reinterpret_cast<const uint4*>(qb->qs_even);
uint4 ao = *reinterpret_cast<const uint4*>(qb->qs_odd);
float a_scale = qb->d;
const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w};
const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w};

#pragma unroll
for (int32_t w = 0; w < 4; w++) {
uint32_t packed = words[w];
Expand All @@ -164,22 +180,11 @@ __global__ void __launch_bounds__(MV_THREADS)
int32_t vi_lo = packed & 0x0F0F0F0F;
int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F;

int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE;
int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2;
const Q8Block* qb = &q8_row[q8_block_idx];

int32_t a_even =
*reinterpret_cast<const int32_t*>(qb->qs_even + q8_half_offset);
int32_t a_odd =
*reinterpret_cast<const int32_t*>(qb->qs_odd + q8_half_offset);

int32_t dp = __dp4a(vi_lo, a_even, 0);
dp = __dp4a(vi_hi, a_odd, dp);

float a_scale = qb->d;
int32_t dp = __dp4a(vi_lo, static_cast<int32_t>(a_even[w]), 0);
dp = __dp4a(vi_hi, static_cast<int32_t>(a_odd[w]), dp);

int32_t a_sum8 = __dp4a(0x01010101, a_even, 0);
a_sum8 = __dp4a(0x01010101, a_odd, a_sum8);
int32_t a_sum8 = __dp4a(0x01010101, static_cast<int32_t>(a_even[w]), 0);
a_sum8 = __dp4a(0x01010101, static_cast<int32_t>(a_odd[w]), a_sum8);

sum += ws * a_scale *
(static_cast<float>(dp) - wz * static_cast<float>(a_sum8));
Expand Down
24 changes: 17 additions & 7 deletions backends/cuda/runtime/shims/int8_plain_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ __host__ __forceinline__ int32_t log2_pow2_i8(int32_t v) {
// blocks, NATURAL order — qs[k] holds the quantized value for element k).
// ---------------------------------------------------------------------------

struct Q8BlockNat {
// alignas(16) pads sizeof(Q8BlockNat) 36->48 so each block (and its two 16-byte
// qs halves) is 16-byte aligned. This lets the matvec load 16 int8 activations
// with one vectorized uint4 load instead of four scalar int32 loads, cutting
// activation load instructions ~4x.
struct alignas(16) Q8BlockNat {
int8_t qs[Q8_NAT_BLOCK_SIZE];
float d; // scale
};
Expand Down Expand Up @@ -135,6 +139,17 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel(
int32_t k_base = i * 16;
uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w};

// One uint4 (16 int8 weights) maps to exactly one 16-byte half of a Q8
// activation block (16 activations): block i>>1, byte offset 0 (i even) or
// 16 (i odd). Load those 16 int8 activations with a single vectorized uint4
// load (+ one scale load) instead of four scalar int32 loads + four scale
// reloads. av.{x,y,z,w} == qs[off+0:4],[4:8],[8:12],[12:16] == a_word for
// w=0..3 -> bit-identical to the scalar path.
const Q8BlockNat* qb = &q8_row[i >> 1];
uint4 av = *reinterpret_cast<const uint4*>(qb->qs + ((i & 1) ? 16 : 0));
float a_scale = qb->d;
const uint32_t a_words[4] = {av.x, av.y, av.z, av.w};

#pragma unroll
for (int32_t w = 0; w < 4; w++) {
int32_t k_word = k_base + w * 4; // 4 int8 weights start here
Expand All @@ -147,15 +162,10 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel(
}

int32_t w_word = static_cast<int32_t>(words[w]);

int32_t q8_block_idx = k_word / Q8_NAT_BLOCK_SIZE;
int32_t q8_offset = k_word % Q8_NAT_BLOCK_SIZE;
const Q8BlockNat* qb = &q8_row[q8_block_idx];
int32_t a_word = *reinterpret_cast<const int32_t*>(qb->qs + q8_offset);
int32_t a_word = static_cast<int32_t>(a_words[w]);

int32_t dp = __dp4a(w_word, a_word, 0);
int32_t a_sum = __dp4a(0x01010101, a_word, 0);
float a_scale = qb->d;

sum += ws * a_scale *
(static_cast<float>(dp) - wz * static_cast<float>(a_sum));
Expand Down
Loading