Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion backends/mlx/builder/op_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union
from typing import Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union

import torch
from executorch.backends.mlx.builder.slot_manager import Slot
Expand Down Expand Up @@ -285,6 +285,110 @@ def emit_product(
return P.to_int_or_vid(final_val)


def emit_add_int(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

P: "MLXProgramBuilder",
a: "IntOrVid",
b: "IntOrVid",
) -> "IntOrVid":
"""Emit ``a + b``, folding to a literal when both operands are static."""
from executorch.backends.mlx.serialization.mlx_graph_schema import (
AddIntNode,
IntOrVid,
)

if not a.is_vid and not b.is_vid:
return IntOrVid.from_literal(a.literal + b.literal)

_, out_slot = P.make_tmp_value_slot()
P.emit(AddIntNode(a=a, b=b, out=P.slot_to_vid(out_slot)))
return P.to_int_or_vid(out_slot)


def emit_sub_int(
P: "MLXProgramBuilder",
a: "IntOrVid",
b: "IntOrVid",
) -> "IntOrVid":
"""Emit ``a - b``, folding to a literal when both operands are static."""
from executorch.backends.mlx.serialization.mlx_graph_schema import (
IntOrVid,
SubtractIntNode,
)

if not a.is_vid and not b.is_vid:
return IntOrVid.from_literal(a.literal - b.literal)

_, out_slot = P.make_tmp_value_slot()
P.emit(SubtractIntNode(a=a, b=b, out=P.slot_to_vid(out_slot)))
return P.to_int_or_vid(out_slot)


def emit_ceil_div(
P: "MLXProgramBuilder",
a: "IntOrVid",
b: int,
) -> "IntOrVid":
"""Emit ``ceil(a / b)``, folding to a literal when ``a`` is static.

Computes ``(a + b - 1) // b``.
"""
from executorch.backends.mlx.serialization.mlx_graph_schema import (
FloorDivideIntNode,
IntOrVid,
)

if not a.is_vid:
return IntOrVid.from_literal((a.literal + b - 1) // b)

sum_iov = emit_add_int(P, a, IntOrVid.from_literal(b - 1))
_, out_slot = P.make_tmp_value_slot()
P.emit(
FloorDivideIntNode(
a=sum_iov,
b=IntOrVid.from_literal(b),
out=P.slot_to_vid(out_slot),
)
)
return P.to_int_or_vid(out_slot)


def emit_if_else(
P: "MLXProgramBuilder",
cond: "IntOrVid",
emit_then: Callable[[], None],
emit_else: Callable[[], None],
) -> None:
"""Emit a branch on ``cond``: nonzero -> then, zero -> else.

If ``cond`` is a literal, no IfNode or chains are emitted; the
selected callback is invoked directly in the current chain.
Otherwise both callbacks are emitted into fresh chains and an
``IfNode`` selects between them at runtime. Nodes emitted inside a
callback land in that branch's chain only.
"""
from executorch.backends.mlx.serialization.mlx_graph_schema import IfNode

if not cond.is_vid:
if cond.literal:
emit_then()
else:
emit_else()
return

with P.new_chain() as then_idx:
emit_then()
with P.new_chain() as else_idx:
emit_else()

P.emit(
IfNode(
cond=cond,
then_chain_idx=then_idx,
else_chain_idx=else_idx,
)
)


def emit_quantized_biases(
P: "MLXProgramBuilder",
zero_point_key: str,
Expand Down
15 changes: 7 additions & 8 deletions backends/mlx/custom_kernel_ops/gguf/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
lower it without materializing the dequantized weight:

* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`.
* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks
repacked into MLX qparams at export time).
* **Q4_K** -> fused custom Metal kernels in :mod:`.q4k` (default), or the legacy
MLX-native repack path when ``ET_MLX_EMIT_DIRECT_GGUF=0``.

Both cover linear and embedding.

Expand All @@ -42,8 +42,7 @@
from torch.export.exported_program import ExportedProgram
from torch.fx.node import Node

# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via
# MLX-native affine ops).
# Quant types each pattern can lower (both via fused custom Metal kernels).
_LINEAR_TYPES = {"q4_k", "q6_k"}
_EMBEDDING_TYPES = {"q4_k", "q6_k"}

Expand Down Expand Up @@ -79,8 +78,8 @@ class GGUFQuantizedLinearHandler(PatternHandler):
"""Lower ``dequantize_gguf + linear`` to a fused quantized matmul.

Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)``
and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX
4-bit ``quantized_matmul``.
and dispatches on ``ggml_type``: Q6_K / Q4_K -> custom Metal kernels in
:mod:`.q6k` / :mod:`.q4k`.
"""

def __init__(self, head, body, weight, ggml_type, output_dtype):
Expand Down Expand Up @@ -126,8 +125,8 @@ class GGUFQuantizedEmbeddingHandler(PatternHandler):
"""Lower ``dequantize_gguf + embedding`` to a quantized gather.

Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)``
and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX
quantized gather.
and dispatches on ``ggml_type``: Q6_K / Q4_K -> custom Metal gather kernels
in :mod:`.q6k` / :mod:`.q4k`.
"""

def __init__(self, head, body, weight, ggml_type, output_dtype):
Expand Down
14 changes: 13 additions & 1 deletion backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@
# LICENSE file in the root directory of this source tree.
#

"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit).
"""GGUF Q4_K format lowering for the MLX backend (fused Metal kernels).

See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by
``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the
package import light.

Set ``ET_MLX_EMIT_DIRECT_GGUF=0`` to use the legacy export-time repack path
(:mod:`.linear_mlx_native` / :mod:`.embedding_mlx_native`) instead.
"""

from __future__ import annotations

import os


def emit_direct_gguf() -> bool:
"""Return True to emit fused kernels that read raw GGUF bytes (default)."""
return os.environ.get("ET_MLX_EMIT_DIRECT_GGUF", "1") != "0"
148 changes: 122 additions & 26 deletions backends/mlx/custom_kernel_ops/gguf/q4k/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,137 @@
# LICENSE file in the root directory of this source tree.
#

"""Shared Q4_K -> MLX qparam repack for the Q4_K lowering.
"""Shared GGUF **Q4_K** primitives for the MLX backend.

Q4_K maps cleanly onto MLX's affine 4-bit kernels (group_size 32): the GGUF
blocks are unpacked to the torchao ``IntxUnpackedToInt8Tensor`` layout and
repacked into MLX qparams (``S * Q + B``) at export time, so the weight is
stored MLX-ready and decoded by MLX itself.
This module holds the pieces common to every Q4_K kernel (linear matmul/matvec
and the embedding gather):

* ``QK_K`` / ``Q4K_BLOCK_BYTES`` and the per-super-block byte layout constants.
* ``_Q4K_HEADER`` -- the Metal header (the ``block_q4_K`` struct plus the
per-element and vectorized dequant helpers) shared by all Q4_K Metal kernels.

Q4_K layout (per 256-element super-block, 144 bytes, see llama.cpp
``block_q4_K`` in ``ggml-common.h``)::

half d # super-block scale for quantized scales
half dmin # super-block scale for quantized mins
uint8 scales[12] # 6-bit packed scales + mins
uint8 qs[128] # 4-bit quants

The dequantized value for a 4-bit code ``q`` in sub-block ``s`` is
``d * scale[s] * q - dmin * min[s]`` (affine).

Attribution
-----------
The Q4_K block layout and the Metal dequant helpers in ``_Q4K_HEADER`` follow
llama.cpp
(``ggml-common.h`` / ``ggml-metal.metal``: ``block_q4_K``, ``dequantize_q4_K``,
``get_scale_min_k4``), which is MIT-licensed (Copyright (c) 2023-2024 The ggml
authors).
"""

from __future__ import annotations

from typing import Tuple
# ---------------------------------------------------------------------------
# Q4_K constants
# ---------------------------------------------------------------------------

QK_K = 256
K_SCALE_SIZE = 12
_Q4K_D_BYTES = 2
_Q4K_DMIN_BYTES = 2
_Q4K_SCALES_BYTES = K_SCALE_SIZE
_Q4K_QS_BYTES = QK_K // 2 # 128
Q4K_BLOCK_BYTES = (
_Q4K_D_BYTES + _Q4K_DMIN_BYTES + _Q4K_SCALES_BYTES + _Q4K_QS_BYTES
) # 144


# ---------------------------------------------------------------------------
# Shared Metal header
# ---------------------------------------------------------------------------

from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
from executorch.backends.mlx.builder.slot_manager import Slot
from torch.fx.node import Node
# Ported from llama.cpp ggml-common.h (block_q4_K, get_scale_min_k4) and
# ggml-metal.metal (dequantize_q4_K). Struct field order matches GGUF bytes:
# d(0:2), dmin(2:4), scales(4:16), qs(16:144).
_Q4K_HEADER = """
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
using namespace metal;

_BITS = 4
#define QK_K 256
#define K_SCALE_SIZE 12

typedef struct {
half d;
half dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
} block_q4_K;

def _repack_mlx(
P: MLXProgramBuilder, weight_node: Node
) -> Tuple[Slot, Slot, Slot, int]:
"""Unpack a raw Q4_K blob and repack into MLX qparam constants.
// Unpack 6-bit scale and min for sub-block index j (0..7).
// Ported from llama.cpp get_scale_min_k4 (ggml-quants.c).
inline void get_scale_min_k4(int j, device const uint8_t * q,
thread uint8_t & sc, thread uint8_t & m) {
if (j < 4) {
sc = q[j] & 63;
m = q[j + 4] & 63;
} else {
sc = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
}
}

Returns ``(packed_slot, scales_slot, biases_slot, group_size)``.
"""
from executorch.extension.llm.export.gguf import ExportableGGUFTensor
// Metal variant used by dequantize_q4_K_16 (matches ggml-metal.metal).
inline uchar2 get_scale_min_k4_just2(int j, int k, device const uint8_t * q) {
return j < 4
? uchar2{uint8_t(q[j + 0 + k] & 63), uint8_t(q[j + 4 + k] & 63)}
: uchar2{
uint8_t((q[j + 4 + k] & 0xF) | ((q[j - 4 + k] >> 6) << 4)),
uint8_t((q[j + 4 + k] >> 4) | ((q[j + 0 + k] >> 6) << 4))};
}

weight_target, raw = P.get_placeholder_target_and_tensor(weight_node)
intx = ExportableGGUFTensor.from_raw(raw, "q4_k").to_intx_unpacked_to_int8_tensor()
group_size = int(intx.block_size[-1])
packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, _BITS)
// Dequantize a single element at within-block position p (0..255).
// Mirrors dequantize_row_q4_K (ggml-quants.c): 64-element chunks, 32 lows
// then 32 highs per chunk, each half using its own scale/min pair.
inline float dequant_q4k_elem(device const block_q4_K * blk, int p) {
const int chunk = p >> 6; // 0..3 (64-element groups)
const int sub = p & 63; // 0..63 within chunk
const int q_idx = (chunk << 5) + (sub & 31);
device const uint8_t * q = blk->qs + q_idx;

packed_slot = P.make_or_get_constant(f"{weight_target}_q4k_packed", packed)
scales_slot = P.make_or_get_constant(f"{weight_target}_q4k_scales", intx.scale)
biases_slot = P.make_or_get_constant(f"{weight_target}_q4k_biases", biases)
return packed_slot, scales_slot, biases_slot, group_size
uint8_t sc, mn;
get_scale_min_k4((chunk << 1) + (sub >= 32 ? 1 : 0), blk->scales, sc, mn);

const float d = (float) blk->d;
const float dm = (float) blk->dmin;
const float dl = d * (float) sc;
const float ml = dm * (float) mn;

const uint8_t nib = (sub < 32) ? (q[0] & 0xF) : (q[0] >> 4);
return dl * (float) nib - ml;
}

// Vectorized Q4_K dequantize: decodes 16 values into half4x4.
// Ported from llama.cpp dequantize_q4_K (ggml-metal.metal).
// il ranges 0..7 (Q4_K uses NL=8, not Q6_K's NL=16).
inline void dequantize_q4_K_16(device const block_q4_K * xb, short il,
thread half4x4 & reg) {
device const uint8_t * q = xb->qs;

short is = (il / 4) * 2;
q = q + (il / 4) * 32 + 16 * (il & 1);
il = il & 3;

const uchar2 sc = get_scale_min_k4_just2(is, il / 2, xb->scales);
const float d = il < 2 ? (float) xb->d : (float) xb->d / 16.f;
const float dm = (float) xb->dmin;
const float dl = d * (float) sc[0];
const float ml = dm * (float) sc[1];

const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i / 4][i % 4] = (half)(dl * (float)(q[i] & mask) - ml);
}
Comment on lines +137 to +140
}
"""
Loading
Loading