Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 145 additions & 57 deletions ggml/src/ggml-cpu/arch/arm/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@
// precomputed tables for expanding 8bits to 8 bytes:
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4

#if defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
// Direct -1/+1 expansion for q1_0 dot products (DOTPROD path)
static const uint64_t table_q1_signs[256] = { B8(ff, 01) };
#endif
#if !defined(__ARM_FEATURE_DOTPROD)
// Sign mask expansion for q1_0 dot products (plain NEON path)
static const uint64_t table_q1_mask[256] = { B8(ff, 00) };
#endif
#endif

void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
Expand Down Expand Up @@ -138,11 +147,15 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
//===================================== Dot products =================================

void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK1_0; // 128
const int qk = QK1_0;
const int nb = n / qk;

assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
Expand All @@ -151,66 +164,141 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
const block_q1_0 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

#if defined(__ARM_NEON)
float32x4_t sumv = vdupq_n_f32(0.0f);
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q1_0 * GGML_RESTRICT vx0 = vx;
const block_q1_0 * GGML_RESTRICT vx1 = (const block_q1_0 *) ((const uint8_t *)vx + bx);
const block_q8_0 * GGML_RESTRICT vy0 = vy;
const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t *)vy + by);

for (int i = 0; i < nb; i++) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);

// Process 4 Q8_0 blocks (each has 32 elements)
for (int k = 0; k < 4; k++) {
const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k];
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);

// Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes)
// Bits are at offset k*4 bytes in x[i].qs
const uint8_t * bits = &x[i].qs[k * 4];

// Load 32 int8 values from y
const int8x16_t y0 = vld1q_s8(yb->qs);
const int8x16_t y1 = vld1q_s8(yb->qs + 16);

// Byte 0-1: bits for y0[0..15]
const uint64_t expand0 = table_b2b_0[bits[0]];
const uint64_t expand1 = table_b2b_0[bits[1]];
// Byte 2-3: bits for y1[0..15]
const uint64_t expand2 = table_b2b_0[bits[2]];
const uint64_t expand3 = table_b2b_0[bits[3]];

// Build the sign vectors by reinterpreting the table values
uint8x8_t e0 = vcreate_u8(expand0);
uint8x8_t e1 = vcreate_u8(expand1);
uint8x8_t e2 = vcreate_u8(expand2);
uint8x8_t e3 = vcreate_u8(expand3);

// Shift right by 4 to get 0 or 1
int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4));
int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4));
int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4));
int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4));

// Convert 0/1 to -1/+1: sign = 2*val - 1
int8x8_t one = vdup_n_s8(1);
s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1
s1 = vsub_s8(vadd_s8(s1, s1), one);
s2 = vsub_s8(vadd_s8(s2, s2), one);
s3 = vsub_s8(vadd_s8(s3, s3), one);

// Combine into 16-element vectors
int8x16_t signs0 = vcombine_s8(s0, s1);
int8x16_t signs1 = vcombine_s8(s2, s3);

// Multiply signs with y values and accumulate
// dot(signs, y) where signs are +1/-1
int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0);
int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1);

// Scale by d1 and accumulate
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1);
float32x4_t sumv0 = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i++) {
const uint8_t * GGML_RESTRICT bits0 = vx0[i].qs;
const uint8_t * GGML_RESTRICT bits1 = vx1[i].qs;
const float dx0 = GGML_CPU_FP16_TO_FP32(vx0[i].d);
const float dx1 = GGML_CPU_FP16_TO_FP32(vx1[i].d);

float32x4_t accv = vdupq_n_f32(0.0f);

for (int k = 0; k < 4; k++) {
const block_q8_0 * GGML_RESTRICT yb0 = &vy0[i * 4 + k];
const block_q8_0 * GGML_RESTRICT yb1 = &vy1[i * 4 + k];

const int8x16_t y0_0 = vld1q_s8(yb0->qs);
const int8x16_t y0_1 = vld1q_s8(yb0->qs + 16);
const int8x16_t y1_0 = vld1q_s8(yb1->qs);
const int8x16_t y1_1 = vld1q_s8(yb1->qs + 16);

const uint8_t * GGML_RESTRICT b0 = bits0 + 4 * k;
const uint8_t * GGML_RESTRICT b1 = bits1 + 4 * k;

const int8x16_t l0 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b0[0]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b1[0]])));
const int8x16_t l1 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b0[1]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b1[1]])));
const int8x16_t l2 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b0[2]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b1[2]])));
const int8x16_t l3 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b0[3]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[b1[3]])));

const int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_0), vreinterpretq_s64_s8(y1_0)));
const int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_0), vreinterpretq_s64_s8(y1_0)));
const int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_1), vreinterpretq_s64_s8(y1_1)));
const int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_1), vreinterpretq_s64_s8(y1_1)));

int32x4_t p = vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), l0, r0), l1, r1), l2, r2), l3, r3);

const float dy0 = GGML_CPU_FP16_TO_FP32(yb0->d);
const float dy1 = GGML_CPU_FP16_TO_FP32(yb1->d);
const float32x4_t scale_y = vcombine_f32(vset_lane_f32(dy1, vdup_n_f32(dy0), 1),
vset_lane_f32(dy1, vdup_n_f32(dy0), 1));
accv = vmlaq_f32(accv, vcvtq_f32_s32(p), scale_y);
}

const float32x4_t scale_x = vcombine_f32(vdup_n_f32(dx0), vdup_n_f32(dx1));
sumv0 = vmlaq_f32(sumv0, accv, scale_x);
}

float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);

vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));

return;
}
#endif

#if defined(__ARM_FEATURE_DOTPROD)
{
float32x4_t sumv = vdupq_n_f32(0.0f);

*s = vaddvq_f32(sumv);
for (int i = 0; i < nb; i++) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);
float32x4_t accv = vdupq_n_f32(0.0f);

for (int k = 0; k < 4; k++) {
const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k];
const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4];
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);

const int8x16_t y0 = vld1q_s8(yb->qs);
const int8x16_t y1 = vld1q_s8(yb->qs + 16);

const int8x16_t signs0 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[bits[0]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[bits[1]])));
const int8x16_t signs1 = vcombine_s8(vreinterpret_s8_u8(vcreate_u8(table_q1_signs[bits[2]])),
vreinterpret_s8_u8(vcreate_u8(table_q1_signs[bits[3]])));

int32x4_t p = vdupq_n_s32(0);
p = ggml_vdotq_s32(p, signs0, y0);
p = ggml_vdotq_s32(p, signs1, y1);

accv = vmlaq_n_f32(accv, vcvtq_f32_s32(p), d1);
}

sumv = vmlaq_n_f32(sumv, accv, d0);
}

*s = vaddvq_f32(sumv);
}
#elif defined(__ARM_NEON)
{
float32x4_t sumv = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i++) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);
float32x4_t accv = vdupq_n_f32(0.0f);

for (int k = 0; k < 4; k++) {
const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k];
const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4];
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);

const int8x16_t y0 = vld1q_s8(yb->qs);
const int8x16_t y1 = vld1q_s8(yb->qs + 16);

const int8x16_t sm0 = vreinterpretq_s8_u8(vcombine_u8(vcreate_u8(table_q1_mask[bits[0]]),
vcreate_u8(table_q1_mask[bits[1]])));
const int8x16_t sm1 = vreinterpretq_s8_u8(vcombine_u8(vcreate_u8(table_q1_mask[bits[2]]),
vcreate_u8(table_q1_mask[bits[3]])));

const int8x16_t sy0 = vsubq_s8(veorq_s8(y0, sm0), sm0);
const int8x16_t sy1 = vsubq_s8(veorq_s8(y1, sm1), sm1);

int32x4_t p = vdupq_n_s32(0);
p = vpadalq_s16(p, vpaddlq_s8(sy0));
p = vpadalq_s16(p, vpaddlq_s8(sy1));

accv = vmlaq_n_f32(accv, vcvtq_f32_s32(p), d1);
}

sumv = vmlaq_n_f32(sumv, accv, d0);
}

*s = vaddvq_f32(sumv);
}
#else
UNUSED(nb);
UNUSED(x);
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q1_0,
.vec_dot = ggml_vec_dot_q1_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
},
[GGML_TYPE_Q4_0] = {
.from_float = quantize_row_q4_0,
Expand Down