From eda40b8b50d6492364d015767386ae986d1ee743 Mon Sep 17 00:00:00 2001 From: uddeshsingh Date: Thu, 11 Jun 2026 23:12:53 -0500 Subject: [PATCH 1/5] Add fused Q4_K Metal kernels for GGUF lowering (#20172) Replace the export-time GGUF-to-MLX qparam repack path with fused Metal kernels --- .../mlx/custom_kernel_ops/gguf/patterns.py | 14 +- .../custom_kernel_ops/gguf/q4k/__init__.py | 2 +- .../mlx/custom_kernel_ops/gguf/q4k/common.py | 150 ++++- .../custom_kernel_ops/gguf/q4k/embedding.py | 116 +++- .../mlx/custom_kernel_ops/gguf/q4k/linear.py | 535 ++++++++++++++++-- .../gguf/test/test_embedding.py | 38 +- .../gguf/test/test_linear.py | 56 +- 7 files changed, 768 insertions(+), 143 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index 7d3a5bc307c..e3a9fea97b8 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -18,8 +18,7 @@ lower it without materializing the dequantized weight: * **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`. -* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks - repacked into MLX qparams at export time). +* **Q4_K** -> fused custom Metal kernels in :mod:`.q4k`. Both cover linear and embedding. @@ -42,8 +41,7 @@ from torch.export.exported_program import ExportedProgram from torch.fx.node import Node -# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via -# MLX-native affine ops). +# Quant types each pattern can lower (both via fused custom Metal kernels). _LINEAR_TYPES = {"q4_k", "q6_k"} _EMBEDDING_TYPES = {"q4_k", "q6_k"} @@ -79,8 +77,8 @@ class GGUFQuantizedLinearHandler(PatternHandler): """Lower ``dequantize_gguf + linear`` to a fused quantized matmul. Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)`` - and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX - 4-bit ``quantized_matmul``. + and dispatches on ``ggml_type``: Q6_K / Q4_K -> custom Metal kernels in + :mod:`.q6k` / :mod:`.q4k`. """ def __init__(self, head, body, weight, ggml_type, output_dtype): @@ -126,8 +124,8 @@ class GGUFQuantizedEmbeddingHandler(PatternHandler): """Lower ``dequantize_gguf + embedding`` to a quantized gather. Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)`` - and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX - quantized gather. + and dispatches on ``ggml_type``: Q6_K / Q4_K -> custom Metal gather kernels + in :mod:`.q6k` / :mod:`.q4k`. """ def __init__(self, head, body, weight, ggml_type, output_dtype): diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py index 6f89cfe2c82..30fbcf96118 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit). +"""GGUF Q4_K format lowering for the MLX backend (fused Metal kernels). See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by ``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py index d58a8b71afd..4a234359360 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py @@ -6,41 +6,139 @@ # LICENSE file in the root directory of this source tree. # -"""Shared Q4_K -> MLX qparam repack for the Q4_K lowering. +"""Shared GGUF **Q4_K** primitives for the MLX backend. -Q4_K maps cleanly onto MLX's affine 4-bit kernels (group_size 32): the GGUF -blocks are unpacked to the torchao ``IntxUnpackedToInt8Tensor`` layout and -repacked into MLX qparams (``S * Q + B``) at export time, so the weight is -stored MLX-ready and decoded by MLX itself. +This module holds the pieces common to every Q4_K kernel (linear matmul/matvec +and the embedding gather): + +* ``QK_K`` / ``Q4K_BLOCK_BYTES`` and the per-super-block byte layout constants. +* ``_Q4K_HEADER`` -- the Metal header (the ``block_q4_K`` struct plus the + per-element and vectorized dequant helpers) shared by all Q4_K Metal kernels. + +Q4_K layout (per 256-element super-block, 144 bytes, see llama.cpp +``block_q4_K`` in ``ggml-common.h``):: + + half d # super-block scale for quantized scales + half dmin # super-block scale for quantized mins + uint8 scales[12] # 6-bit packed scales + mins + uint8 qs[128] # 4-bit quants + +The dequantized value for a 4-bit code ``q`` in sub-block ``s`` is +``d * scale[s] * q - dmin * min[s]`` (affine). + +Attribution +----------- +The Q4_K block layout and the Metal dequant helpers in ``_Q4K_HEADER`` follow +llama.cpp +(``ggml-common.h`` / ``ggml-metal.metal``: ``block_q4_K``, ``dequantize_q4_K``, +``get_scale_min_k4``), which is MIT-licensed (Copyright (c) 2023-2024 The ggml +authors). """ from __future__ import annotations -from typing import Tuple +# --------------------------------------------------------------------------- +# Q4_K constants +# --------------------------------------------------------------------------- + +QK_K = 256 +K_SCALE_SIZE = 12 +_Q4K_D_BYTES = 2 +_Q4K_DMIN_BYTES = 2 +_Q4K_SCALES_BYTES = K_SCALE_SIZE +_Q4K_QS_BYTES = QK_K // 2 # 128 +Q4K_BLOCK_BYTES = ( + _Q4K_D_BYTES + _Q4K_DMIN_BYTES + _Q4K_SCALES_BYTES + _Q4K_QS_BYTES +) # 144 + +# Q4_K mat-mat uses NL = QK_K / 32 (8 sub-blocks of 32 elements). +Q4K_NL = QK_K // 32 # 8 + +# --------------------------------------------------------------------------- +# Shared Metal header +# --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams -from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot -from torch.fx.node import Node +# Ported from llama.cpp ggml-common.h (block_q4_K, get_scale_min_k4) and +# ggml-metal.metal (dequantize_q4_K). Struct field order matches GGUF bytes: +# d(0:2), dmin(2:4), scales(4:16), qs(16:144). +_Q4K_HEADER = """ +#include +#include +using namespace metal; -_BITS = 4 +#define QK_K 256 +#define K_SCALE_SIZE 12 +typedef struct { + half d; + half dmin; + uint8_t scales[K_SCALE_SIZE]; + uint8_t qs[QK_K/2]; +} block_q4_K; -def _repack_mlx( - P: MLXProgramBuilder, weight_node: Node -) -> Tuple[Slot, Slot, Slot, int]: - """Unpack a raw Q4_K blob and repack into MLX qparam constants. +// Unpack 6-bit scale and min for sub-block index j (0..7). +// Ported from llama.cpp get_scale_min_k4 (ggml-quants.c). +inline void get_scale_min_k4(int j, device const uint8_t * q, + thread uint8_t & sc, thread uint8_t & m) { + if (j < 4) { + sc = q[j] & 63; + m = q[j + 4] & 63; + } else { + sc = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + } +} - Returns ``(packed_slot, scales_slot, biases_slot, group_size)``. - """ - from executorch.extension.llm.export.gguf import ExportableGGUFTensor +// Metal variant used by dequantize_q4_K_16 (matches ggml-metal.metal). +inline uchar2 get_scale_min_k4_just2(int j, int k, device const uint8_t * q) { + return j < 4 + ? uchar2{uint8_t(q[j + 0 + k] & 63), uint8_t(q[j + 4 + k] & 63)} + : uchar2{ + uint8_t((q[j + 4 + k] & 0xF) | ((q[j - 4 + k] >> 6) << 4)), + uint8_t((q[j + 4 + k] >> 4) | ((q[j + 0 + k] >> 6) << 4))}; +} - weight_target, raw = P.get_placeholder_target_and_tensor(weight_node) - intx = ExportableGGUFTensor.from_raw(raw, "q4_k").to_intx_unpacked_to_int8_tensor() - group_size = int(intx.block_size[-1]) - packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, _BITS) +// Dequantize a single element at within-block position p (0..255). +// Mirrors dequantize_row_q4_K (ggml-quants.c): 64-element chunks, 32 lows +// then 32 highs per chunk, each half using its own scale/min pair. +inline float dequant_q4k_elem(device const block_q4_K * blk, int p) { + const int chunk = p >> 6; // 0..3 (64-element groups) + const int sub = p & 63; // 0..63 within chunk + const int q_idx = (chunk << 5) + (sub & 31); + device const uint8_t * q = blk->qs + q_idx; - packed_slot = P.make_or_get_constant(f"{weight_target}_q4k_packed", packed) - scales_slot = P.make_or_get_constant(f"{weight_target}_q4k_scales", intx.scale) - biases_slot = P.make_or_get_constant(f"{weight_target}_q4k_biases", biases) - return packed_slot, scales_slot, biases_slot, group_size + uint8_t sc, mn; + get_scale_min_k4((chunk << 1) + (sub >= 32 ? 1 : 0), blk->scales, sc, mn); + + const float d = (float) blk->d; + const float dm = (float) blk->dmin; + const float dl = d * (float) sc; + const float ml = dm * (float) mn; + + const uint8_t nib = (sub < 32) ? (q[0] & 0xF) : (q[0] >> 4); + return dl * (float) nib - ml; +} + +// Vectorized Q4_K dequantize: decodes 16 values into half4x4. +// Ported from llama.cpp dequantize_q4_K (ggml-metal.metal). +// il ranges 0..7 (Q4_K uses NL=8, not Q6_K's NL=16). +inline void dequantize_q4_K_16(device const block_q4_K * xb, short il, + thread half4x4 & reg) { + device const uint8_t * q = xb->qs; + + short is = (il / 4) * 2; + q = q + (il / 4) * 32 + 16 * (il & 1); + il = il & 3; + + const uchar2 sc = get_scale_min_k4_just2(is, il / 2, xb->scales); + const float d = il < 2 ? (float) xb->d : (float) xb->d / 16.f; + const float dm = (float) xb->dmin; + const float dl = d * (float) sc[0]; + const float ml = dm * (float) sc[1]; + + const ushort mask = il < 2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i / 4][i % 4] = (half)(dl * (float)(q[i] & mask) - ml); + } +} +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py index 7b5bbcff0e1..35ee1db4242 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -6,50 +6,116 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF **Q4_K** embedding lowering via MLX's native 4-bit quantized gather. +"""GGUF **Q4_K** embedding lowering for the MLX GGUF pattern handler. -Lowers a ``dequantize_gguf -> embedding`` pattern to a quantized gather: gather -the packed quants / scales / biases by index, then dequantize the gathered rows -(``DequantizeNode``, mode "affine"). The GGUF blob is repacked into MLX qparams -at export time (see :mod:`.common`). +Lowers a ``dequantize_gguf -> embedding`` pattern to a fused gather Metal kernel +that reads raw ``block_q4_K`` bytes directly (same approach as :mod:`..q6k.embedding`). """ from __future__ import annotations -from executorch.backends.mlx.builder.op_helpers import emit_quantized_gather +import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( + _Q4K_HEADER, + Q4K_BLOCK_BYTES, + QK_K, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) from torch.fx.node import Node +# --------------------------------------------------------------------------- +# Metal kernel source +# --------------------------------------------------------------------------- + + +# One thread per output element. grid = (K, num_idx, 1): x picks the feature j, +# y picks the gathered row; each thread dequantizes a single Q4_K element. +_Q4K_EMBED_SOURCE = """ + const uint j = thread_position_in_grid.x; // 0..K-1 + const uint r = thread_position_in_grid.y; // gathered row + const int row = (int) indices[r]; + const int nb = K / QK_K; + device const block_q4_K * blk = + ((device const block_q4_K *) weight) + (uint)row * nb + (j / QK_K); + out[r * (uint)K + j] = (OutT) dequant_q4k_elem(blk, j % QK_K); +""" + + def emit_embedding( P: MLXProgramBuilder, head: Node, weight_node: Node, indices_node: Node, - output_dtype, + output_dtype: torch.dtype, ) -> Slot: - """Lower a Q4_K ``dequantize_gguf -> embedding`` pattern to a quantized gather. + """Lower a Q4_K ``dequantize_gguf`` -> ``embedding`` pattern to a fused gather. - Gathers the packed quants / scales / biases by index, then dequantizes the - gathered rows (MLX affine 4-bit) -- the same shape as MLX's generic quantized - embedding. + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.embedding`` node that owns the output slot. """ - w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) - (indices_slot,) = P.slot_map([indices_node]) + weight_slot, indices_slot = P.slot_map([weight_node, indices_node]) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"gguf q4k embedding: weight must be 2-D (vocab, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + row_bytes = weight_meta.shape[1] + if not isinstance(row_bytes, int): + raise NotImplementedError( + "gguf q4k embedding: weight shape must be statically known" + ) + if row_bytes % Q4K_BLOCK_BYTES != 0: + raise ValueError( + f"gguf q4k embedding: weight row bytes {row_bytes} must be a " + f"multiple of {Q4K_BLOCK_BYTES}" + ) + K = (row_bytes // Q4K_BLOCK_BYTES) * QK_K + + out_dtype_int = torch_dtype_to_scalar_type(output_dtype) out = P.make_or_get_slot(head) - emit_quantized_gather( - P, - out, - indices_slot, - w_slot, - scales_slot, - biases_slot, - group_size=group_size, - bits=_BITS, - mode="affine", - out_dtype=output_dtype, + leading = emit_shape(P, indices_node, indices_slot, end_dim=None) + num_idx_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(K)] + + # threadgroup.x must divide grid.x (= K, a multiple of 256). + tg_x = 256 if K % 256 == 0 else K + + P.emit( + MetalKernelNode( + name="gguf_q4k_embedding", + source=_Q4K_EMBED_SOURCE, + header=_Q4K_HEADER, + inputs=[P.slot_to_tid(weight_slot), P.slot_to_tid(indices_slot)], + outputs=[P.slot_to_tid(out)], + grid=[IntOrVid.from_literal(K), num_idx_iov, IntOrVid.from_literal(1)], + threadgroup=[ + IntOrVid.from_literal(tg_x), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["weight", "indices"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[out_dtype_int], + template_arg_names=["OutT", "K"], + template_arg_kinds=[2, 0], # dtype, int + template_arg_values=[out_dtype_int, K], + ) ) + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py index 41d032a2d4a..0078b029255 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -6,29 +6,416 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF **Q4_K** linear lowering via MLX's native 4-bit quantized matmul. +"""GGUF **Q4_K** linear implementation. -Lowers a ``dequantize_gguf -> linear`` pattern to a ``QuantizedMatmulNode`` -(mode "affine", group_size 32); the GGUF blob is repacked into MLX qparams at -export time (see :mod:`.common`). +Same structure as :mod:`..q6k.linear`: mat-vec (M==1), mat-mat (M>1), IfNode (dynamic M). +Kernels ported from llama.cpp ``kernel_mul_mv_q4_K_f32_impl`` and ``kernel_mul_mm`` (Q4_K). """ from __future__ import annotations from typing import Optional -from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( + _Q4K_HEADER, + Q4K_BLOCK_BYTES, + QK_K, +) from executorch.backends.mlx.serialization.mlx_graph_schema import ( - AddNode, - AsTypeNode, - QuantizedMatmulNode, + AddIntNode, + FloorDivideIntNode, + IfNode, + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SubtractIntNode, ) from torch.fx.node import Node +# --------------------------------------------------------------------------- +# Metal kernel sources +# --------------------------------------------------------------------------- + + +# Decode mat-vec kernel, ported from llama.cpp kernel_mul_mv_q4_K_f32_impl. +# Threadgroup = (32 * NSG, 1, 1): NSG simdgroups, each computing N_R0 output +# rows for one activation row (grid.y). Accumulate in float, reduce via simd_sum. +def _q4k_matvec_source(has_bias: bool) -> str: + write = "out[(uint)m * N + r] = (InT)(tot" + write += " + (float)bias[r]);" if has_bias else ");" + return f""" + constexpr short N_R0 = 2; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; + + const ushort tiisg = thread_index_in_simdgroup; + const ushort sgitg = simdgroup_index_in_threadgroup; + const uint m = thread_position_in_grid.y; + const uint tgx = thread_position_in_grid.x / (32u * NSG); + const int nb = K / QK_K; + const int first_row = (int)(tgx * NSG + sgitg) * N_R0; + + const short ix = tiisg / 8; + const short it = tiisg % 8; + const short iq = it / 4; + const short ir = it % 4; + + device const block_q4_K * xrows = (device const block_q4_K *) weight; + device const InT * yy = x + (uint)m * (uint)K; + device const InT * y4 = yy + ix * QK_K + 64 * iq + 8 * ir; + + float sumf[N_R0]; + for (short row = 0; row < N_R0; ++row) {{ sumf[row] = 0.f; }} + + float yl[16]; + float yh[16]; + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) {{ + float4 sumy = {{0.f, 0.f, 0.f, 0.f}}; + for (short i = 0; i < 8; ++i) {{ + yl[i+0] = (float) y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = (float) y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = (float) y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = (float) y4[i+160]; sumy[3] += yh[i+8]; + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + if (r >= N) {{ break; }} + + device const block_q4_K * blk = xrows + (uint)r * nb + ib; + device const uint16_t * sc = (device const uint16_t *)blk->scales + iq; + device const uint16_t * q1 = (device const uint16_t *)blk->qs + 16 * iq + 4 * ir; + device const uint16_t * q2 = q1 + 32; + const float d = (float) blk->d; + const float dm = (float) blk->dmin; + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + float4 acc1 = {{0.f, 0.f, 0.f, 0.f}}; + float4 acc2 = {{0.f, 0.f, 0.f, 0.f}}; + for (short i = 0; i < 4; ++i) {{ + acc1[0] += yl[2*i + 0] * (float)(q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (float)(q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (float)(q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (float)(q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (float)(q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (float)(q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (float)(q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (float)(q2[i] & 0xF000); + }} + + sumf[row] += d * ((acc1[0] + acc1[1] / 256.f) * (float)sc8[0] + + (acc1[2] + acc1[3] / 256.f) * (float)sc8[1] / 16.f + + (acc2[0] + acc2[1] / 256.f) * (float)sc8[4] + + (acc2[2] + acc2[3] / 256.f) * (float)sc8[5] / 16.f) - + dm * (sumy[0] * (float)sc8[2] + sumy[1] * (float)sc8[3] + + sumy[2] * (float)sc8[6] + sumy[3] * (float)sc8[7]); + }} + + y4 += 4 * QK_K; + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && r < N) {{ + {write} + }} + }} +""" + + +# Prefill mat-mat kernel, ported from llama.cpp kernel_mul_mm (Q4_K variant). +# 64x32 output tiles, 4 simdgroups / 128 threads per threadgroup. +# Uses vectorized dequantize_q4_K_16 to decode 16 weight values per thread +# into threadgroup memory, then runs simdgroup_multiply_accumulate on 8x8 +# tiles. NL=8 for Q4_K (QK_K / 32 = 8 dequant steps per super-block). +def _q4k_matmul_source(has_bias: bool) -> str: + bias_add = "+ (float) bias[r0 + i]" if has_bias else "" + return f""" + constexpr short NR0 = 64; // weight/output rows per tile (N dim) + constexpr short NR1 = 32; // activation rows per tile (M dim) + constexpr short NK = 32; // K-chunk per iteration + constexpr short NL = 16; // Q4_K: QK_K / 32 + constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight + constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation + + threadgroup half sa[4096]; // NR0 * NK storage (strided by 64) + threadgroup half sb[4096]; // NR1 * NK storage (strided by 64) + + const ushort tid = thread_index_in_threadgroup; // 0..127 + const ushort sgitg = simdgroup_index_in_threadgroup; // 0..3 + + const uint r0 = thread_position_in_grid.y * NR0; // first weight row + const uint r1 = (thread_position_in_grid.x / 128u) * NR1; // first activation row + + // M (number of activation rows) read at runtime. + int M = 1; + for (uint d = 0; d + 1 < x_ndim; ++d) {{ M *= (int) x_shape[d]; }} + + const int nb = K / QK_K; + + // Clamp tile edges. + const short nr0 = (N - (int)r0 < NR0) ? (N - (int)r0) : NR0; + const short nr1 = (M - (int)r1 < NR1) ? (M - (int)r1) : NR1; + + // Thread → element mapping for cooperative loads. + const short lr0 = ((short)(tid / NL0) < nr0) ? (short)(tid / NL0) : (nr0 - 1); // 0..63 + const short lr1 = ((short)(tid / NL1) < nr1) ? (short)(tid / NL1) : (nr1 - 1); // 0..31 + + short il0 = tid % NL0; + short il = il0; // current dequant sub-block index within Q4_K block + + const short offset1 = il0 / NL; // always 0 for NL=8, NL0=4 + + // Pointer to weight block for this thread's assigned row. + device const block_q4_K * wblk = (device const block_q4_K *) weight + + (uint)(r0 + lr0) * nb + offset1; + + // Pointer to activation row for this thread. + const short iy = 8 * (tid % NL1); + device const InT * yp = x + (uint)(r1 + lr1) * (uint)K + iy; + + // Accumulator: 8 simdgroup 8x8 matrices (4 sgitg configs x 2 sub-tiles). + simdgroup_half8x8 ma[4]; + simdgroup_half8x8 mb[2]; + simdgroup_float8x8 mc[8]; + for (short i = 0; i < 8; ++i) {{ + mc[i] = make_filled_simdgroup_matrix(0.f); + }} + + for (int loop_k = 0; loop_k < K; loop_k += NK) {{ + // --- Cooperative load: dequantized weight tile (NR0 x NK) into sa --- + half4x4 temp_a; + dequantize_q4_K_16(wblk, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 16; ++i) {{ + const short sx = 2 * il0 + i / 8; + const short sy = (tid / NL0) / 8; + const short lx = (tid / NL0) % 8; + const short ly = i % 8; + const short ib = 8 * sx + sy; + *(sa + 64 * ib + 8 * ly + lx) = temp_a[i / 4][i % 4]; + }} + + // --- Cooperative load: activation tile (NR1 x NK) into sb --- + const short sx_b = tid % NL1; + const short sy_b = (tid / NL1) / 8; + const short ly_b = (tid / NL1) % 8; + const short ib_b = 4 * sx_b + sy_b; + + for (short i = 0; i < 8; ++i) {{ + *(sb + 64 * ib_b + 8 * ly_b + i) = (half) *(yp + i); + }} + + // Advance weight pointer through Q4_K sub-blocks. + il = (il + 2 < NL) ? il + 2 : il % 2; + wblk = (il < 2) ? wblk + (2 + NL - 1) / NL : wblk; + + yp += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // --- Simdgroup matmul on loaded tiles --- + threadgroup const half * lsma = sa + 4 * 64 * (sgitg % 2); + threadgroup const half * lsmb = sb + 2 * 64 * (sgitg / 2); + + for (short ik = 0; ik < NK / 8; ++ik) {{ + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 4; ++i) {{ + simdgroup_load(ma[i], lsma + 64 * i, 8, ulong2(0, 0), false); + }} + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 2; ++i) {{ + simdgroup_load(mb[i], lsmb + 64 * i, 8, ulong2(0, 0), false); + }} + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 8; ++i) {{ + simdgroup_multiply_accumulate(mc[i], mb[i / 4], ma[i % 4], mc[i]); + }} + lsma += 8 * 64; + lsmb += 4 * 64; + }} + }} + + // --- Write results: always via threadgroup memory for float→InT cast --- + // Barrier needed: sa was used for weight tiles during the K-loop and is now + // reused as float staging for the output. Without this barrier, a fast + // simdgroup could start writing mc[] into sa while a slower one is still + // reading the last weight tile via simdgroup_load(ma[]). + // (Mirrors the barrier in llama.cpp kernel_mul_mm's bounds-checked write path.) + threadgroup_barrier(mem_flags::mem_threadgroup); + {{ + threadgroup float * temp_str = ((threadgroup float *) sa) + + 32 * (sgitg & 1) + (16 * (sgitg >> 1)) * NR0; + for (short i = 0; i < 8; ++i) {{ + simdgroup_store(mc[i], temp_str + 8 * (i % 4) + 8 * NR0 * (i / 4), + NR0, ulong2(0, 0), false); + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) {{ + for (int j = tid; j < nr1; j += NR1) {{ + device InT * D = out + (uint)(r1 + j) * (uint)N + r0; + threadgroup float * Cp = ((threadgroup float *) sa) + j * NR0; + for (int i = 0; i < nr0; ++i) {{ + float v = Cp[i]; + D[i] = (InT)(v {bias_add}); + }} + }} + }} + }} +""" + + +# Number of simdgroups per threadgroup for the mat-vec kernel. +_Q4K_MV_NSG = 4 +# Tile sizes for the mat-mat kernel (from llama.cpp kernel_mul_mm). +_Q4K_MM_NR0 = 64 # weight/output rows (N dim) per threadgroup +_Q4K_MM_NR1 = 32 # activation rows (M dim) per threadgroup + + +def _emit_q4k_matvec( + P: MLXProgramBuilder, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + M_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + n_r0 = 2 + nsg = _Q4K_MV_NSG + num_row_groups = (N + nsg * n_r0 - 1) // (nsg * n_r0) + grid_x = num_row_groups * 32 * nsg + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + P.emit( + MetalKernelNode( + name="gguf_q4k_matvec", + source=_q4k_matvec_source(has_bias), + header=_Q4K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(grid_x), + M_iov, + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(32 * nsg), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K", "NSG"], + template_arg_kinds=[2, 0, 0, 0], # dtype, int, int, int + template_arg_values=[in_dtype_int, N, K, nsg], + ) + ) + + +def _emit_q4k_matmul( + P: MLXProgramBuilder, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + blocks_m_iov: IntOrVid, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + # grid.x = ceil(M / NR1) * 128 threads (activation tiles) + # grid.y = ceil(N / NR0) (weight tiles) + blocks_n = (N + _Q4K_MM_NR0 - 1) // _Q4K_MM_NR0 + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + # blocks_m_iov = ceil(M / NR1); multiply by 128 for grid.x + _, grid_x_slot = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=blocks_m_iov, + b=IntOrVid.from_literal(128), + out=P.slot_to_vid(grid_x_slot), + ) + ) + grid_x_iov = IntOrVid.from_vid(P.slot_to_vid(grid_x_slot)) + + P.emit( + MetalKernelNode( + name="gguf_q4k_matmul", + source=_q4k_matmul_source(has_bias), + header=_Q4K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + grid_x_iov, + IntOrVid.from_literal(blocks_n), + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(128), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K"], + template_arg_kinds=[2, 0, 0], + template_arg_values=[in_dtype_int, N, K], + ) + ) + + def emit_linear( P: MLXProgramBuilder, head: Node, @@ -36,47 +423,119 @@ def emit_linear( weight_node: Node, bias_node: Optional[Node], ) -> Slot: - """Lower a Q4_K ``dequantize_gguf -> linear`` pattern to MLX 4-bit matmul. + """Lower a Q4_K ``dequantize_gguf`` -> ``linear`` pattern to fused kernels. - ``weight_node`` is the raw GGUF blob constant; ``head`` is the ``aten.linear`` - node. The blob is repacked into MLX qparams at export time, so only the - MLX-format constants are serialized. + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.linear`` node that owns the output slot. """ - w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) - x_slot, bias_slot = P.slot_map([x_node, bias_node]) + x_slot, weight_slot, bias_slot = P.slot_map([x_node, weight_node, bias_node]) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"gguf q4k linear: weight must be 2-D (N, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + N = weight_meta.shape[0] + row_bytes = weight_meta.shape[1] + if not isinstance(N, int) or not isinstance(row_bytes, int): + raise NotImplementedError( + "gguf q4k linear: weight shape must be statically known" + ) + if row_bytes % Q4K_BLOCK_BYTES != 0: + raise ValueError( + f"gguf q4k linear: weight row bytes {row_bytes} must be a multiple of " + f"{Q4K_BLOCK_BYTES}" + ) + K = (row_bytes // Q4K_BLOCK_BYTES) * QK_K + + # Determine M (product of x's leading dims). Static M lets us pick the + # optimal kernel and (for mat-mat) compute a literal launch grid. + x_meta = x_node.meta["val"] + leading_dims = x_meta.shape[:-1] + M: Optional[int] = 1 + for d in leading_dims: + if isinstance(d, int): + M *= d + else: + M = None # dynamic / symbolic + break out = P.make_or_get_slot(head) - P.emit( - QuantizedMatmulNode( - x=P.slot_to_tid(x_slot), - w=P.slot_to_tid(w_slot), - scales=P.slot_to_tid(scales_slot), - biases=P.slot_to_tid(biases_slot), - out=P.slot_to_tid(out), - group_size=group_size, - bits=_BITS, - mode="affine", - transpose=True, + tile = _Q4K_MM_NR1 # M-dimension tile (activation rows per threadgroup) + if M == 1: + # Static decode -> mat-vec. + _emit_q4k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) + elif M is not None: + # Static prefill -> tiled simdgroup mat-mat (literal grid). + blocks_m = (M + tile - 1) // tile + _emit_q4k_matmul( + P, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + IntOrVid.from_literal(blocks_m), + out, ) - ) + else: + # Dynamic seqlen -> emit both kernels in separate chains and select at + # runtime with an IfNode. cond = M - 1: nonzero (M>1) runs the mat-mat + # (then) chain, zero (M==1) runs the mat-vec (else) chain. + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + m_iov = emit_product(P, leading) - if bias_node is not None: + _, cond_slot = P.make_tmp_value_slot() P.emit( - AddNode( - a=P.slot_to_tid(out), - b=P.slot_to_tid(bias_slot), - out=P.slot_to_tid(out), + SubtractIntNode( + a=m_iov, + b=IntOrVid.from_literal(1), + out=P.slot_to_vid(cond_slot), ) ) + cond_iov = IntOrVid.from_vid(P.slot_to_vid(cond_slot)) - out_dtype = head.meta["val"].dtype - if out_dtype != x_node.meta["val"].dtype: + # blocks_m = (M + tile - 1) // tile (mat-mat grid.y). + _, sum_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=m_iov, + b=IntOrVid.from_literal(tile - 1), + out=P.slot_to_vid(sum_slot), + ) + ) + _, blocks_m_slot = P.make_tmp_value_slot() P.emit( - AsTypeNode( - x=P.slot_to_tid(out), - out=P.slot_to_tid(out), - scalar_type=torch_dtype_to_scalar_type(out_dtype), + FloorDivideIntNode( + a=IntOrVid.from_vid(P.slot_to_vid(sum_slot)), + b=IntOrVid.from_literal(tile), + out=P.slot_to_vid(blocks_m_slot), ) ) + blocks_m_iov = IntOrVid.from_vid(P.slot_to_vid(blocks_m_slot)) + with P.new_chain() as then_idx: # prefill / mat-mat + _emit_q4k_matmul( + P, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + blocks_m_iov, + out, + ) + with P.new_chain() as else_idx: # decode / mat-vec + _emit_q4k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) + + P.emit( + IfNode( + cond=cond_iov, + then_chain_idx=then_idx, + else_chain_idx=else_idx, + ) + ) return out diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index 3f8e60b7aa8..12acd4ebc75 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -6,13 +6,13 @@ # LICENSE file in the root directory of this source tree. """ -Tests for the GGUF Q6_K embedding lowering. +Tests for the GGUF Q6_K / Q4_K embedding lowering. An ``nn.Embedding`` whose weight is an ``ExportableGGUFTensor`` exports to -``embedding(torchao::dequantize_gguf(weight, "q6_k", ...), indices)``. The MLX -``GGUF_QUANTIZED_EMBEDDING`` pattern matches that subgraph and lowers it to the -fused Q6_K gather Metal kernel. These tests compare the kernel against the eager -reference (``gguf``-package dequant + ``F.embedding``) on the same packed table. +``embedding(torchao::dequantize_gguf(weight, ggml_type, ...), indices)``. The MLX +``GGUF_QUANTIZED_EMBEDDING`` pattern matches that subgraph and lowers it to fused +gather Metal kernels. These tests compare the kernel against the eager reference +(``gguf``-package dequant + ``F.embedding``) on the same packed table. Usage:: @@ -27,18 +27,24 @@ import torch import torch.nn as nn from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q4_k_blob, make_q6_k_blob, ) from executorch.backends.mlx.test.test_utils import OpTestCase from executorch.extension.llm.export.gguf import ExportableGGUFTensor -def _make_gguf_embedding_model(vocab: int, K: int, seed: int = 0) -> nn.Module: - """An ``nn.Embedding`` whose weight is a Q6_K ``ExportableGGUFTensor``.""" +def _make_gguf_embedding_model( + vocab: int, + K: int, + ggml_type: str = "q6_k", + seed: int = 0, +) -> nn.Module: emb = nn.Embedding(vocab, K) - blob = make_q6_k_blob(vocab, K, seed=seed) + blob_maker = make_q4_k_blob if ggml_type == "q4_k" else make_q6_k_blob + blob = blob_maker(vocab, K, seed=seed) emb.weight = nn.Parameter( - ExportableGGUFTensor.from_raw(blob, "q6_k", torch.bfloat16), + ExportableGGUFTensor.from_raw(blob, ggml_type, torch.bfloat16), requires_grad=False, ) return emb @@ -56,12 +62,14 @@ def __init__( vocab: int = 512, K: int = 256, idx_shape: Tuple[int, ...] = (8,), + ggml_type: str = "q6_k", ): self.vocab = vocab self.K = K self.idx_shape = idx_shape + self.ggml_type = ggml_type shp = "x".join(str(d) for d in idx_shape) - self.name = f"gguf_embedding_v{vocab}_k{K}_idx{shp}" + self.name = f"gguf_embedding_{ggml_type}_v{vocab}_k{K}_idx{shp}" @classmethod def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: @@ -77,6 +85,10 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: # kept small so the packed weight fits CI-runner GPU buffer limits; the # gather + per-row dequant path is identical regardless of vocab. cls(vocab=2048, K=5376, idx_shape=(8,)), + cls(vocab=512, K=256, idx_shape=(8,), ggml_type="q4_k"), + cls(vocab=512, K=512, idx_shape=(8,), ggml_type="q4_k"), + cls(vocab=512, K=256, idx_shape=(2, 3), ggml_type="q4_k"), + cls(vocab=2048, K=5376, idx_shape=(8,), ggml_type="q4_k"), ] def get_edge_compile_config(self): @@ -86,7 +98,7 @@ def get_edge_compile_config(self): return EdgeCompileConfig(_check_ir_validity=False) def create_model(self) -> nn.Module: - return _make_gguf_embedding_model(self.vocab, self.K) + return _make_gguf_embedding_model(self.vocab, self.K, self.ggml_type) def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) @@ -100,7 +112,9 @@ def _main() -> None: # noqa: C901 from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner - parser = argparse.ArgumentParser(description="Test GGUF Q6_K embedding lowering") + parser = argparse.ArgumentParser( + description="Test GGUF Q6_K / Q4_K embedding lowering" + ) parser.add_argument("action", choices=["generate", "compare", "run", "list"]) parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--rebuild", action="store_true") diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index 4a7defbe107..55094d4acd1 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -6,16 +6,16 @@ # LICENSE file in the root directory of this source tree. """ -Tests for the GGUF Q6_K linear lowering. +Tests for the GGUF Q6_K / Q4_K linear lowering. A linear whose weight is an ``ExportableGGUFTensor`` (extension/llm/export/gguf) -exports to ``linear(x, torchao::dequantize_gguf(weight, "q6_k", ...), bias)``. +exports to ``linear(x, torchao::dequantize_gguf(weight, ggml_type, ...), bias)``. The MLX ``GGUF_QUANTIZED_LINEAR`` pattern (custom_kernel_ops/gguf/patterns.py) -matches that subgraph and lowers it to the fused Q6_K Metal kernels (mat-vec for -decode, mat-mat for prefill). These tests compare the fused kernels against the -eager reference (``gguf``-package dequant + ``F.linear``) on the same packed -weight, so quantization quality is irrelevant -- only kernel-vs-reference -numerics are checked. +matches that subgraph and lowers it to fused Metal kernels (mat-vec for decode, +mat-mat for prefill). These tests compare the fused kernels against the eager +reference (``gguf``-package dequant + ``F.linear``) on the same packed weight, +so quantization quality is irrelevant -- only kernel-vs-reference numerics are +checked. ``GGUFLinearDynamicTest`` exports once with a symbolic seqlen and runs the same .pte with M=1 and M>1 to exercise both branches of the runtime ``IfNode`` @@ -40,7 +40,7 @@ # --------------------------------------------------------------------------- -# GGUF Q6_K test fixtures. +# GGUF Q6_K / Q4_K test fixtures. # # The Python ``gguf`` package can dequantize Q6_K but does NOT implement Q6_K # quantization, so we build the packed weight here. Quantization quality is @@ -130,30 +130,12 @@ def _fp32_linear_reference(model: "GGUFLinearModel", x: torch.Tensor): a bf16 eager matmul is too noisy an oracle over large K. Dequantize in fp32, matmul in fp32, then cast back -- differences collapse to ~1 output ULP. - The reference weight must match the representation the kernel consumes: - Q6_K dequantizes the raw blob in-kernel at full precision (use the gguf-exact - dequant), while Q4_K is repacked into bf16 MLX qparams, so use that repacked - dequant (repack precision vs gguf is covered separately by test_gguf.py). + Both Q6_K and Q4_K kernels dequantize the raw GGUF blob in-kernel; use the + gguf-exact dequant as the reference oracle. """ lin = model.linear weight = lin.weight - if getattr(weight, "ggml_type", None) == "q4_k": - # Q4_K is repacked into bf16 MLX affine qparams (S, Q, B); reconstruct - # exactly what the kernel dequantizes so the oracle isolates kernel - # accumulation (repack precision vs gguf is covered by test_gguf.py). - from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams - - intx = weight.to_intx_unpacked_to_int8_tensor() - gs = int(intx.block_size[-1]) - Q, B = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4) - qb = Q.view(torch.uint8) - nibbles = torch.stack([(qb & 0xF).float(), ((qb >> 4) & 0xF).float()], dim=-1) - q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1) - scale = intx.scale.float().repeat_interleave(gs, dim=1) - bias_b = B.float().repeat_interleave(gs, dim=1) - w = scale * q_unsigned + bias_b - else: - w = weight.dequantize(torch.float32) + w = weight.dequantize(torch.float32) bias = lin.bias.float() if lin.bias is not None else None out = torch.nn.functional.linear(x.float(), w, bias) return [out.to(x.dtype)] @@ -223,7 +205,7 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: # fits CI-runner GPU buffer limits; the mat-vec N-tiling path is the # same at any N. cfgs.append(cls(M=1, N=16384, K=5376, dtype=torch.bfloat16)) # lm_head - # Q4_K -> MLX native 4-bit quantized_matmul (group_size 32). + # Q4_K fused Metal kernels (mat-vec / mat-mat). cfgs.append(cls(M=1, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) cfgs.append(cls(M=8, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) cfgs.append(cls(M=1, N=5376, K=5376, dtype=torch.bfloat16, ggml_type="q4_k")) @@ -264,15 +246,17 @@ def __init__( N: int = 512, K: int = 512, dtype: torch.dtype = torch.bfloat16, + ggml_type: str = "q6_k", ): self.export_M = export_M self.test_M = test_M self.N = N self.K = K self.dtype = dtype + self.ggml_type = ggml_type self.rtol, self.atol = _DTYPE_TOL[dtype] self.name = ( - f"gguf_linear_dyn_exp{export_M}_test{test_M}_n{N}_k{K}_" + f"gguf_linear_dyn_{ggml_type}_exp{export_M}_test{test_M}_n{N}_k{K}_" f"{_DTYPE_TAG[dtype]}" ) @@ -284,6 +268,8 @@ def get_test_configs(cls) -> List["GGUFLinearDynamicTest"]: cls(export_M=4, test_M=4, dtype=torch.bfloat16), # control cls(export_M=4, test_M=1, dtype=torch.float16), cls(export_M=4, test_M=40, N=300, K=256, dtype=torch.bfloat16), # ragged + cls(export_M=4, test_M=1, dtype=torch.bfloat16, ggml_type="q4_k"), + cls(export_M=4, test_M=8, dtype=torch.bfloat16, ggml_type="q4_k"), ] def get_dynamic_shapes(self): @@ -296,7 +282,9 @@ def get_edge_compile_config(self): def create_model(self) -> nn.Module: # Deterministic weight so export-time and run-time use the same model. return GGUFLinearModel( - _make_gguf_linear_model(self.N, self.K, self.dtype, bias=True) + _make_gguf_linear_model( + self.N, self.K, self.dtype, bias=True, ggml_type=self.ggml_type + ) ) def create_inputs(self) -> Tuple[torch.Tensor, ...]: @@ -331,7 +319,9 @@ def _eager_sanity() -> None: from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner - parser = argparse.ArgumentParser(description="Test GGUF Q6_K linear lowering") + parser = argparse.ArgumentParser( + description="Test GGUF Q6_K / Q4_K linear lowering" + ) parser.add_argument( "action", choices=["generate", "compare", "run", "list", "eager"] ) From 49ac1d2c7092858754144ed387750e0dff50ef7f Mon Sep 17 00:00:00 2001 From: uddeshsingh Date: Fri, 12 Jun 2026 00:23:36 -0500 Subject: [PATCH 2/5] Guard Q4_K GGUF lowering behind ET_MLX_EMIT_DIRECT_GGUF Keep the legacy MLX-native repack path available when the env var is set to 0, per maintainer request on #20172. --- .../mlx/custom_kernel_ops/gguf/patterns.py | 29 +++++-- .../custom_kernel_ops/gguf/q4k/__init__.py | 12 +++ .../gguf/q4k/embedding_mlx_native.py | 58 +++++++++++++ .../gguf/q4k/linear_mlx_native.py | 85 +++++++++++++++++++ .../custom_kernel_ops/gguf/q4k/repack_mlx.py | 43 ++++++++++ 5 files changed, 222 insertions(+), 5 deletions(-) create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index e3a9fea97b8..a8652796bfb 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -18,7 +18,8 @@ lower it without materializing the dequantized weight: * **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`. -* **Q4_K** -> fused custom Metal kernels in :mod:`.q4k`. +* **Q4_K** -> fused custom Metal kernels in :mod:`.q4k` (default), or the legacy + MLX-native repack path when ``ET_MLX_EMIT_DIRECT_GGUF=0``. Both cover linear and embedding. @@ -113,9 +114,18 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: emit_linear, ) else: # q4_k - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( - emit_linear, + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import ( + emit_direct_gguf, ) + + if emit_direct_gguf(): + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( + emit_linear, + ) + else: + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import ( + emit_linear, + ) return emit_linear(P, n, x_node, self.weight, bias_node) @@ -159,7 +169,16 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: emit_embedding, ) else: # q4_k - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( - emit_embedding, + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import ( + emit_direct_gguf, ) + + if emit_direct_gguf(): + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( + emit_embedding, + ) + else: + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding_mlx_native import ( + emit_embedding, + ) return emit_embedding(P, n, self.weight, indices_node, self.output_dtype) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py index 30fbcf96118..cd517828b7d 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py @@ -11,4 +11,16 @@ See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by ``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the package import light. + +Set ``ET_MLX_EMIT_DIRECT_GGUF=0`` to use the legacy export-time repack path +(:mod:`.linear_mlx_native` / :mod:`.embedding_mlx_native`) instead. """ + +from __future__ import annotations + +import os + + +def emit_direct_gguf() -> bool: + """Return True to emit fused kernels that read raw GGUF bytes (default).""" + return os.environ.get("ET_MLX_EMIT_DIRECT_GGUF", "1") != "0" diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py new file mode 100644 index 00000000000..73b386fc44d --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py @@ -0,0 +1,58 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** embedding lowering via MLX's native 4-bit quantized gather. + +Lowers a ``dequantize_gguf -> embedding`` pattern to a quantized gather: gather +the packed quants / scales / biases by index, then dequantize the gathered rows +(``DequantizeNode``, mode "affine"). The GGUF blob is repacked into MLX qparams +at export time (see :mod:`.repack_mlx`). +""" + +from __future__ import annotations + +from executorch.backends.mlx.builder.op_helpers import emit_quantized_gather +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.repack_mlx import ( + _BITS, + repack_mlx, +) +from torch.fx.node import Node + + +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype, +) -> Slot: + """Lower a Q4_K ``dequantize_gguf -> embedding`` pattern to a quantized gather. + + Gathers the packed quants / scales / biases by index, then dequantizes the + gathered rows (MLX affine 4-bit) -- the same shape as MLX's generic quantized + embedding. + """ + w_slot, scales_slot, biases_slot, group_size = repack_mlx(P, weight_node) + (indices_slot,) = P.slot_map([indices_node]) + + out = P.make_or_get_slot(head) + emit_quantized_gather( + P, + out, + indices_slot, + w_slot, + scales_slot, + biases_slot, + group_size=group_size, + bits=_BITS, + mode="affine", + out_dtype=output_dtype, + ) + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py new file mode 100644 index 00000000000..4ecc6d89527 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py @@ -0,0 +1,85 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** linear lowering via MLX's native 4-bit quantized matmul. + +Lowers a ``dequantize_gguf -> linear`` pattern to a ``QuantizedMatmulNode`` +(mode "affine", group_size 32); the GGUF blob is repacked into MLX qparams at +export time (see :mod:`.repack_mlx`). +""" + +from __future__ import annotations + +from typing import Optional + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.repack_mlx import ( + _BITS, + repack_mlx, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddNode, + AsTypeNode, + QuantizedMatmulNode, +) +from torch.fx.node import Node + + +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Lower a Q4_K ``dequantize_gguf -> linear`` pattern to MLX 4-bit matmul. + + ``weight_node`` is the raw GGUF blob constant; ``head`` is the ``aten.linear`` + node. The blob is repacked into MLX qparams at export time, so only the + MLX-format constants are serialized. + """ + w_slot, scales_slot, biases_slot, group_size = repack_mlx(P, weight_node) + x_slot, bias_slot = P.slot_map([x_node, bias_node]) + + out = P.make_or_get_slot(head) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w_slot), + scales=P.slot_to_tid(scales_slot), + biases=P.slot_to_tid(biases_slot), + out=P.slot_to_tid(out), + group_size=group_size, + bits=_BITS, + mode="affine", + transpose=True, + ) + ) + + if bias_node is not None: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(bias_slot), + out=P.slot_to_tid(out), + ) + ) + + out_dtype = head.meta["val"].dtype + if out_dtype != x_node.meta["val"].dtype: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py b/backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py new file mode 100644 index 00000000000..5775bd83d5c --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py @@ -0,0 +1,43 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Q4_K -> MLX qparam repack for the legacy MLX-native lowering path. + +Used when ``ET_MLX_EMIT_DIRECT_GGUF=0``: the GGUF blob is unpacked and repacked +into MLX affine 4-bit qparams at export time instead of being consumed directly +by fused Metal kernels. +""" + +from __future__ import annotations + +from typing import Tuple + +from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from torch.fx.node import Node + +_BITS = 4 + + +def repack_mlx(P: MLXProgramBuilder, weight_node: Node) -> Tuple[Slot, Slot, Slot, int]: + """Unpack a raw Q4_K blob and repack into MLX qparam constants. + + Returns ``(packed_slot, scales_slot, biases_slot, group_size)``. + """ + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + weight_target, raw = P.get_placeholder_target_and_tensor(weight_node) + intx = ExportableGGUFTensor.from_raw(raw, "q4_k").to_intx_unpacked_to_int8_tensor() + group_size = int(intx.block_size[-1]) + packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, _BITS) + + packed_slot = P.make_or_get_constant(f"{weight_target}_q4k_packed", packed) + scales_slot = P.make_or_get_constant(f"{weight_target}_q4k_scales", intx.scale) + biases_slot = P.make_or_get_constant(f"{weight_target}_q4k_biases", biases) + return packed_slot, scales_slot, biases_slot, group_size From c23c9e4ad5e9195dc887e36746b95aa69f73718e Mon Sep 17 00:00:00 2001 From: uddeshsingh Date: Fri, 12 Jun 2026 14:37:31 -0500 Subject: [PATCH 3/5] Extract emit_if_else/emit_sub_int/emit_ceil_div helpers, fix output dtype handling, add legacy-path test coverage, and harden the embedding kernel. --- backends/mlx/builder/op_helpers.py | 106 +++++++++++++++++- .../custom_kernel_ops/gguf/q4k/embedding.py | 18 ++- .../mlx/custom_kernel_ops/gguf/q4k/linear.py | 100 ++++++----------- .../mlx/custom_kernel_ops/gguf/q6k/linear.py | 61 +++------- .../gguf/test/test_embedding.py | 27 ++++- .../gguf/test/test_linear.py | 64 ++++++++++- 6 files changed, 256 insertions(+), 120 deletions(-) diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 2f94a808adc..a1f1ff5747c 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union import torch from executorch.backends.mlx.builder.slot_manager import Slot @@ -285,6 +285,110 @@ def emit_product( return P.to_int_or_vid(final_val) +def emit_add_int( + P: "MLXProgramBuilder", + a: "IntOrVid", + b: "IntOrVid", +) -> "IntOrVid": + """Emit ``a + b``, folding to a literal when both operands are static.""" + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + IntOrVid, + ) + + if not a.is_vid and not b.is_vid: + return IntOrVid.from_literal(a.literal + b.literal) + + _, out_slot = P.make_tmp_value_slot() + P.emit(AddIntNode(a=a, b=b, out=P.slot_to_vid(out_slot))) + return P.to_int_or_vid(out_slot) + + +def emit_sub_int( + P: "MLXProgramBuilder", + a: "IntOrVid", + b: "IntOrVid", +) -> "IntOrVid": + """Emit ``a - b``, folding to a literal when both operands are static.""" + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + SubtractIntNode, + ) + + if not a.is_vid and not b.is_vid: + return IntOrVid.from_literal(a.literal - b.literal) + + _, out_slot = P.make_tmp_value_slot() + P.emit(SubtractIntNode(a=a, b=b, out=P.slot_to_vid(out_slot))) + return P.to_int_or_vid(out_slot) + + +def emit_ceil_div( + P: "MLXProgramBuilder", + a: "IntOrVid", + b: int, +) -> "IntOrVid": + """Emit ``ceil(a / b)``, folding to a literal when ``a`` is static. + + Computes ``(a + b - 1) // b``. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FloorDivideIntNode, + IntOrVid, + ) + + if not a.is_vid: + return IntOrVid.from_literal((a.literal + b - 1) // b) + + sum_iov = emit_add_int(P, a, IntOrVid.from_literal(b - 1)) + _, out_slot = P.make_tmp_value_slot() + P.emit( + FloorDivideIntNode( + a=sum_iov, + b=IntOrVid.from_literal(b), + out=P.slot_to_vid(out_slot), + ) + ) + return P.to_int_or_vid(out_slot) + + +def emit_if_else( + P: "MLXProgramBuilder", + cond: "IntOrVid", + emit_then: Callable[[], None], + emit_else: Callable[[], None], +) -> None: + """Emit a branch on ``cond``: nonzero -> then, zero -> else. + + If ``cond`` is a literal, no IfNode or chains are emitted; the + selected callback is invoked directly in the current chain. + Otherwise both callbacks are emitted into fresh chains and an + ``IfNode`` selects between them at runtime. Nodes emitted inside a + callback land in that branch's chain only. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import IfNode + + if not cond.is_vid: + if cond.literal: + emit_then() + else: + emit_else() + return + + with P.new_chain() as then_idx: + emit_then() + with P.new_chain() as else_idx: + emit_else() + + P.emit( + IfNode( + cond=cond, + then_chain_idx=then_idx, + else_chain_idx=else_idx, + ) + ) + + def emit_quantized_biases( P: "MLXProgramBuilder", zero_point_key: str, diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py index 35ee1db4242..8a547268a84 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -45,6 +45,10 @@ const uint j = thread_position_in_grid.x; // 0..K-1 const uint r = thread_position_in_grid.y; // gathered row const int row = (int) indices[r]; + if (row < 0 || row >= V) { + out[r * (uint)K + j] = (OutT)0; + return; + } const int nb = K / QK_K; device const block_q4_K * blk = ((device const block_q4_K *) weight) + (uint)row * nb + (j / QK_K); @@ -72,8 +76,9 @@ def emit_embedding( f"gguf q4k embedding: weight must be 2-D (vocab, row_bytes); got " f"shape {tuple(weight_meta.shape)}" ) + vocab = weight_meta.shape[0] row_bytes = weight_meta.shape[1] - if not isinstance(row_bytes, int): + if not isinstance(vocab, int) or not isinstance(row_bytes, int): raise NotImplementedError( "gguf q4k embedding: weight shape must be statically known" ) @@ -91,8 +96,9 @@ def emit_embedding( num_idx_iov = emit_product(P, leading) out_shape_flat = leading + [IntOrVid.from_literal(K)] - # threadgroup.x must divide grid.x (= K, a multiple of 256). - tg_x = 256 if K % 256 == 0 else K + if K % QK_K != 0: + raise AssertionError(f"gguf q4k embedding: K={K} must be divisible by {QK_K}") + tg_x = QK_K P.emit( MetalKernelNode( @@ -112,9 +118,9 @@ def emit_embedding( output_shapes_flat=out_shape_flat, output_shape_lengths=[len(out_shape_flat)], output_dtypes=[out_dtype_int], - template_arg_names=["OutT", "K"], - template_arg_kinds=[2, 0], # dtype, int - template_arg_values=[out_dtype_int, K], + template_arg_names=["OutT", "K", "V"], + template_arg_kinds=[2, 0, 0], # dtype, int, int + template_arg_values=[out_dtype_int, K, vocab], ) ) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py index 0078b029255..77cd0ec710d 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -17,8 +17,11 @@ from typing import Optional from executorch.backends.mlx.builder.op_helpers import ( + emit_ceil_div, + emit_if_else, emit_product, emit_shape, + emit_sub_int, torch_dtype_to_scalar_type, ) from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder @@ -29,13 +32,9 @@ QK_K, ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( - AddIntNode, - FloorDivideIntNode, - IfNode, IntOrVid, MetalKernelNode, MultiplyIntNode, - SubtractIntNode, ) from torch.fx.node import Node @@ -49,7 +48,7 @@ # Threadgroup = (32 * NSG, 1, 1): NSG simdgroups, each computing N_R0 output # rows for one activation row (grid.y). Accumulate in float, reduce via simd_sum. def _q4k_matvec_source(has_bias: bool) -> str: - write = "out[(uint)m * N + r] = (InT)(tot" + write = "out[(uint)m * N + r] = (OutT)(tot" write += " + (float)bias[r]);" if has_bias else ");" return f""" constexpr short N_R0 = 2; @@ -144,14 +143,14 @@ def _q4k_matvec_source(has_bias: bool) -> str: # 64x32 output tiles, 4 simdgroups / 128 threads per threadgroup. # Uses vectorized dequantize_q4_K_16 to decode 16 weight values per thread # into threadgroup memory, then runs simdgroup_multiply_accumulate on 8x8 -# tiles. NL=8 for Q4_K (QK_K / 32 = 8 dequant steps per super-block). +# tiles. NL=16 for Q4_K (QK_K / 16 = 16 dequant steps per super-block). def _q4k_matmul_source(has_bias: bool) -> str: bias_add = "+ (float) bias[r0 + i]" if has_bias else "" return f""" constexpr short NR0 = 64; // weight/output rows per tile (N dim) constexpr short NR1 = 32; // activation rows per tile (M dim) constexpr short NK = 32; // K-chunk per iteration - constexpr short NL = 16; // Q4_K: QK_K / 32 + constexpr short NL = 16; // Q4_K: QK_K / 16 constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation @@ -181,7 +180,7 @@ def _q4k_matmul_source(has_bias: bool) -> str: short il0 = tid % NL0; short il = il0; // current dequant sub-block index within Q4_K block - const short offset1 = il0 / NL; // always 0 for NL=8, NL0=4 + const short offset1 = il0 / NL; // always 0 for NL=8, NL0=2 // Pointer to weight block for this thread's assigned row. device const block_q4_K * wblk = (device const block_q4_K *) weight @@ -255,7 +254,7 @@ def _q4k_matmul_source(has_bias: bool) -> str: }} }} - // --- Write results: always via threadgroup memory for float→InT cast --- + // --- Write results: always via threadgroup memory for float→OutT cast --- // Barrier needed: sa was used for weight tiles during the K-loop and is now // reused as float staging for the output. Without this barrier, a fast // simdgroup could start writing mc[] into sa while a slower one is still @@ -273,11 +272,11 @@ def _q4k_matmul_source(has_bias: bool) -> str: if (sgitg == 0) {{ for (int j = tid; j < nr1; j += NR1) {{ - device InT * D = out + (uint)(r1 + j) * (uint)N + r0; + device OutT * D = out + (uint)(r1 + j) * (uint)N + r0; threadgroup float * Cp = ((threadgroup float *) sa) + j * NR0; for (int i = 0; i < nr0; ++i) {{ float v = Cp[i]; - D[i] = (InT)(v {bias_add}); + D[i] = (OutT)(v {bias_add}); }} }} }} @@ -300,6 +299,7 @@ def _emit_q4k_matvec( bias_slot: Optional[Slot], N: int, K: int, + out_dtype_int: int, out: Slot, ) -> None: in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) @@ -341,10 +341,10 @@ def _emit_q4k_matvec( output_names=["out"], output_shapes_flat=out_shape_flat, output_shape_lengths=[len(out_shape_flat)], - output_dtypes=[in_dtype_int], - template_arg_names=["InT", "N", "K", "NSG"], - template_arg_kinds=[2, 0, 0, 0], # dtype, int, int, int - template_arg_values=[in_dtype_int, N, K, nsg], + output_dtypes=[out_dtype_int], + template_arg_names=["InT", "OutT", "N", "K", "NSG"], + template_arg_kinds=[2, 2, 0, 0, 0], # dtype, dtype, int, int, int + template_arg_values=[in_dtype_int, out_dtype_int, N, K, nsg], ) ) @@ -358,6 +358,7 @@ def _emit_q4k_matmul( N: int, K: int, blocks_m_iov: IntOrVid, + out_dtype_int: int, out: Slot, ) -> None: in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) @@ -408,10 +409,10 @@ def _emit_q4k_matmul( output_names=["out"], output_shapes_flat=out_shape_flat, output_shape_lengths=[len(out_shape_flat)], - output_dtypes=[in_dtype_int], - template_arg_names=["InT", "N", "K"], - template_arg_kinds=[2, 0, 0], - template_arg_values=[in_dtype_int, N, K], + output_dtypes=[out_dtype_int], + template_arg_names=["InT", "OutT", "N", "K"], + template_arg_kinds=[2, 2, 0, 0], + template_arg_values=[in_dtype_int, out_dtype_int, N, K], ) ) @@ -462,10 +463,13 @@ def emit_linear( break out = P.make_or_get_slot(head) + out_dtype_int = torch_dtype_to_scalar_type(head.meta["val"].dtype) tile = _Q4K_MM_NR1 # M-dimension tile (activation rows per threadgroup) if M == 1: # Static decode -> mat-vec. - _emit_q4k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) + _emit_q4k_matvec( + P, x_node, x_slot, weight_slot, bias_slot, N, K, out_dtype_int, out + ) elif M is not None: # Static prefill -> tiled simdgroup mat-mat (literal grid). blocks_m = (M + tile - 1) // tile @@ -478,46 +482,21 @@ def emit_linear( N, K, IntOrVid.from_literal(blocks_m), + out_dtype_int, out, ) else: # Dynamic seqlen -> emit both kernels in separate chains and select at # runtime with an IfNode. cond = M - 1: nonzero (M>1) runs the mat-mat # (then) chain, zero (M==1) runs the mat-vec (else) chain. - leading = emit_shape(P, x_node, x_slot, end_dim=-1) - m_iov = emit_product(P, leading) - - _, cond_slot = P.make_tmp_value_slot() - P.emit( - SubtractIntNode( - a=m_iov, - b=IntOrVid.from_literal(1), - out=P.slot_to_vid(cond_slot), - ) - ) - cond_iov = IntOrVid.from_vid(P.slot_to_vid(cond_slot)) - - # blocks_m = (M + tile - 1) // tile (mat-mat grid.y). - _, sum_slot = P.make_tmp_value_slot() - P.emit( - AddIntNode( - a=m_iov, - b=IntOrVid.from_literal(tile - 1), - out=P.slot_to_vid(sum_slot), - ) - ) - _, blocks_m_slot = P.make_tmp_value_slot() - P.emit( - FloorDivideIntNode( - a=IntOrVid.from_vid(P.slot_to_vid(sum_slot)), - b=IntOrVid.from_literal(tile), - out=P.slot_to_vid(blocks_m_slot), - ) - ) - blocks_m_iov = IntOrVid.from_vid(P.slot_to_vid(blocks_m_slot)) + m_iov = emit_product(P, emit_shape(P, x_node, x_slot, end_dim=-1)) + cond_iov = emit_sub_int(P, m_iov, IntOrVid.from_literal(1)) + blocks_m_iov = emit_ceil_div(P, m_iov, tile) - with P.new_chain() as then_idx: # prefill / mat-mat - _emit_q4k_matmul( + emit_if_else( + P, + cond_iov, + emit_then=lambda: _emit_q4k_matmul( P, x_node, x_slot, @@ -526,16 +505,11 @@ def emit_linear( N, K, blocks_m_iov, + out_dtype_int, out, - ) - with P.new_chain() as else_idx: # decode / mat-vec - _emit_q4k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) - - P.emit( - IfNode( - cond=cond_iov, - then_chain_idx=then_idx, - else_chain_idx=else_idx, - ) + ), + emit_else=lambda: _emit_q4k_matvec( + P, x_node, x_slot, weight_slot, bias_slot, N, K, out_dtype_int, out + ), ) return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index 99a82053e90..9e63a6a3bd3 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -45,8 +45,11 @@ from typing import Optional from executorch.backends.mlx.builder.op_helpers import ( + emit_ceil_div, + emit_if_else, emit_product, emit_shape, + emit_sub_int, torch_dtype_to_scalar_type, ) from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder @@ -57,13 +60,9 @@ QK_K, ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( - AddIntNode, - FloorDivideIntNode, - IfNode, IntOrVid, MetalKernelNode, MultiplyIntNode, - SubtractIntNode, ) from torch.fx.node import Node @@ -492,40 +491,14 @@ def emit_linear( # Dynamic seqlen -> emit both kernels in separate chains and select at # runtime with an IfNode. cond = M - 1: nonzero (M>1) runs the mat-mat # (then) chain, zero (M==1) runs the mat-vec (else) chain. - leading = emit_shape(P, x_node, x_slot, end_dim=-1) - m_iov = emit_product(P, leading) - - _, cond_slot = P.make_tmp_value_slot() - P.emit( - SubtractIntNode( - a=m_iov, - b=IntOrVid.from_literal(1), - out=P.slot_to_vid(cond_slot), - ) - ) - cond_iov = IntOrVid.from_vid(P.slot_to_vid(cond_slot)) - - # blocks_m = (M + tile - 1) // tile (mat-mat grid.y). - _, sum_slot = P.make_tmp_value_slot() - P.emit( - AddIntNode( - a=m_iov, - b=IntOrVid.from_literal(tile - 1), - out=P.slot_to_vid(sum_slot), - ) - ) - _, blocks_m_slot = P.make_tmp_value_slot() - P.emit( - FloorDivideIntNode( - a=IntOrVid.from_vid(P.slot_to_vid(sum_slot)), - b=IntOrVid.from_literal(tile), - out=P.slot_to_vid(blocks_m_slot), - ) - ) - blocks_m_iov = IntOrVid.from_vid(P.slot_to_vid(blocks_m_slot)) + m_iov = emit_product(P, emit_shape(P, x_node, x_slot, end_dim=-1)) + cond_iov = emit_sub_int(P, m_iov, IntOrVid.from_literal(1)) + blocks_m_iov = emit_ceil_div(P, m_iov, tile) - with P.new_chain() as then_idx: # prefill / mat-mat - _emit_q6k_matmul( + emit_if_else( + P, + cond_iov, + emit_then=lambda: _emit_q6k_matmul( P, x_node, x_slot, @@ -535,15 +508,9 @@ def emit_linear( K, blocks_m_iov, out, - ) - with P.new_chain() as else_idx: # decode / mat-vec - _emit_q6k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) - - P.emit( - IfNode( - cond=cond_iov, - then_chain_idx=then_idx, - else_chain_idx=else_idx, - ) + ), + emit_else=lambda: _emit_q6k_matvec( + P, x_node, x_slot, weight_slot, bias_slot, N, K, out + ), ) return out diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index 12acd4ebc75..d86a9a55d37 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -27,6 +27,8 @@ import torch import torch.nn as nn from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + _emit_direct_gguf_env, + _q4k_mlx_native_dequant, make_q4_k_blob, make_q6_k_blob, ) @@ -63,13 +65,18 @@ def __init__( K: int = 256, idx_shape: Tuple[int, ...] = (8,), ggml_type: str = "q6_k", + emit_direct_gguf: bool = True, ): self.vocab = vocab self.K = K self.idx_shape = idx_shape self.ggml_type = ggml_type + self.emit_direct_gguf = emit_direct_gguf shp = "x".join(str(d) for d in idx_shape) - self.name = f"gguf_embedding_{ggml_type}_v{vocab}_k{K}_idx{shp}" + tag = f"gguf_embedding_{ggml_type}_v{vocab}_k{K}_idx{shp}" + if ggml_type == "q4_k" and not emit_direct_gguf: + tag += "_mlx_native" + self.name = tag @classmethod def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: @@ -89,8 +96,19 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: cls(vocab=512, K=512, idx_shape=(8,), ggml_type="q4_k"), cls(vocab=512, K=256, idx_shape=(2, 3), ggml_type="q4_k"), cls(vocab=2048, K=5376, idx_shape=(8,), ggml_type="q4_k"), + cls( + vocab=512, + K=256, + idx_shape=(8,), + ggml_type="q4_k", + emit_direct_gguf=False, + ), ] + def generate_test_files(self, verbose: bool = False): + with _emit_direct_gguf_env(self.emit_direct_gguf): + return super().generate_test_files(verbose=verbose) + def get_edge_compile_config(self): from executorch.exir import EdgeCompileConfig @@ -105,6 +123,13 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: indices = torch.randint(0, self.vocab, self.idx_shape, dtype=torch.int64) return (indices,) + def compute_expected_outputs(self, model, test_inputs): + if self.ggml_type == "q4_k" and not self.emit_direct_gguf: + weight = _q4k_mlx_native_dequant(model.weight) + out = torch.nn.functional.embedding(test_inputs[0], weight) + return [out.to(model.weight.orig_dtype)] + return model(*test_inputs) + def _main() -> None: # noqa: C901 import argparse diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index 55094d4acd1..68f86eec289 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -28,6 +28,8 @@ python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear list """ +import os +from contextlib import contextmanager from typing import List, Tuple # Importing the patterns module registers GGUF_QUANTIZED_LINEAR / _EMBEDDING. @@ -95,6 +97,19 @@ def make_q4_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: _BLOB_MAKERS = {"q6_k": make_q6_k_blob, "q4_k": make_q4_k_blob} +@contextmanager +def _emit_direct_gguf_env(enabled: bool): + old = os.environ.get("ET_MLX_EMIT_DIRECT_GGUF") + os.environ["ET_MLX_EMIT_DIRECT_GGUF"] = "1" if enabled else "0" + try: + yield + finally: + if old is None: + os.environ.pop("ET_MLX_EMIT_DIRECT_GGUF", None) + else: + os.environ["ET_MLX_EMIT_DIRECT_GGUF"] = old + + def _make_gguf_linear_model( N: int, K: int, @@ -130,8 +145,9 @@ def _fp32_linear_reference(model: "GGUFLinearModel", x: torch.Tensor): a bf16 eager matmul is too noisy an oracle over large K. Dequantize in fp32, matmul in fp32, then cast back -- differences collapse to ~1 output ULP. - Both Q6_K and Q4_K kernels dequantize the raw GGUF blob in-kernel; use the - gguf-exact dequant as the reference oracle. + Direct Q6_K / Q4_K kernels dequantize the raw GGUF blob in-kernel; use the + gguf-exact dequant as the reference oracle. Legacy Q4_K tests override this + to match the export-time MLX qparam repack path. """ lin = model.linear weight = lin.weight @@ -141,6 +157,30 @@ def _fp32_linear_reference(model: "GGUFLinearModel", x: torch.Tensor): return [out.to(x.dtype)] +def _q4k_mlx_native_dequant(weight) -> torch.Tensor: + from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams + + intx = weight.to_intx_unpacked_to_int8_tensor() + group_size = int(intx.block_size[-1]) + packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4) + packed_bytes = packed.view(torch.uint8) + nibbles = torch.stack( + [(packed_bytes & 0xF).float(), ((packed_bytes >> 4) & 0xF).float()], dim=-1 + ) + q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1) + scale = intx.scale.float().repeat_interleave(group_size, dim=1) + bias = biases.float().repeat_interleave(group_size, dim=1) + return scale * q_unsigned + bias + + +def _fp32_linear_mlx_native_reference(model: "GGUFLinearModel", x: torch.Tensor): + lin = model.linear + w = _q4k_mlx_native_dequant(lin.weight) + bias = lin.bias.float() if lin.bias is not None else None + out = torch.nn.functional.linear(x.float(), w, bias) + return [out.to(x.dtype)] + + _DTYPE_TOL = { torch.bfloat16: (2e-2, 2e-2), # The mat-mat (prefill) kernel stores tiles in half precision (as in @@ -169,6 +209,7 @@ def __init__( dtype: torch.dtype = torch.bfloat16, bias: bool = True, ggml_type: str = "q6_k", + emit_direct_gguf: bool = True, ): self.M = M self.N = N @@ -176,8 +217,11 @@ def __init__( self.dtype = dtype self.bias = bias self.ggml_type = ggml_type + self.emit_direct_gguf = emit_direct_gguf self.rtol, self.atol = _DTYPE_TOL[dtype] tag = f"gguf_linear_{ggml_type}_m{M}_n{N}_k{K}_{_DTYPE_TAG[dtype]}" + if ggml_type == "q4_k" and not emit_direct_gguf: + tag += "_mlx_native" self.name = tag if bias else tag + "_nobias" @classmethod @@ -212,8 +256,22 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: cfgs.append( cls(M=1, N=512, K=512, dtype=torch.bfloat16, bias=False, ggml_type="q4_k") ) + cfgs.append( + cls( + M=1, + N=512, + K=512, + dtype=torch.bfloat16, + ggml_type="q4_k", + emit_direct_gguf=False, + ) + ) return cfgs + def generate_test_files(self, verbose: bool = False): + with _emit_direct_gguf_env(self.emit_direct_gguf): + return super().generate_test_files(verbose=verbose) + def get_edge_compile_config(self): return _edge_compile_config() @@ -229,6 +287,8 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (torch.randn(self.M, self.K, dtype=self.dtype),) def compute_expected_outputs(self, model, test_inputs): + if self.ggml_type == "q4_k" and not self.emit_direct_gguf: + return _fp32_linear_mlx_native_reference(model, test_inputs[0]) return _fp32_linear_reference(model, test_inputs[0]) From bbfe5d6659047367cd1daf7903a0152e922f5170 Mon Sep 17 00:00:00 2001 From: uddeshsingh Date: Fri, 12 Jun 2026 14:45:37 -0500 Subject: [PATCH 4/5] Move Q4_K env-var dispatch into emit_linear/emit_embedding so patterns stays unchanged. --- .../mlx/custom_kernel_ops/gguf/patterns.py | 21 +++++---------- .../mlx/custom_kernel_ops/gguf/q4k/common.py | 2 -- .../custom_kernel_ops/gguf/q4k/embedding.py | 27 ++++++++++++++++++- .../mlx/custom_kernel_ops/gguf/q4k/linear.py | 23 ++++++++++++++-- 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index a8652796bfb..129284c5509 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -114,18 +114,9 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: emit_linear, ) else: # q4_k - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import ( - emit_direct_gguf, + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( + emit_linear, ) - - if emit_direct_gguf(): - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( - emit_linear, - ) - else: - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import ( - emit_linear, - ) return emit_linear(P, n, x_node, self.weight, bias_node) @@ -177,8 +168,8 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( emit_embedding, ) - else: - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding_mlx_native import ( - emit_embedding, - ) + else: # q4_k + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( + emit_embedding, + ) return emit_embedding(P, n, self.weight, indices_node, self.output_dtype) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py index 4a234359360..a20261a301b 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py @@ -51,8 +51,6 @@ _Q4K_D_BYTES + _Q4K_DMIN_BYTES + _Q4K_SCALES_BYTES + _Q4K_QS_BYTES ) # 144 -# Q4_K mat-mat uses NL = QK_K / 32 (8 sub-blocks of 32 elements). -Q4K_NL = QK_K // 32 # 8 # --------------------------------------------------------------------------- # Shared Metal header diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py index 8a547268a84..2b0d831d8bd 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -56,7 +56,7 @@ """ -def emit_embedding( +def _emit_embedding_fused( P: MLXProgramBuilder, head: Node, weight_node: Node, @@ -125,3 +125,28 @@ def emit_embedding( ) return out + + + +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype: torch.dtype, +) -> Slot: + """Dispatch to fused Metal gather or the legacy MLX-native repack path.""" + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import emit_direct_gguf + + if emit_direct_gguf(): + return _emit_embedding_fused( + P, head, weight_node, indices_node, output_dtype + ) + + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding_mlx_native import ( + emit_embedding as emit_embedding_mlx_native, + ) + + return emit_embedding_mlx_native( + P, head, weight_node, indices_node, output_dtype + ) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py index 77cd0ec710d..4c098e9b4d8 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -180,7 +180,7 @@ def _q4k_matmul_source(has_bias: bool) -> str: short il0 = tid % NL0; short il = il0; // current dequant sub-block index within Q4_K block - const short offset1 = il0 / NL; // always 0 for NL=8, NL0=2 + const short offset1 = il0 / NL; // always 0 (il0 < NL0=2, NL=16) // Pointer to weight block for this thread's assigned row. device const block_q4_K * wblk = (device const block_q4_K *) weight @@ -417,7 +417,7 @@ def _emit_q4k_matmul( ) -def emit_linear( +def _emit_linear_fused( P: MLXProgramBuilder, head: Node, x_node: Node, @@ -513,3 +513,22 @@ def emit_linear( ), ) return out + +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Dispatch to fused Metal kernels or the legacy MLX-native repack path.""" + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import emit_direct_gguf + + if emit_direct_gguf(): + return _emit_linear_fused(P, head, x_node, weight_node, bias_node) + + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import ( + emit_linear as emit_linear_mlx_native, + ) + + return emit_linear_mlx_native(P, head, x_node, weight_node, bias_node) From e6ebd600acd5eb38812c6572386dd1dea24aa82a Mon Sep 17 00:00:00 2001 From: uddeshsingh Date: Fri, 12 Jun 2026 14:56:27 -0500 Subject: [PATCH 5/5] Move Q4_K embedding env-var dispatch fully into emit_embedding. --- backends/mlx/custom_kernel_ops/gguf/patterns.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index 129284c5509..308d6bd00aa 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -160,16 +160,7 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: emit_embedding, ) else: # q4_k - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import ( - emit_direct_gguf, - ) - - if emit_direct_gguf(): - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( - emit_embedding, - ) - else: # q4_k - from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( emit_embedding, ) return emit_embedding(P, n, self.weight, indices_node, self.output_dtype)